Skip to content

Commit 85fd827

Browse files
Increase performance for Tree sitter matches (#2713)
First, were checking the captures against the SCM files in production, which was an oversight - should only be happening in debugging. Secondarily, the existing implementation multiple map and filters. This part of the code is very time sensitive and can run quite long for large files since we are iterating every capture in the entire file. This implementation aims to reduce iteration steps Fixes #2656 ## Checklist - [/] I have added [tests](https://www.cursorless.org/docs/contributing/test-case-recorder/) - [/] I have updated the [docs](https://github.com/cursorless-dev/cursorless/tree/main/docs) and [cheatsheet](https://github.com/cursorless-dev/cursorless/tree/main/cursorless-talon/src/cheatsheet) - [/] I have not broken the cheatsheet
1 parent 6c4c02a commit 85fd827

File tree

6 files changed

+207
-119
lines changed

6 files changed

+207
-119
lines changed

packages/cursorless-engine/src/languages/LanguageDefinition.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ export class LanguageDefinition {
7474
* legacy pathways
7575
*/
7676
getScopeHandler(scopeType: ScopeType) {
77-
if (!this.query.captureNames.includes(scopeType.type)) {
77+
if (!this.query.hasCapture(scopeType.type)) {
7878
return undefined;
7979
}
8080

Lines changed: 133 additions & 100 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,65 @@
11
import type { Position, TextDocument } from "@cursorless/common";
2-
import { showError, type TreeSitter } from "@cursorless/common";
3-
import { groupBy, uniq } from "lodash-es";
4-
import type { Point, Query } from "web-tree-sitter";
2+
import { type TreeSitter } from "@cursorless/common";
3+
import type * as treeSitter from "web-tree-sitter";
54
import { ide } from "../../singletons/ide.singleton";
65
import { getNodeRange } from "../../util/nodeSelectors";
76
import type {
7+
MutableQueryCapture,
88
MutableQueryMatch,
9-
QueryCapture,
109
QueryMatch,
1110
} from "./QueryCapture";
1211
import { checkCaptureStartEnd } from "./checkCaptureStartEnd";
1312
import { isContainedInErrorNode } from "./isContainedInErrorNode";
14-
import { parsePredicates } from "./parsePredicates";
15-
import { predicateToString } from "./predicateToString";
16-
import { rewriteStartOfEndOf } from "./rewriteStartOfEndOf";
13+
import { normalizeCaptureName } from "./normalizeCaptureName";
14+
import { parsePredicatesWithErrorHandling } from "./parsePredicatesWithErrorHandling";
15+
import { positionToPoint } from "./positionToPoint";
16+
import {
17+
getStartOfEndOfRange,
18+
rewriteStartOfEndOf,
19+
} from "./rewriteStartOfEndOf";
1720
import { treeSitterQueryCache } from "./treeSitterQueryCache";
1821

1922
/**
2023
* Wrapper around a tree-sitter query that provides a more convenient API, and
2124
* defines our own custom predicate operators
2225
*/
2326
export class TreeSitterQuery {
27+
private shouldCheckCaptures: boolean;
28+
2429
private constructor(
2530
private treeSitter: TreeSitter,
2631

2732
/**
2833
* The raw tree-sitter query as parsed by tree-sitter from the query file
2934
*/
30-
private query: Query,
35+
private query: treeSitter.Query,
3136

3237
/**
3338
* The predicates for each pattern in the query. Each element of the outer
3439
* array corresponds to a pattern, and each element of the inner array
3540
* corresponds to a predicate for that pattern.
3641
*/
3742
private patternPredicates: ((match: MutableQueryMatch) => boolean)[][],
38-
) {}
39-
40-
static create(languageId: string, treeSitter: TreeSitter, query: Query) {
41-
const { errors, predicates } = parsePredicates(query.predicates);
42-
43-
if (errors.length > 0) {
44-
for (const error of errors) {
45-
const context = [
46-
`language ${languageId}`,
47-
`pattern ${error.patternIdx}`,
48-
`predicate \`${predicateToString(
49-
query.predicates[error.patternIdx][error.predicateIdx],
50-
)}\``,
51-
].join(", ");
52-
53-
void showError(
54-
ide().messages,
55-
"TreeSitterQuery.parsePredicates",
56-
`Error parsing predicate for ${context}: ${error.error}`,
57-
);
58-
}
43+
) {
44+
this.shouldCheckCaptures = ide().runMode !== "production";
45+
}
5946

60-
// We show errors to the user, but we don't want to crash the extension
61-
// unless we're in test mode
62-
if (ide().runMode === "test") {
63-
throw new Error("Invalid predicates");
64-
}
65-
}
47+
static create(
48+
languageId: string,
49+
treeSitter: TreeSitter,
50+
query: treeSitter.Query,
51+
) {
52+
const predicates = parsePredicatesWithErrorHandling(languageId, query);
6653

6754
return new TreeSitterQuery(treeSitter, query, predicates);
6855
}
6956

57+
hasCapture(name: string): boolean {
58+
return this.query.captureNames.some(
59+
(n) => normalizeCaptureName(n) === name,
60+
);
61+
}
62+
7063
matches(
7164
document: TextDocument,
7265
start?: Position,
@@ -84,74 +77,114 @@ export class TreeSitterQuery {
8477
start?: Position,
8578
end?: Position,
8679
): QueryMatch[] {
87-
return this.query
88-
.matches(this.treeSitter.getTree(document).rootNode, {
89-
startPosition: start == null ? undefined : positionToPoint(start),
90-
endPosition: end == null ? undefined : positionToPoint(end),
91-
})
92-
.map(
93-
({ pattern, captures }): MutableQueryMatch => ({
94-
patternIdx: pattern,
95-
captures: captures.map(({ name, node }) => ({
96-
name,
97-
node,
98-
document,
99-
range: getNodeRange(node),
100-
insertionDelimiter: undefined,
101-
allowMultiple: false,
102-
hasError: () => isContainedInErrorNode(node),
103-
})),
104-
}),
105-
)
106-
.filter((match) =>
107-
this.patternPredicates[match.patternIdx].every((predicate) =>
108-
predicate(match),
109-
),
110-
)
111-
.map((match): QueryMatch => {
112-
// Merge the ranges of all captures with the same name into a single
113-
// range and return one capture with that name. We consider captures
114-
// with names `@foo`, `@foo.start`, and `@foo.end` to have the same
115-
// name, for which we'd return a capture with name `foo`.
116-
const captures: QueryCapture[] = Object.entries(
117-
groupBy(match.captures, ({ name }) => normalizeCaptureName(name)),
118-
).map(([name, captures]) => {
119-
captures = rewriteStartOfEndOf(captures);
120-
const capturesAreValid = checkCaptureStartEnd(
121-
captures,
122-
ide().messages,
123-
);
124-
125-
if (!capturesAreValid && ide().runMode === "test") {
126-
throw new Error("Invalid captures");
127-
}
128-
129-
return {
130-
name,
131-
range: captures
132-
.map(({ range }) => range)
133-
.reduce((accumulator, range) => range.union(accumulator)),
134-
allowMultiple: captures.some((capture) => capture.allowMultiple),
135-
insertionDelimiter: captures.find(
136-
(capture) => capture.insertionDelimiter != null,
137-
)?.insertionDelimiter,
138-
hasError: () => captures.some((capture) => capture.hasError()),
139-
};
140-
});
141-
142-
return { ...match, captures };
143-
});
80+
const matches = this.getTreeMatches(document, start, end);
81+
const results: QueryMatch[] = [];
82+
83+
for (const match of matches) {
84+
const mutableMatch = this.createMutableQueryMatch(document, match);
85+
86+
if (!this.runPredicates(mutableMatch)) {
87+
continue;
88+
}
89+
90+
results.push(this.createQueryMatch(mutableMatch));
91+
}
92+
93+
return results;
14494
}
14595

146-
get captureNames() {
147-
return uniq(this.query.captureNames.map(normalizeCaptureName));
96+
private getTreeMatches(
97+
document: TextDocument,
98+
start?: Position,
99+
end?: Position,
100+
) {
101+
const { rootNode } = this.treeSitter.getTree(document);
102+
return this.query.matches(rootNode, {
103+
startPosition: start != null ? positionToPoint(start) : undefined,
104+
endPosition: end != null ? positionToPoint(end) : undefined,
105+
});
148106
}
149-
}
150107

151-
function normalizeCaptureName(name: string): string {
152-
return name.replace(/(\.(start|end))?(\.(startOf|endOf))?$/, "");
153-
}
108+
private createMutableQueryMatch(
109+
document: TextDocument,
110+
match: treeSitter.QueryMatch,
111+
): MutableQueryMatch {
112+
return {
113+
patternIdx: match.pattern,
114+
captures: match.captures.map(({ name, node }) => ({
115+
name,
116+
node,
117+
document,
118+
range: getNodeRange(node),
119+
insertionDelimiter: undefined,
120+
allowMultiple: false,
121+
hasError: () => isContainedInErrorNode(node),
122+
})),
123+
};
124+
}
154125

155-
function positionToPoint(start: Position): Point {
156-
return { row: start.line, column: start.character };
126+
private runPredicates(match: MutableQueryMatch): boolean {
127+
for (const predicate of this.patternPredicates[match.patternIdx]) {
128+
if (!predicate(match)) {
129+
return false;
130+
}
131+
}
132+
return true;
133+
}
134+
135+
private createQueryMatch(match: MutableQueryMatch): QueryMatch {
136+
const result: MutableQueryCapture[] = [];
137+
const map = new Map<
138+
string,
139+
{ acc: MutableQueryCapture; captures: MutableQueryCapture[] }
140+
>();
141+
142+
// Merge the ranges of all captures with the same name into a single
143+
// range and return one capture with that name. We consider captures
144+
// with names `@foo`, `@foo.start`, and `@foo.end` to have the same
145+
// name, for which we'd return a capture with name `foo`.
146+
147+
for (const capture of match.captures) {
148+
const name = normalizeCaptureName(capture.name);
149+
const range = getStartOfEndOfRange(capture);
150+
const existing = map.get(name);
151+
152+
if (existing == null) {
153+
const captures = [capture];
154+
const acc = {
155+
...capture,
156+
name,
157+
range,
158+
hasError: () => captures.some((c) => c.hasError()),
159+
};
160+
result.push(acc);
161+
map.set(name, { acc, captures });
162+
} else {
163+
existing.acc.range = existing.acc.range.union(range);
164+
existing.acc.allowMultiple =
165+
existing.acc.allowMultiple || capture.allowMultiple;
166+
existing.acc.insertionDelimiter =
167+
existing.acc.insertionDelimiter ?? capture.insertionDelimiter;
168+
existing.captures.push(capture);
169+
}
170+
}
171+
172+
if (this.shouldCheckCaptures) {
173+
this.checkCaptures(Array.from(map.values()));
174+
}
175+
176+
return { captures: result };
177+
}
178+
179+
private checkCaptures(matches: { captures: MutableQueryCapture[] }[]) {
180+
for (const match of matches) {
181+
const capturesAreValid = checkCaptureStartEnd(
182+
rewriteStartOfEndOf(match.captures),
183+
ide().messages,
184+
);
185+
if (!capturesAreValid && ide().runMode === "test") {
186+
throw new Error("Invalid captures");
187+
}
188+
}
189+
}
157190
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
export function normalizeCaptureName(name: string): string {
2+
return name.replace(/(\.(start|end))?(\.(startOf|endOf))?$/, "");
3+
}
Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
import { showError } from "@cursorless/common";
2+
import type { Query } from "web-tree-sitter";
3+
import { ide } from "../../singletons/ide.singleton";
4+
import { parsePredicates } from "./parsePredicates";
5+
import { predicateToString } from "./predicateToString";
6+
7+
export function parsePredicatesWithErrorHandling(
8+
languageId: string,
9+
query: Query,
10+
) {
11+
const { errors, predicates } = parsePredicates(query.predicates);
12+
13+
if (errors.length > 0) {
14+
for (const error of errors) {
15+
const context = [
16+
`language ${languageId}`,
17+
`pattern ${error.patternIdx}`,
18+
`predicate \`${predicateToString(
19+
query.predicates[error.patternIdx][error.predicateIdx],
20+
)}\``,
21+
].join(", ");
22+
23+
void showError(
24+
ide().messages,
25+
"TreeSitterQuery.parsePredicates",
26+
`Error parsing predicate for ${context}: ${error.error}`,
27+
);
28+
}
29+
30+
// We show errors to the user, but we don't want to crash the extension
31+
// unless we're in test mode
32+
if (ide().runMode === "test") {
33+
throw new Error("Invalid predicates");
34+
}
35+
}
36+
37+
return predicates;
38+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import type { Position } from "@cursorless/common";
2+
import type { Point } from "web-tree-sitter";
3+
4+
export function positionToPoint(start: Position): Point {
5+
return { row: start.line, column: start.character };
6+
}
Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type { Range } from "@cursorless/common";
12
import type { MutableQueryCapture } from "./QueryCapture";
23

34
/**
@@ -11,22 +12,29 @@ import type { MutableQueryCapture } from "./QueryCapture";
1112
export function rewriteStartOfEndOf(
1213
captures: MutableQueryCapture[],
1314
): MutableQueryCapture[] {
14-
return captures.map((capture) => {
15-
// Remove trailing .startOf and .endOf, adjusting ranges.
16-
if (capture.name.endsWith(".startOf")) {
17-
return {
18-
...capture,
19-
name: capture.name.replace(/\.startOf$/, ""),
20-
range: capture.range.start.toEmptyRange(),
21-
};
22-
}
23-
if (capture.name.endsWith(".endOf")) {
24-
return {
25-
...capture,
26-
name: capture.name.replace(/\.endOf$/, ""),
27-
range: capture.range.end.toEmptyRange(),
28-
};
29-
}
30-
return capture;
31-
});
15+
return captures.map((capture) => ({
16+
...capture,
17+
range: getStartOfEndOfRange(capture),
18+
name: getStartOfEndOfName(capture),
19+
}));
20+
}
21+
22+
export function getStartOfEndOfRange(capture: MutableQueryCapture): Range {
23+
if (capture.name.endsWith(".startOf")) {
24+
return capture.range.start.toEmptyRange();
25+
}
26+
if (capture.name.endsWith(".endOf")) {
27+
return capture.range.end.toEmptyRange();
28+
}
29+
return capture.range;
30+
}
31+
32+
function getStartOfEndOfName(capture: MutableQueryCapture): string {
33+
if (capture.name.endsWith(".startOf")) {
34+
return capture.name.slice(0, -8);
35+
}
36+
if (capture.name.endsWith(".endOf")) {
37+
return capture.name.slice(0, -6);
38+
}
39+
return capture.name;
3240
}

0 commit comments

Comments
 (0)