Skip to content
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ export class LanguageDefinition {
* legacy pathways
*/
getScopeHandler(scopeType: ScopeType) {
if (!this.query.captureNames.includes(scopeType.type)) {
if (!this.query.hasCapture(scopeType.type)) {
return undefined;
}

Expand Down
Original file line number Diff line number Diff line change
@@ -1,72 +1,65 @@
import type { Position, TextDocument } from "@cursorless/common";
import { showError, type TreeSitter } from "@cursorless/common";
import { groupBy, uniq } from "lodash-es";
import type { Point, Query } from "web-tree-sitter";
import { type TreeSitter } from "@cursorless/common";
import type * as treeSitter from "web-tree-sitter";
import { ide } from "../../singletons/ide.singleton";
import { getNodeRange } from "../../util/nodeSelectors";
import type {
MutableQueryCapture,
MutableQueryMatch,
QueryCapture,
QueryMatch,
} from "./QueryCapture";
import { checkCaptureStartEnd } from "./checkCaptureStartEnd";
import { isContainedInErrorNode } from "./isContainedInErrorNode";
import { parsePredicates } from "./parsePredicates";
import { predicateToString } from "./predicateToString";
import { rewriteStartOfEndOf } from "./rewriteStartOfEndOf";
import { normalizeCaptureName } from "./normalizeCaptureName";
import { parsePredicatesWithErrorHandling } from "./parsePredicatesWithErrorHandling";
import { positionToPoint } from "./positionToPoint";
import {
getStartOfEndOfRange,
rewriteStartOfEndOf,
} from "./rewriteStartOfEndOf";
import { treeSitterQueryCache } from "./treeSitterQueryCache";

/**
* Wrapper around a tree-sitter query that provides a more convenient API, and
* defines our own custom predicate operators
*/
export class TreeSitterQuery {
private shouldCheckCaptures: boolean;

private constructor(
private treeSitter: TreeSitter,

/**
* The raw tree-sitter query as parsed by tree-sitter from the query file
*/
private query: Query,
private query: treeSitter.Query,

/**
* The predicates for each pattern in the query. Each element of the outer
* array corresponds to a pattern, and each element of the inner array
* corresponds to a predicate for that pattern.
*/
private patternPredicates: ((match: MutableQueryMatch) => boolean)[][],
) {}

static create(languageId: string, treeSitter: TreeSitter, query: Query) {
const { errors, predicates } = parsePredicates(query.predicates);

if (errors.length > 0) {
for (const error of errors) {
const context = [
`language ${languageId}`,
`pattern ${error.patternIdx}`,
`predicate \`${predicateToString(
query.predicates[error.patternIdx][error.predicateIdx],
)}\``,
].join(", ");

void showError(
ide().messages,
"TreeSitterQuery.parsePredicates",
`Error parsing predicate for ${context}: ${error.error}`,
);
}
) {
this.shouldCheckCaptures = ide().runMode !== "production";
}

// We show errors to the user, but we don't want to crash the extension
// unless we're in test mode
if (ide().runMode === "test") {
throw new Error("Invalid predicates");
}
}
static create(
languageId: string,
treeSitter: TreeSitter,
query: treeSitter.Query,
) {
const predicates = parsePredicatesWithErrorHandling(languageId, query);

return new TreeSitterQuery(treeSitter, query, predicates);
}

hasCapture(name: string): boolean {
return this.query.captureNames.some(
(n) => normalizeCaptureName(n) === name,
);
}

matches(
document: TextDocument,
start?: Position,
Expand All @@ -84,74 +77,114 @@ export class TreeSitterQuery {
start?: Position,
end?: Position,
): QueryMatch[] {
return this.query
.matches(this.treeSitter.getTree(document).rootNode, {
startPosition: start == null ? undefined : positionToPoint(start),
endPosition: end == null ? undefined : positionToPoint(end),
})
.map(
({ pattern, captures }): MutableQueryMatch => ({
patternIdx: pattern,
captures: captures.map(({ name, node }) => ({
name,
node,
document,
range: getNodeRange(node),
insertionDelimiter: undefined,
allowMultiple: false,
hasError: () => isContainedInErrorNode(node),
})),
}),
)
.filter((match) =>
this.patternPredicates[match.patternIdx].every((predicate) =>
predicate(match),
),
)
.map((match): QueryMatch => {
// Merge the ranges of all captures with the same name into a single
// range and return one capture with that name. We consider captures
// with names `@foo`, `@foo.start`, and `@foo.end` to have the same
// name, for which we'd return a capture with name `foo`.
const captures: QueryCapture[] = Object.entries(
groupBy(match.captures, ({ name }) => normalizeCaptureName(name)),
).map(([name, captures]) => {
captures = rewriteStartOfEndOf(captures);
const capturesAreValid = checkCaptureStartEnd(
captures,
ide().messages,
);

if (!capturesAreValid && ide().runMode === "test") {
throw new Error("Invalid captures");
}

return {
name,
range: captures
.map(({ range }) => range)
.reduce((accumulator, range) => range.union(accumulator)),
allowMultiple: captures.some((capture) => capture.allowMultiple),
insertionDelimiter: captures.find(
(capture) => capture.insertionDelimiter != null,
)?.insertionDelimiter,
hasError: () => captures.some((capture) => capture.hasError()),
};
});

return { ...match, captures };
});
const matches = this.getTreeMatches(document, start, end);
const results: QueryMatch[] = [];

for (const match of matches) {
const mutableMatch = this.createMutableQueryMatch(document, match);

if (!this.runPredicates(mutableMatch)) {
continue;
}

results.push(this.createQueryMatch(mutableMatch));
}

return results;
}

get captureNames() {
return uniq(this.query.captureNames.map(normalizeCaptureName));
private getTreeMatches(
document: TextDocument,
start?: Position,
end?: Position,
) {
const { rootNode } = this.treeSitter.getTree(document);
return this.query.matches(rootNode, {
startPosition: start != null ? positionToPoint(start) : undefined,
endPosition: end != null ? positionToPoint(end) : undefined,
});
}
}

function normalizeCaptureName(name: string): string {
return name.replace(/(\.(start|end))?(\.(startOf|endOf))?$/, "");
}
private createMutableQueryMatch(
document: TextDocument,
match: treeSitter.QueryMatch,
): MutableQueryMatch {
return {
patternIdx: match.pattern,
captures: match.captures.map(({ name, node }) => ({
name,
node,
document,
range: getNodeRange(node),
insertionDelimiter: undefined,
allowMultiple: false,
hasError: () => isContainedInErrorNode(node),
})),
};
}

function positionToPoint(start: Position): Point {
return { row: start.line, column: start.character };
private runPredicates(match: MutableQueryMatch): boolean {
for (const predicate of this.patternPredicates[match.patternIdx]) {
if (!predicate(match)) {
return false;
}
}
return true;
}

private createQueryMatch(match: MutableQueryMatch): QueryMatch {
const result: MutableQueryCapture[] = [];
const map = new Map<
string,
{ acc: MutableQueryCapture; captures: MutableQueryCapture[] }
>();

// Merge the ranges of all captures with the same name into a single
// range and return one capture with that name. We consider captures
// with names `@foo`, `@foo.start`, and `@foo.end` to have the same
// name, for which we'd return a capture with name `foo`.

for (const capture of match.captures) {
const name = normalizeCaptureName(capture.name);
const range = getStartOfEndOfRange(capture);
const existing = map.get(name);

if (existing == null) {
const captures = [capture];
const acc = {
...capture,
name,
range,
hasError: () => captures.some((c) => c.hasError()),
};
result.push(acc);
map.set(name, { acc, captures });
} else {
existing.acc.range = existing.acc.range.union(range);
existing.acc.allowMultiple =
existing.acc.allowMultiple || capture.allowMultiple;
existing.acc.insertionDelimiter =
existing.acc.insertionDelimiter ?? capture.insertionDelimiter;
existing.captures.push(capture);
}
}

if (this.shouldCheckCaptures) {
this.checkCaptures(Array.from(map.values()));
}

return { captures: result };
}

private checkCaptures(matches: { captures: MutableQueryCapture[] }[]) {
for (const match of matches) {
const capturesAreValid = checkCaptureStartEnd(
rewriteStartOfEndOf(match.captures),
ide().messages,
);
if (!capturesAreValid && ide().runMode === "test") {
throw new Error("Invalid captures");
}
}
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
export function normalizeCaptureName(name: string): string {
return name.replace(/(\.(start|end))?(\.(startOf|endOf))?$/, "");
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,38 @@
import { showError } from "@cursorless/common";
import type { Query } from "web-tree-sitter";
import { ide } from "../../singletons/ide.singleton";
import { parsePredicates } from "./parsePredicates";
import { predicateToString } from "./predicateToString";

export function parsePredicatesWithErrorHandling(
languageId: string,
query: Query,
) {
const { errors, predicates } = parsePredicates(query.predicates);

if (errors.length > 0) {
for (const error of errors) {
const context = [
`language ${languageId}`,
`pattern ${error.patternIdx}`,
`predicate \`${predicateToString(
query.predicates[error.patternIdx][error.predicateIdx],
)}\``,
].join(", ");

void showError(
ide().messages,
"TreeSitterQuery.parsePredicates",
`Error parsing predicate for ${context}: ${error.error}`,
);
}

// We show errors to the user, but we don't want to crash the extension
// unless we're in test mode
if (ide().runMode === "test") {
throw new Error("Invalid predicates");
}
}

return predicates;
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
import type { Position } from "@cursorless/common";
import type { Point } from "web-tree-sitter";

export function positionToPoint(start: Position): Point {
return { row: start.line, column: start.character };
}
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import type { Range } from "@cursorless/common";
import type { MutableQueryCapture } from "./QueryCapture";

/**
Expand All @@ -11,22 +12,29 @@ import type { MutableQueryCapture } from "./QueryCapture";
export function rewriteStartOfEndOf(
captures: MutableQueryCapture[],
): MutableQueryCapture[] {
return captures.map((capture) => {
// Remove trailing .startOf and .endOf, adjusting ranges.
if (capture.name.endsWith(".startOf")) {
return {
...capture,
name: capture.name.replace(/\.startOf$/, ""),
range: capture.range.start.toEmptyRange(),
};
}
if (capture.name.endsWith(".endOf")) {
return {
...capture,
name: capture.name.replace(/\.endOf$/, ""),
range: capture.range.end.toEmptyRange(),
};
}
return capture;
});
return captures.map((capture) => ({
...capture,
range: getStartOfEndOfRange(capture),
name: getStartOfEndOfName(capture),
}));
}

export function getStartOfEndOfRange(capture: MutableQueryCapture): Range {
if (capture.name.endsWith(".startOf")) {
return capture.range.start.toEmptyRange();
}
if (capture.name.endsWith(".endOf")) {
return capture.range.end.toEmptyRange();
}
return capture.range;
}

function getStartOfEndOfName(capture: MutableQueryCapture): string {
if (capture.name.endsWith(".startOf")) {
return capture.name.slice(0, -8);
}
if (capture.name.endsWith(".endOf")) {
return capture.name.slice(0, -6);
}
return capture.name;
}
Loading