diff --git a/packages/cursorless-engine/src/languages/LanguageDefinition.ts b/packages/cursorless-engine/src/languages/LanguageDefinition.ts index e2f2ee9086..93e93fb6d0 100644 --- a/packages/cursorless-engine/src/languages/LanguageDefinition.ts +++ b/packages/cursorless-engine/src/languages/LanguageDefinition.ts @@ -74,7 +74,7 @@ export class LanguageDefinition { * legacy pathways */ getScopeHandler(scopeType: ScopeType) { - if (!this.query.captureNames.includes(scopeType.type)) { + if (!this.query.hasCapture(scopeType.type)) { return undefined; } diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/TreeSitterQuery.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/TreeSitterQuery.ts index 69785ed0bf..60d3763a02 100644 --- a/packages/cursorless-engine/src/languages/TreeSitterQuery/TreeSitterQuery.ts +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/TreeSitterQuery.ts @@ -1,19 +1,22 @@ import type { Position, TextDocument } from "@cursorless/common"; -import { showError, type TreeSitter } from "@cursorless/common"; -import { groupBy, uniq } from "lodash-es"; -import type { Point, Query } from "web-tree-sitter"; +import { type TreeSitter } from "@cursorless/common"; +import type * as treeSitter from "web-tree-sitter"; import { ide } from "../../singletons/ide.singleton"; import { getNodeRange } from "../../util/nodeSelectors"; import type { + MutableQueryCapture, MutableQueryMatch, - QueryCapture, QueryMatch, } from "./QueryCapture"; import { checkCaptureStartEnd } from "./checkCaptureStartEnd"; import { isContainedInErrorNode } from "./isContainedInErrorNode"; -import { parsePredicates } from "./parsePredicates"; -import { predicateToString } from "./predicateToString"; -import { rewriteStartOfEndOf } from "./rewriteStartOfEndOf"; +import { normalizeCaptureName } from "./normalizeCaptureName"; +import { parsePredicatesWithErrorHandling } from "./parsePredicatesWithErrorHandling"; +import { positionToPoint } from "./positionToPoint"; +import { + getStartOfEndOfRange, + rewriteStartOfEndOf, +} from "./rewriteStartOfEndOf"; import { treeSitterQueryCache } from "./treeSitterQueryCache"; /** @@ -21,13 +24,15 @@ import { treeSitterQueryCache } from "./treeSitterQueryCache"; * defines our own custom predicate operators */ export class TreeSitterQuery { + private shouldCheckCaptures: boolean; + private constructor( private treeSitter: TreeSitter, /** * The raw tree-sitter query as parsed by tree-sitter from the query file */ - private query: Query, + private query: treeSitter.Query, /** * The predicates for each pattern in the query. Each element of the outer @@ -35,38 +40,26 @@ export class TreeSitterQuery { * corresponds to a predicate for that pattern. */ private patternPredicates: ((match: MutableQueryMatch) => boolean)[][], - ) {} - - static create(languageId: string, treeSitter: TreeSitter, query: Query) { - const { errors, predicates } = parsePredicates(query.predicates); - - if (errors.length > 0) { - for (const error of errors) { - const context = [ - `language ${languageId}`, - `pattern ${error.patternIdx}`, - `predicate \`${predicateToString( - query.predicates[error.patternIdx][error.predicateIdx], - )}\``, - ].join(", "); - - void showError( - ide().messages, - "TreeSitterQuery.parsePredicates", - `Error parsing predicate for ${context}: ${error.error}`, - ); - } + ) { + this.shouldCheckCaptures = ide().runMode !== "production"; + } - // We show errors to the user, but we don't want to crash the extension - // unless we're in test mode - if (ide().runMode === "test") { - throw new Error("Invalid predicates"); - } - } + static create( + languageId: string, + treeSitter: TreeSitter, + query: treeSitter.Query, + ) { + const predicates = parsePredicatesWithErrorHandling(languageId, query); return new TreeSitterQuery(treeSitter, query, predicates); } + hasCapture(name: string): boolean { + return this.query.captureNames.some( + (n) => normalizeCaptureName(n) === name, + ); + } + matches( document: TextDocument, start?: Position, @@ -84,74 +77,114 @@ export class TreeSitterQuery { start?: Position, end?: Position, ): QueryMatch[] { - return this.query - .matches(this.treeSitter.getTree(document).rootNode, { - startPosition: start == null ? undefined : positionToPoint(start), - endPosition: end == null ? undefined : positionToPoint(end), - }) - .map( - ({ pattern, captures }): MutableQueryMatch => ({ - patternIdx: pattern, - captures: captures.map(({ name, node }) => ({ - name, - node, - document, - range: getNodeRange(node), - insertionDelimiter: undefined, - allowMultiple: false, - hasError: () => isContainedInErrorNode(node), - })), - }), - ) - .filter((match) => - this.patternPredicates[match.patternIdx].every((predicate) => - predicate(match), - ), - ) - .map((match): QueryMatch => { - // Merge the ranges of all captures with the same name into a single - // range and return one capture with that name. We consider captures - // with names `@foo`, `@foo.start`, and `@foo.end` to have the same - // name, for which we'd return a capture with name `foo`. - const captures: QueryCapture[] = Object.entries( - groupBy(match.captures, ({ name }) => normalizeCaptureName(name)), - ).map(([name, captures]) => { - captures = rewriteStartOfEndOf(captures); - const capturesAreValid = checkCaptureStartEnd( - captures, - ide().messages, - ); - - if (!capturesAreValid && ide().runMode === "test") { - throw new Error("Invalid captures"); - } - - return { - name, - range: captures - .map(({ range }) => range) - .reduce((accumulator, range) => range.union(accumulator)), - allowMultiple: captures.some((capture) => capture.allowMultiple), - insertionDelimiter: captures.find( - (capture) => capture.insertionDelimiter != null, - )?.insertionDelimiter, - hasError: () => captures.some((capture) => capture.hasError()), - }; - }); - - return { ...match, captures }; - }); + const matches = this.getTreeMatches(document, start, end); + const results: QueryMatch[] = []; + + for (const match of matches) { + const mutableMatch = this.createMutableQueryMatch(document, match); + + if (!this.runPredicates(mutableMatch)) { + continue; + } + + results.push(this.createQueryMatch(mutableMatch)); + } + + return results; } - get captureNames() { - return uniq(this.query.captureNames.map(normalizeCaptureName)); + private getTreeMatches( + document: TextDocument, + start?: Position, + end?: Position, + ) { + const { rootNode } = this.treeSitter.getTree(document); + return this.query.matches(rootNode, { + startPosition: start != null ? positionToPoint(start) : undefined, + endPosition: end != null ? positionToPoint(end) : undefined, + }); } -} -function normalizeCaptureName(name: string): string { - return name.replace(/(\.(start|end))?(\.(startOf|endOf))?$/, ""); -} + private createMutableQueryMatch( + document: TextDocument, + match: treeSitter.QueryMatch, + ): MutableQueryMatch { + return { + patternIdx: match.pattern, + captures: match.captures.map(({ name, node }) => ({ + name, + node, + document, + range: getNodeRange(node), + insertionDelimiter: undefined, + allowMultiple: false, + hasError: () => isContainedInErrorNode(node), + })), + }; + } -function positionToPoint(start: Position): Point { - return { row: start.line, column: start.character }; + private runPredicates(match: MutableQueryMatch): boolean { + for (const predicate of this.patternPredicates[match.patternIdx]) { + if (!predicate(match)) { + return false; + } + } + return true; + } + + private createQueryMatch(match: MutableQueryMatch): QueryMatch { + const result: MutableQueryCapture[] = []; + const map = new Map< + string, + { acc: MutableQueryCapture; captures: MutableQueryCapture[] } + >(); + + // Merge the ranges of all captures with the same name into a single + // range and return one capture with that name. We consider captures + // with names `@foo`, `@foo.start`, and `@foo.end` to have the same + // name, for which we'd return a capture with name `foo`. + + for (const capture of match.captures) { + const name = normalizeCaptureName(capture.name); + const range = getStartOfEndOfRange(capture); + const existing = map.get(name); + + if (existing == null) { + const captures = [capture]; + const acc = { + ...capture, + name, + range, + hasError: () => captures.some((c) => c.hasError()), + }; + result.push(acc); + map.set(name, { acc, captures }); + } else { + existing.acc.range = existing.acc.range.union(range); + existing.acc.allowMultiple = + existing.acc.allowMultiple || capture.allowMultiple; + existing.acc.insertionDelimiter = + existing.acc.insertionDelimiter ?? capture.insertionDelimiter; + existing.captures.push(capture); + } + } + + if (this.shouldCheckCaptures) { + this.checkCaptures(Array.from(map.values())); + } + + return { captures: result }; + } + + private checkCaptures(matches: { captures: MutableQueryCapture[] }[]) { + for (const match of matches) { + const capturesAreValid = checkCaptureStartEnd( + rewriteStartOfEndOf(match.captures), + ide().messages, + ); + if (!capturesAreValid && ide().runMode === "test") { + throw new Error("Invalid captures"); + } + } + } } diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/normalizeCaptureName.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/normalizeCaptureName.ts new file mode 100644 index 0000000000..5322ff1556 --- /dev/null +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/normalizeCaptureName.ts @@ -0,0 +1,3 @@ +export function normalizeCaptureName(name: string): string { + return name.replace(/(\.(start|end))?(\.(startOf|endOf))?$/, ""); +} diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/parsePredicatesWithErrorHandling.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/parsePredicatesWithErrorHandling.ts new file mode 100644 index 0000000000..6798939249 --- /dev/null +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/parsePredicatesWithErrorHandling.ts @@ -0,0 +1,38 @@ +import { showError } from "@cursorless/common"; +import type { Query } from "web-tree-sitter"; +import { ide } from "../../singletons/ide.singleton"; +import { parsePredicates } from "./parsePredicates"; +import { predicateToString } from "./predicateToString"; + +export function parsePredicatesWithErrorHandling( + languageId: string, + query: Query, +) { + const { errors, predicates } = parsePredicates(query.predicates); + + if (errors.length > 0) { + for (const error of errors) { + const context = [ + `language ${languageId}`, + `pattern ${error.patternIdx}`, + `predicate \`${predicateToString( + query.predicates[error.patternIdx][error.predicateIdx], + )}\``, + ].join(", "); + + void showError( + ide().messages, + "TreeSitterQuery.parsePredicates", + `Error parsing predicate for ${context}: ${error.error}`, + ); + } + + // We show errors to the user, but we don't want to crash the extension + // unless we're in test mode + if (ide().runMode === "test") { + throw new Error("Invalid predicates"); + } + } + + return predicates; +} diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/positionToPoint.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/positionToPoint.ts new file mode 100644 index 0000000000..af5650a309 --- /dev/null +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/positionToPoint.ts @@ -0,0 +1,6 @@ +import type { Position } from "@cursorless/common"; +import type { Point } from "web-tree-sitter"; + +export function positionToPoint(start: Position): Point { + return { row: start.line, column: start.character }; +} diff --git a/packages/cursorless-engine/src/languages/TreeSitterQuery/rewriteStartOfEndOf.ts b/packages/cursorless-engine/src/languages/TreeSitterQuery/rewriteStartOfEndOf.ts index 463ba9a15c..76fc022411 100644 --- a/packages/cursorless-engine/src/languages/TreeSitterQuery/rewriteStartOfEndOf.ts +++ b/packages/cursorless-engine/src/languages/TreeSitterQuery/rewriteStartOfEndOf.ts @@ -1,3 +1,4 @@ +import type { Range } from "@cursorless/common"; import type { MutableQueryCapture } from "./QueryCapture"; /** @@ -11,22 +12,29 @@ import type { MutableQueryCapture } from "./QueryCapture"; export function rewriteStartOfEndOf( captures: MutableQueryCapture[], ): MutableQueryCapture[] { - return captures.map((capture) => { - // Remove trailing .startOf and .endOf, adjusting ranges. - if (capture.name.endsWith(".startOf")) { - return { - ...capture, - name: capture.name.replace(/\.startOf$/, ""), - range: capture.range.start.toEmptyRange(), - }; - } - if (capture.name.endsWith(".endOf")) { - return { - ...capture, - name: capture.name.replace(/\.endOf$/, ""), - range: capture.range.end.toEmptyRange(), - }; - } - return capture; - }); + return captures.map((capture) => ({ + ...capture, + range: getStartOfEndOfRange(capture), + name: getStartOfEndOfName(capture), + })); +} + +export function getStartOfEndOfRange(capture: MutableQueryCapture): Range { + if (capture.name.endsWith(".startOf")) { + return capture.range.start.toEmptyRange(); + } + if (capture.name.endsWith(".endOf")) { + return capture.range.end.toEmptyRange(); + } + return capture.range; +} + +function getStartOfEndOfName(capture: MutableQueryCapture): string { + if (capture.name.endsWith(".startOf")) { + return capture.name.slice(0, -8); + } + if (capture.name.endsWith(".endOf")) { + return capture.name.slice(0, -6); + } + return capture.name; }