diff --git a/packages/cursorless-engine/src/languages/LegacyLanguageId.ts b/packages/cursorless-engine/src/languages/LegacyLanguageId.ts index f4252da4e5..f439cc2f16 100644 --- a/packages/cursorless-engine/src/languages/LegacyLanguageId.ts +++ b/packages/cursorless-engine/src/languages/LegacyLanguageId.ts @@ -2,6 +2,6 @@ * The language IDs that we have full tree-sitter support for using our legacy * modifiers. */ -export const legacyLanguageIds = ["clojure", "latex", "rust"] as const; +export const legacyLanguageIds = ["latex", "rust"] as const; export type LegacyLanguageId = (typeof legacyLanguageIds)[number]; diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/QueryCapture.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/QueryCapture.ts index 5df0fcb802..856bfadc15 100644 --- a/packages/cursorless-engine/src/languages/TreeSitterQuery/QueryCapture.ts +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/QueryCapture.ts @@ -1,17 +1,18 @@ import type { Range, TextDocument } from "@cursorless/common"; -import type { Point } from "web-tree-sitter"; +import type { Point, TreeCursor } from "web-tree-sitter"; /** * Simple representation of the tree sitter syntax node. Used by * {@link MutableQueryCapture} to avoid using range/text and other mutable * parameters directly from the node. */ -interface SimpleSyntaxNode { +export interface SimpleSyntaxNode { readonly id: number; readonly type: string; readonly isNamed: boolean; readonly parent: SimpleSyntaxNode | null; readonly children: Array; + walk(): TreeCursor; } /** diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/getChildNodesForFieldName.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/getChildNodesForFieldName.ts new file mode 100644 index 0000000000..e887e254a0 --- /dev/null +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/getChildNodesForFieldName.ts @@ -0,0 +1,19 @@ +import type { SimpleSyntaxNode } from "./QueryCapture"; + +export function getChildNodesForFieldName( + node: SimpleSyntaxNode, + fieldName: string, +): SimpleSyntaxNode[] { + const nodes = []; + const treeCursor = node.walk(); + let hasNext = treeCursor.gotoFirstChild(); + + while (hasNext) { + if (treeCursor.currentFieldName === fieldName) { + nodes.push(treeCursor.currentNode); + } + hasNext = treeCursor.gotoNextSibling(); + } + + return nodes; +} diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/isEven.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/isEven.ts new file mode 100644 index 0000000000..12e9dacea7 --- /dev/null +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/isEven.ts @@ -0,0 +1,30 @@ +import type { SimpleSyntaxNode } from "./QueryCapture"; + +/** + * Checks if a node is at an even index within its parent's field. + * + * @param node - The node to check. + * @param fieldName - The name of the field in the parent node. + * @returns True if the node is at an even index, false otherwise. + */ +export function isEven(node: SimpleSyntaxNode, fieldName: string): boolean { + if (node.parent == null) { + throw Error("Node has no parent"); + } + + const treeCursor = node.parent.walk(); + let hasNext = treeCursor.gotoFirstChild(); + let even = true; + + while (hasNext) { + if (treeCursor.currentFieldName === fieldName) { + if (treeCursor.currentNode.id === node.id) { + return even; + } + even = !even; + } + hasNext = treeCursor.gotoNextSibling(); + } + + throw Error(`Node not found in parent for field: ${fieldName}`); +} diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/queryPredicateOperators.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/queryPredicateOperators.ts index 00985321fb..2565679bb9 100644 --- a/packages/cursorless-engine/src/languages/TreeSitterQuery/queryPredicateOperators.ts +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/queryPredicateOperators.ts @@ -4,8 +4,49 @@ import { z } from "zod"; import { makeRangeFromPositions } from "../../util/nodeSelectors"; import type { MutableQueryCapture } from "./QueryCapture"; import { QueryPredicateOperator } from "./QueryPredicateOperator"; +import { isEven } from "./isEven"; import { q } from "./operatorArgumentSchemaTypes"; +/** + * A predicate operator that returns true if the node is at an even index within + * its parents field. For example, `(#even? @foo value)` will accept the match + * if the `@foo` capture is at an even index among its parents value children. + */ +class Even extends QueryPredicateOperator { + name = "even?" as const; + schema = z.tuple([q.node, q.string]); + run({ node }: MutableQueryCapture, fieldName: string) { + return isEven(node, fieldName); + } +} + +/** + * A predicate operator that returns true if the node is at an odd index within + * its parents field. For example, `(#odd? @foo value)` will accept the match + * if the `@foo` capture is at an odd index among its parents value children. + */ +class Odd extends QueryPredicateOperator { + name = "odd?" as const; + schema = z.tuple([q.node, q.string]); + run({ node }: MutableQueryCapture, fieldName: string) { + return !isEven(node, fieldName); + } +} + +/** + * A predicate operator that returns true if the node matches the given text. + * For example, `(#text? @foo bar)` will accept the match if the `@foo` + * captures text is `bar`. It is acceptable to pass in multiple texts, e.g. + * `(#text? @foo bar baz)`. + */ +class Text extends QueryPredicateOperator { + name = "text?" as const; + schema = z.tuple([q.node, q.string]).rest(q.string); + run({ document, range }: MutableQueryCapture, ...texts: string[]) { + return texts.includes(document.getText(range)); + } +} + /** * A predicate operator that returns true if the node is of the given type. * For example, `(#type? @foo string)` will accept the match if the `@foo` @@ -388,6 +429,9 @@ class EmptySingleMultiDelimiter extends QueryPredicateOperator Math.floor(nodeIndex / 2) * 2 + parity, - ); -} - -function mapParityNodeFinder(parity: 0 | 1) { - return parityNodeFinder(patternFinder("map_lit"), parity); -} - -/** - * Creates a node finder which will apply a transformation to the index of a - * value node and return the node at the given index of the nodes parent - * @param parentFinder A finder which will be applied to the parent to determine - * whether it is a match - * @param indexTransform A function that will be applied to the index of the - * value node. The node at the given index will be used instead of the node - * itself - * @returns A node finder based on the given description - */ -function indexNodeFinder( - parentFinder: NodeFinder, - indexTransform: (index: number) => number, -) { - return (node: Node) => { - const parent = node.parent; - - if (parent == null || parentFinder(parent) == null) { - return null; - } - - const valueNodes = getValueNodes(parent); - - const nodeIndex = valueNodes.findIndex(({ id }) => id === node.id); - - if (nodeIndex === -1) { - // FIXME: In the future we might conceivably try to handle saying "take - // item" when the selection is inside a comment between the key and value - return null; - } - - const desiredIndex = indexTransform(nodeIndex); - - if (desiredIndex === -1) { - return null; - } - - return valueNodes[desiredIndex]; - }; -} - -/** - * Return the "value" node children of a given node. These are the items in a list - * @param node The node whose children to get - * @returns A list of the value node children of the given node - */ -const getValueNodes = (node: Node) => getChildNodesForFieldName(node, "value"); - -// A function call is a list literal which is not quoted -const functionCallPattern = "~quoting_lit.list_lit!"; -const functionCallFinder = patternFinder(functionCallPattern); - -/** - * Matches a function call if the name of the function is one of the given names - * @param names The acceptable function names - * @returns The function call node if the name matches otherwise null - */ -function functionNameBasedFinder(...names: string[]) { - return (node: Node) => { - const functionCallNode = functionCallFinder(node); - if (functionCallNode == null) { - return null; - } - - const functionNode = getValueNodes(functionCallNode)[0]; - - return names.includes(functionNode?.text) ? functionCallNode : null; - }; -} - -function functionNameBasedMatcher(...names: string[]) { - return matcher(functionNameBasedFinder(...names)); -} - -const functionFinder = functionNameBasedFinder("defn", "defmacro"); - -const functionNameMatcher = chainedMatcher([ - functionFinder, - (functionNode) => getValueNodes(functionNode)[1], -]); - -const ifStatementFinder = functionNameBasedFinder( - "if", - "if-let", - "when", - "when-let", -); - -const ifStatementMatcher = matcher(ifStatementFinder); - -const nodeMatchers: Partial< - Record -> = { - collectionKey: matcher(mapParityNodeFinder(0)), - value: matcher(mapParityNodeFinder(1)), - - // FIXME: Handle formal parameters - argumentOrParameter: matcher( - indexNodeFinder(patternFinder(functionCallPattern), (nodeIndex: number) => - nodeIndex !== 0 ? nodeIndex : -1, - ), - ), - - functionCall: functionCallPattern, - functionCallee: chainedMatcher([ - functionCallFinder, - (functionNode) => getValueNodes(functionNode)[0], - ]), - - namedFunction: matcher(functionFinder), - - functionName: functionNameMatcher, - - // FIXME: Handle `let` declarations, defs, etc - name: functionNameMatcher, - - anonymousFunction: cascadingMatcher( - functionNameBasedMatcher("fn"), - patternMatcher("anon_fn_lit"), - ), - - ifStatement: ifStatementMatcher, - - condition: chainedMatcher([ - ifStatementFinder, - (node) => getValueNodes(node)[1], - ]), -}; - -export default createPatternMatchers(nodeMatchers); diff --git a/packages/cursorless-engine/src/languages/getNodeMatcher.ts b/packages/cursorless-engine/src/languages/getNodeMatcher.ts index 0fd62cc28d..91c5fef7ec 100644 --- a/packages/cursorless-engine/src/languages/getNodeMatcher.ts +++ b/packages/cursorless-engine/src/languages/getNodeMatcher.ts @@ -1,6 +1,6 @@ +import type { SimpleScopeTypeType } from "@cursorless/common"; import { UnsupportedLanguageError } from "@cursorless/common"; import type { Node } from "web-tree-sitter"; -import type { SimpleScopeTypeType } from "@cursorless/common"; import type { NodeMatcher, NodeMatcherValue, @@ -8,9 +8,8 @@ import type { } from "../typings/Types"; import { notSupported } from "../util/nodeMatchers"; import { selectionWithEditorFromRange } from "../util/selectionUtils"; -import clojure from "./clojure"; -import type { LegacyLanguageId } from "./LegacyLanguageId"; import latex from "./latex"; +import type { LegacyLanguageId } from "./LegacyLanguageId"; import rust from "./rust"; export function getNodeMatcher( @@ -41,7 +40,6 @@ export const languageMatchers: Record< LegacyLanguageId, Partial> > = { - clojure, latex, rust, }; diff --git a/queries/clojure.scm b/queries/clojure.scm index 407e0640ac..2ec5c41e99 100644 --- a/queries/clojure.scm +++ b/queries/clojure.scm @@ -79,3 +79,85 @@ open: "{" @collectionItem.iteration.start.endOf close: "}" @collectionItem.iteration.end.startOf ) @collectionItem.iteration.domain + +;;!! (foo) +;;! ^^^^^ +( + (list_lit + . + value: (_) @functionCallee + ) @functionCall @functionCallee.domain @argumentOrParameter.iteration + ;; A function call is a list literal which is not quoted + (#not-parent-type? @functionCall quoting_lit) +) + +;;!! (foo :bar) +;;! ^^^^ +( + (list_lit + . + value: (_) + value: (_) @argumentOrParameter + ) @_dummy + (#not-parent-type? @_dummy quoting_lit) +) + +;;!! (defn foo [] 5) +;;! ^^^^^^^^^^^^^^^ +;;! ^^^ +( + (list_lit + . + value: (_) @_dummy + . + value: (_) @name @functionName + ) @namedFunction @name.domain @functionName.domain + (#text? @_dummy defn defmacro) + (#not-parent-type? @namedFunction quoting_lit) +) + +;;!! (fn [] 5) +;;! ^^^^^^^^^ +( + (list_lit + . + value: (_) @_dummy + ) @anonymousFunction + (#text? @_dummy fn) + (#not-parent-type? @anonymousFunction quoting_lit) +) + +;;!! #(+ 1 1) +;;! ^^^^^^^^ +(anon_fn_lit) @anonymousFunction + +;;!! (if true "hello") +;;! ^^^^^^^^^^^^^^^^^ +;;! ^^^^ +( + (list_lit + . + value: (_) @_dummy + . + value: (_) @condition + ) @ifStatement @condition.domain + (#text? @_dummy "if" "if-let" "when" "when-let") + (#not-parent-type? @ifStatement quoting_lit) +) + +;;!! {:foo 1, :bar 2} +;;! ^^^^ ^^^^ +;;! ^ ^ +(map_lit + value: (_) @collectionKey @collectionKey.domain.start @value.domain.start + value: (_) @value @collectionKey.domain.end @value.domain.end + (#even? @collectionKey value) + (#odd? @value value) +) + +;;!! {:foo 1, :bar 2} +;;! ^^^^^^^^^^^^^^ +(map_lit + "{" @collectionKey.iteration.start.endOf @value.iteration.start.endOf + "}" @collectionKey.iteration.end.startOf @value.iteration.end.startOf +)