diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/QueryCapture.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/QueryCapture.ts index 0c4eba64b9..0de5d611a4 100644 --- a/packages/cursorless-engine/src/languages/TreeSitterQuery/QueryCapture.ts +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/QueryCapture.ts @@ -1,28 +1,5 @@ import type { Range, TextDocument } from "@cursorless/common"; -import type { Point, TreeCursor } from "web-tree-sitter"; - -/** - * Simple representation of the tree sitter syntax node. Used by - * {@link MutableQueryCapture} to avoid using range/text and other mutable - * parameters directly from the node. - */ -export interface SimpleSyntaxNode { - readonly id: number; - readonly type: string; - readonly isNamed: boolean; - readonly parent: SimpleSyntaxNode | null; - readonly children: Array; - walk(): TreeCursor; -} - -/** - * Add start and end position to the simple syntax node. Used by the `child-range!` predicate. - */ -interface SimpleChildSyntaxNode extends SimpleSyntaxNode { - readonly startPosition: Point; - readonly endPosition: Point; - readonly text: string; -} +import type { Node } from "web-tree-sitter"; /** * A capture of a query pattern against a syntax tree. Often corresponds to a @@ -69,8 +46,9 @@ export interface QueryMatch { export interface MutableQueryCapture extends QueryCapture { /** * The tree-sitter node that was captured. + * This may be undefined if the range has been modified by a query predicate. */ - readonly node: SimpleSyntaxNode; + node: Node | undefined; readonly document: TextDocument; range: Range; diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/getChildNodesForFieldName.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/getChildNodesForFieldName.ts index e887e254a0..681f7e163c 100644 --- a/packages/cursorless-engine/src/languages/TreeSitterQuery/getChildNodesForFieldName.ts +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/getChildNodesForFieldName.ts @@ -1,9 +1,9 @@ -import type { SimpleSyntaxNode } from "./QueryCapture"; +import type { Node } from "web-tree-sitter"; export function getChildNodesForFieldName( - node: SimpleSyntaxNode, + node: Node, fieldName: string, -): SimpleSyntaxNode[] { +): Node[] { const nodes = []; const treeCursor = node.walk(); let hasNext = treeCursor.gotoFirstChild(); diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/getNode.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/getNode.ts new file mode 100644 index 0000000000..0c7d6c22af --- /dev/null +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/getNode.ts @@ -0,0 +1,11 @@ +import type { Node } from "web-tree-sitter"; +import type { MutableQueryCapture } from "./QueryCapture"; + +export function getNode(capture: MutableQueryCapture): Node { + if (capture.node == null) { + throw Error( + `Capture ${capture.name} has no node. The range of the capture has already been updated and no longer matches a specific node.`, + ); + } + return capture.node; +} diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/isEven.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/isEven.ts index 12e9dacea7..1dff9d5ca8 100644 --- a/packages/cursorless-engine/src/languages/TreeSitterQuery/isEven.ts +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/isEven.ts @@ -1,4 +1,4 @@ -import type { SimpleSyntaxNode } from "./QueryCapture"; +import type { Node } from "web-tree-sitter"; /** * Checks if a node is at an even index within its parent's field. @@ -7,7 +7,7 @@ import type { SimpleSyntaxNode } from "./QueryCapture"; * @param fieldName - The name of the field in the parent node. * @returns True if the node is at an even index, false otherwise. */ -export function isEven(node: SimpleSyntaxNode, fieldName: string): boolean { +export function isEven(node: Node, fieldName: string): boolean { if (node.parent == null) { throw Error("Node has no parent"); } diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/queryPredicateOperators.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/queryPredicateOperators.ts index dc8594d6ef..a1678b4378 100644 --- a/packages/cursorless-engine/src/languages/TreeSitterQuery/queryPredicateOperators.ts +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/queryPredicateOperators.ts @@ -1,11 +1,13 @@ import { Position, Range, adjustPosition } from "@cursorless/common"; import type { Point } from "web-tree-sitter"; import { z } from "zod"; +import { getNode } from "./getNode"; import { isEven } from "./isEven"; import { makeRangeFromPositions } from "./makeRangeFromPositions"; import { q } from "./operatorArgumentSchemaTypes"; import type { MutableQueryCapture } from "./QueryCapture"; import { QueryPredicateOperator } from "./QueryPredicateOperator"; +import { setRange } from "./setRange"; /** * A predicate operator that returns true if the node is at an even index within @@ -15,8 +17,8 @@ import { QueryPredicateOperator } from "./QueryPredicateOperator"; class Even extends QueryPredicateOperator { name = "even?" as const; schema = z.tuple([q.node, q.string]); - run({ node }: MutableQueryCapture, fieldName: string) { - return isEven(node, fieldName); + run(capture: MutableQueryCapture, fieldName: string) { + return isEven(getNode(capture), fieldName); } } @@ -28,8 +30,8 @@ class Even extends QueryPredicateOperator { class Odd extends QueryPredicateOperator { name = "odd?" as const; schema = z.tuple([q.node, q.string]); - run({ node }: MutableQueryCapture, fieldName: string) { - return !isEven(node, fieldName); + run(capture: MutableQueryCapture, fieldName: string) { + return !isEven(getNode(capture), fieldName); } } @@ -42,8 +44,8 @@ class Odd extends QueryPredicateOperator { class Text extends QueryPredicateOperator { name = "text?" as const; schema = z.tuple([q.node, q.string]).rest(q.string); - run({ document, range }: MutableQueryCapture, ...texts: string[]) { - return texts.includes(document.getText(range)); + run(capture: MutableQueryCapture, ...texts: string[]) { + return texts.includes(getNode(capture).text); } } @@ -56,8 +58,8 @@ class Text extends QueryPredicateOperator { class Type extends QueryPredicateOperator { name = "type?" as const; schema = z.tuple([q.node, q.string]).rest(q.string); - run({ node }: MutableQueryCapture, ...types: string[]) { - return types.includes(node.type); + run(capture: MutableQueryCapture, ...types: string[]) { + return types.includes(getNode(capture).type); } } @@ -70,8 +72,8 @@ class Type extends QueryPredicateOperator { class NotType extends QueryPredicateOperator { name = "not-type?" as const; schema = z.tuple([q.node, q.string]).rest(q.string); - run({ node }: MutableQueryCapture, ...types: string[]) { - return !types.includes(node.type); + run(capture: MutableQueryCapture, ...types: string[]) { + return !types.includes(getNode(capture).type); } } @@ -84,7 +86,8 @@ class NotType extends QueryPredicateOperator { class NotParentType extends QueryPredicateOperator { name = "not-parent-type?" as const; schema = z.tuple([q.node, q.string]).rest(q.string); - run({ node }: MutableQueryCapture, ...types: string[]) { + run(capture: MutableQueryCapture, ...types: string[]) { + const node = getNode(capture); return node.parent == null || !types.includes(node.parent.type); } } @@ -97,7 +100,8 @@ class NotParentType extends QueryPredicateOperator { class IsNthChild extends QueryPredicateOperator { name = "is-nth-child?" as const; schema = z.tuple([q.node, q.integer]); - run({ node }: MutableQueryCapture, n: number) { + run(capture: MutableQueryCapture, n: number) { + const node = getNode(capture); return node.parent?.children.findIndex((n) => n.id === node.id) === n; } } @@ -112,8 +116,10 @@ class HasMultipleChildrenOfType extends QueryPredicateOperator n.type === type).length; + run(capture: MutableQueryCapture, type: string) { + const count = getNode(capture).children.filter( + (n) => n.type === type, + ).length; return count > 1; } } @@ -128,15 +134,13 @@ class ChildRange extends QueryPredicateOperator { ]); run( - nodeInfo: MutableQueryCapture, + capture: MutableQueryCapture, startIndex: number, endIndex?: number, excludeStart?: boolean, excludeEnd?: boolean, ) { - const { - node: { children }, - } = nodeInfo; + const children = getNode(capture).children; startIndex = startIndex < 0 ? children.length + startIndex : startIndex; endIndex = endIndex == null ? -1 : endIndex; @@ -145,9 +149,12 @@ class ChildRange extends QueryPredicateOperator { const start = children[startIndex]; const end = children[endIndex]; - nodeInfo.range = makeRangeFromPositions( - excludeStart ? start.endPosition : start.startPosition, - excludeEnd ? end.startPosition : end.endPosition, + setRange( + capture, + makeRangeFromPositions( + excludeStart ? start.endPosition : start.startPosition, + excludeEnd ? end.startPosition : end.endPosition, + ), ); return true; @@ -161,10 +168,13 @@ class CharacterRange extends QueryPredicateOperator { z.tuple([q.node, q.integer, q.integer]), ]); - run(nodeInfo: MutableQueryCapture, startOffset: number, endOffset?: number) { - nodeInfo.range = new Range( - nodeInfo.range.start.translate(undefined, startOffset), - nodeInfo.range.end.translate(undefined, endOffset ?? 0), + run(capture: MutableQueryCapture, startOffset: number, endOffset?: number) { + setRange( + capture, + new Range( + capture.range.start.translate(undefined, startOffset), + capture.range.end.translate(undefined, endOffset ?? 0), + ), ); return true; @@ -189,9 +199,9 @@ class ShrinkToMatch extends QueryPredicateOperator { name = "shrink-to-match!" as const; schema = z.tuple([q.node, q.string]); - run(nodeInfo: MutableQueryCapture, pattern: string) { - const { document, range } = nodeInfo; - const text = document.getText(range); + run(capture: MutableQueryCapture, pattern: string) { + const { document, range } = capture; + const text = getNode(capture).text; const match = text.match(new RegExp(pattern, "ds")); if (match?.index == null) { @@ -203,9 +213,12 @@ class ShrinkToMatch extends QueryPredicateOperator { const baseOffset = document.offsetAt(range.start); - nodeInfo.range = new Range( - document.positionAt(baseOffset + startOffset), - document.positionAt(baseOffset + endOffset), + setRange( + capture, + new Range( + document.positionAt(baseOffset + startOffset), + document.positionAt(baseOffset + endOffset), + ), ); return true; @@ -225,8 +238,8 @@ class GrowToNamedSiblings extends QueryPredicateOperator { name = "grow-to-named-siblings!" as const; schema = z.union([z.tuple([q.node]), z.tuple([q.node, q.string])]); - run(nodeInfo: MutableQueryCapture, notText?: string) { - const { node, range } = nodeInfo; + run(capture: MutableQueryCapture, notText?: string) { + const node = getNode(capture); if (node.parent == null) { throw Error("Node has no parent"); @@ -254,9 +267,12 @@ class GrowToNamedSiblings extends QueryPredicateOperator { } if (endPosition != null) { - nodeInfo.range = new Range( - range.start, - new Position(endPosition.row, endPosition.column), + setRange( + capture, + new Range( + capture.range.start, + new Position(endPosition.row, endPosition.column), + ), ); } @@ -272,16 +288,21 @@ class TrimEnd extends QueryPredicateOperator { name = "trim-end!" as const; schema = z.tuple([q.node]); - run(nodeInfo: MutableQueryCapture) { - const { document, range } = nodeInfo; - const text = document.getText(range); + run(capture: MutableQueryCapture) { + const { document, range } = capture; + const text = getNode(capture).text; const whitespaceLength = text.length - text.trimEnd().length; + if (whitespaceLength > 0) { - nodeInfo.range = new Range( - range.start, - adjustPosition(document, range.end, -whitespaceLength), + setRange( + capture, + new Range( + range.start, + adjustPosition(document, range.end, -whitespaceLength), + ), ); } + return true; } } @@ -293,9 +314,9 @@ class DocumentRange extends QueryPredicateOperator { name = "document-range!" as const; schema = z.tuple([q.node]).rest(q.node); - run(...nodeInfos: MutableQueryCapture[]) { - for (const nodeInfo of nodeInfos) { - nodeInfo.range = nodeInfo.document.range; + run(...captures: MutableQueryCapture[]) { + for (const capture of captures) { + setRange(capture, capture.document.range); } return true; @@ -321,28 +342,15 @@ class AllowMultiple extends QueryPredicateOperator { return true; } - run(...nodeInfos: MutableQueryCapture[]) { - for (const nodeInfo of nodeInfos) { - nodeInfo.allowMultiple = true; + run(...captures: MutableQueryCapture[]) { + for (const capture of captures) { + capture.allowMultiple = true; } return true; } } -/** - * A predicate operator that logs a node, for debugging. - */ -class Log extends QueryPredicateOperator { - name = "log!" as const; - schema = z.tuple([q.node]); - - run(nodeInfo: MutableQueryCapture) { - console.log(`#log!: ${nodeInfo.name}@${nodeInfo.range}`); - return true; - } -} - /** * A predicate operator that sets the insertion delimiter of the match. For * example, `(#insertion-delimiter! @foo ", ")` will set the insertion delimiter @@ -352,18 +360,18 @@ class InsertionDelimiter extends QueryPredicateOperator { name = "insertion-delimiter!" as const; schema = z.tuple([q.node, q.string]); - run(nodeInfo: MutableQueryCapture, insertionDelimiter: string) { - nodeInfo.insertionDelimiter = insertionDelimiter; + run(capture: MutableQueryCapture, insertionDelimiter: string) { + capture.insertionDelimiter = insertionDelimiter; return true; } } /** - * A predicate operator that sets the insertion delimiter of {@link nodeInfo} to + * A predicate operator that sets the insertion delimiter of {@link capture} to * either {@link insertionDelimiterConsequence} or * {@link insertionDelimiterAlternative} depending on whether - * {@link conditionNodeInfo} is single or multiline, respectively. For example, + * {@link conditionCapture} is single or multiline, respectively. For example, * * ```scm * (#single-or-multi-line-delimiter! @foo @bar ", " ",\n") @@ -377,12 +385,12 @@ class SingleOrMultilineDelimiter extends QueryPredicateOperator child.isNamed, ); - nodeInfo.insertionDelimiter = isEmpty + capture.insertionDelimiter = isEmpty ? insertionDelimiterEmpty - : conditionNodeInfo.range.isSingleLine + : conditionCapture.range.isSingleLine ? insertionDelimiterSingleLine : insertionDelimiterMultiline; @@ -427,7 +435,6 @@ class EmptySingleMultiDelimiter extends QueryPredicateOperator.*)\"#+") - ) [ diff --git a/queries/scala.scm b/queries/scala.scm index f62fac220f..73a706f546 100644 --- a/queries/scala.scm +++ b/queries/scala.scm @@ -157,8 +157,8 @@ "(" @argumentList.removal.start.endOf @argumentOrParameter.iteration.start.endOf ")" @argumentList.removal.end.startOf @argumentOrParameter.iteration.end.startOf ) @argumentList - (#child-range! @argumentList 1 -2) (#empty-single-multi-delimiter! @argumentList @argumentList "" ", " ",\n") + (#child-range! @argumentList 1 -2) ) @argumentList.domain @argumentOrParameter.iteration.domain ;;!! def foo(aaa: Int, bbb: Int) = x @@ -168,8 +168,8 @@ "(" @argumentList.removal.start.endOf @argumentOrParameter.iteration.start.endOf ")" @argumentList.removal.end.startOf @argumentOrParameter.iteration.end.startOf ) @argumentList - (#child-range! @argumentList 1 -2) (#empty-single-multi-delimiter! @argumentList @argumentList "" ", " ",\n") + (#child-range! @argumentList 1 -2) ) @argumentList.domain @argumentOrParameter.iteration.domain ;;!! foo(aaa, bbb) @@ -179,8 +179,8 @@ "(" @argumentList.removal.start.endOf @argumentOrParameter.iteration.start.endOf ")" @argumentList.removal.end.startOf @argumentOrParameter.iteration.end.startOf ) @argumentList - (#child-range! @argumentList 1 -2) (#empty-single-multi-delimiter! @argumentList @argumentList "" ", " ",\n") + (#child-range! @argumentList 1 -2) ) @argumentList.domain @argumentOrParameter.iteration.domain operator: (operator_identifier) @disqualifyDelimiter diff --git a/queries/talon.scm b/queries/talon.scm index ec82915292..74b212f428 100644 --- a/queries/talon.scm +++ b/queries/talon.scm @@ -183,8 +183,8 @@ "(" @argumentList.removal.start.endOf @argumentOrParameter.iteration.start.endOf ")" @argumentList.removal.end.startOf @argumentOrParameter.iteration.end.startOf ) @argumentList - (#child-range! @argumentList 1 -2) (#empty-single-multi-delimiter! @argumentList @argumentList "" ", " ",\n") + (#child-range! @argumentList 1 -2) ) @argumentList.domain @argumentOrParameter.iteration.domain ;;!! # foo