Skip to content

Commit 42affed

Browse files
Increase performance for Tree sitter matches
1 parent 43e9e21 commit 42affed

File tree

6 files changed

+163
-86
lines changed

6 files changed

+163
-86
lines changed

packages/cursorless-engine/src/languages/LanguageDefinition.ts

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,7 @@ export class LanguageDefinition {
7474
* legacy pathways
7575
*/
7676
getScopeHandler(scopeType: ScopeType) {
77-
if (!this.query.captureNames.includes(scopeType.type)) {
77+
if (!this.query.hasCapture(scopeType.type)) {
7878
return undefined;
7979
}
8080

packages/cursorless-engine/src/languages/TreeSitterQuery/QueryCapture.ts

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,17 @@ export interface QueryCapture {
4949
hasError(): boolean;
5050
}
5151

52+
/**
53+
* A capture of a query pattern against a syntax tree that can be modified.
54+
*/
55+
export interface ModifiableQueryCapture {
56+
readonly name: string;
57+
range: Range;
58+
allowMultiple: boolean;
59+
insertionDelimiter: string | undefined;
60+
hasError(): boolean;
61+
}
62+
5263
/**
5364
* A match of a query pattern against a syntax tree.
5465
*/
Lines changed: 116 additions & 67 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,25 @@
11
import type { Position, TextDocument } from "@cursorless/common";
22
import { showError, type TreeSitter } from "@cursorless/common";
3-
import { groupBy, uniq } from "lodash-es";
4-
import type { Point, Query } from "web-tree-sitter";
3+
import type Parser from "web-tree-sitter";
4+
import type { Query } from "web-tree-sitter";
55
import { ide } from "../../singletons/ide.singleton";
66
import { getNodeRange } from "../../util/nodeSelectors";
77
import type {
8+
ModifiableQueryCapture,
9+
MutableQueryCapture,
810
MutableQueryMatch,
9-
QueryCapture,
1011
QueryMatch,
1112
} from "./QueryCapture";
1213
import { checkCaptureStartEnd } from "./checkCaptureStartEnd";
1314
import { isContainedInErrorNode } from "./isContainedInErrorNode";
15+
import { normalizeCaptureName } from "./normalizeCaptureName";
1416
import { parsePredicates } from "./parsePredicates";
17+
import { positionToPoint } from "./positionToPoint";
1518
import { predicateToString } from "./predicateToString";
16-
import { rewriteStartOfEndOf } from "./rewriteStartOfEndOf";
19+
import {
20+
getStartOfEndOfRange,
21+
rewriteStartOfEndOf,
22+
} from "./rewriteStartOfEndOf";
1723
import { treeSitterQueryCache } from "./treeSitterQueryCache";
1824

1925
/**
@@ -67,6 +73,12 @@ export class TreeSitterQuery {
6773
return new TreeSitterQuery(treeSitter, query, predicates);
6874
}
6975

76+
hasCapture(name: string): boolean {
77+
return this.query.captureNames.some(
78+
(n) => normalizeCaptureName(n) === name,
79+
);
80+
}
81+
7082
matches(
7183
document: TextDocument,
7284
start?: Position,
@@ -84,74 +96,111 @@ export class TreeSitterQuery {
8496
start?: Position,
8597
end?: Position,
8698
): QueryMatch[] {
87-
return this.query
88-
.matches(this.treeSitter.getTree(document).rootNode, {
89-
startPosition: start == null ? undefined : positionToPoint(start),
90-
endPosition: end == null ? undefined : positionToPoint(end),
91-
})
92-
.map(
93-
({ pattern, captures }): MutableQueryMatch => ({
94-
patternIdx: pattern,
95-
captures: captures.map(({ name, node }) => ({
96-
name,
97-
node,
98-
document,
99-
range: getNodeRange(node),
100-
insertionDelimiter: undefined,
101-
allowMultiple: false,
102-
hasError: () => isContainedInErrorNode(node),
103-
})),
104-
}),
105-
)
106-
.filter((match) =>
107-
this.patternPredicates[match.patternIdx].every((predicate) =>
108-
predicate(match),
109-
),
110-
)
111-
.map((match): QueryMatch => {
112-
// Merge the ranges of all captures with the same name into a single
113-
// range and return one capture with that name. We consider captures
114-
// with names `@foo`, `@foo.start`, and `@foo.end` to have the same
115-
// name, for which we'd return a capture with name `foo`.
116-
const captures: QueryCapture[] = Object.entries(
117-
groupBy(match.captures, ({ name }) => normalizeCaptureName(name)),
118-
).map(([name, captures]) => {
119-
captures = rewriteStartOfEndOf(captures);
120-
const capturesAreValid = checkCaptureStartEnd(
121-
captures,
122-
ide().messages,
123-
);
124-
125-
if (!capturesAreValid && ide().runMode === "test") {
126-
throw new Error("Invalid captures");
127-
}
128-
129-
return {
130-
name,
131-
range: captures
132-
.map(({ range }) => range)
133-
.reduce((accumulator, range) => range.union(accumulator)),
134-
allowMultiple: captures.some((capture) => capture.allowMultiple),
135-
insertionDelimiter: captures.find(
136-
(capture) => capture.insertionDelimiter != null,
137-
)?.insertionDelimiter,
138-
hasError: () => captures.some((capture) => capture.hasError()),
139-
};
140-
});
141-
142-
return { ...match, captures };
143-
});
99+
const results: QueryMatch[] = [];
100+
const isTesting = ide().runMode === "test";
101+
102+
const matches = this.query.matches(
103+
this.treeSitter.getTree(document).rootNode,
104+
{
105+
startPosition: start != null ? positionToPoint(start) : undefined,
106+
endPosition: end != null ? positionToPoint(end) : undefined,
107+
},
108+
);
109+
110+
for (const match of matches) {
111+
const mutableMatch = createMutableQueryMatch(document, match);
112+
113+
if (this.runPredicates(mutableMatch)) {
114+
results.push(createQueryMatch(mutableMatch, isTesting));
115+
}
116+
}
117+
118+
return results;
144119
}
145120

146-
get captureNames() {
147-
return uniq(this.query.captureNames.map(normalizeCaptureName));
121+
private runPredicates(match: MutableQueryMatch): boolean {
122+
for (const predicate of this.patternPredicates[match.patternIdx]) {
123+
if (!predicate(match)) {
124+
return false;
125+
}
126+
}
127+
return true;
148128
}
149129
}
150130

151-
function normalizeCaptureName(name: string): string {
152-
return name.replace(/(\.(start|end))?(\.(startOf|endOf))?$/, "");
131+
function createMutableQueryMatch(
132+
document: TextDocument,
133+
match: Parser.QueryMatch,
134+
): MutableQueryMatch {
135+
return {
136+
patternIdx: match.pattern,
137+
captures: match.captures.map(({ name, node }) => ({
138+
name,
139+
node,
140+
document,
141+
range: getNodeRange(node),
142+
insertionDelimiter: undefined,
143+
allowMultiple: false,
144+
hasError: () => isContainedInErrorNode(node),
145+
})),
146+
};
153147
}
154148

155-
function positionToPoint(start: Position): Point {
156-
return { row: start.line, column: start.character };
149+
function createQueryMatch(
150+
match: MutableQueryMatch,
151+
isTesting: boolean,
152+
): QueryMatch {
153+
const result: ModifiableQueryCapture[] = [];
154+
const resultMap = new Map<
155+
string,
156+
{ acc: ModifiableQueryCapture; captures: MutableQueryCapture[] }
157+
>();
158+
159+
// Merge the ranges of all captures with the same name into a single
160+
// range and return one capture with that name. We consider captures
161+
// with names `@foo`, `@foo.start`, and `@foo.end` to have the same
162+
// name, for which we'd return a capture with name `foo`.
163+
164+
for (const capture of match.captures) {
165+
const name = normalizeCaptureName(capture.name);
166+
const range = getStartOfEndOfRange(capture);
167+
const existing = resultMap.get(name);
168+
169+
if (existing == null) {
170+
const accumulator = {
171+
name,
172+
range,
173+
allowMultiple: capture.allowMultiple,
174+
insertionDelimiter: capture.insertionDelimiter,
175+
hasError: () => capture.hasError(),
176+
};
177+
result.push(accumulator);
178+
resultMap.set(name, {
179+
acc: accumulator,
180+
captures: [capture],
181+
});
182+
} else {
183+
existing.acc.range = existing.acc.range.union(range);
184+
existing.acc.allowMultiple =
185+
existing.acc.allowMultiple || capture.allowMultiple;
186+
existing.acc.insertionDelimiter =
187+
existing.acc.insertionDelimiter ?? capture.insertionDelimiter;
188+
existing.acc.hasError = () => existing.captures.some((c) => c.hasError());
189+
existing.captures.push(capture);
190+
}
191+
}
192+
193+
if (isTesting) {
194+
for (const captureGroup of resultMap.values()) {
195+
const capturesAreValid = checkCaptureStartEnd(
196+
rewriteStartOfEndOf(captureGroup.captures),
197+
ide().messages,
198+
);
199+
if (!capturesAreValid) {
200+
throw new Error("Invalid captures");
201+
}
202+
}
203+
}
204+
205+
return { captures: result };
157206
}
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
export function normalizeCaptureName(name: string): string {
2+
return name.replace(/(\.(start|end))?(\.(startOf|endOf))?$/, "");
3+
}
Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
import type { Position } from "@cursorless/common";
2+
import type { Point } from "web-tree-sitter";
3+
4+
export function positionToPoint(start: Position): Point {
5+
return { row: start.line, column: start.character };
6+
}
Lines changed: 26 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
import type { Range } from "@cursorless/common";
12
import type { MutableQueryCapture } from "./QueryCapture";
23

34
/**
@@ -11,22 +12,29 @@ import type { MutableQueryCapture } from "./QueryCapture";
1112
export function rewriteStartOfEndOf(
1213
captures: MutableQueryCapture[],
1314
): MutableQueryCapture[] {
14-
return captures.map((capture) => {
15-
// Remove trailing .startOf and .endOf, adjusting ranges.
16-
if (capture.name.endsWith(".startOf")) {
17-
return {
18-
...capture,
19-
name: capture.name.replace(/\.startOf$/, ""),
20-
range: capture.range.start.toEmptyRange(),
21-
};
22-
}
23-
if (capture.name.endsWith(".endOf")) {
24-
return {
25-
...capture,
26-
name: capture.name.replace(/\.endOf$/, ""),
27-
range: capture.range.end.toEmptyRange(),
28-
};
29-
}
30-
return capture;
31-
});
15+
return captures.map((capture) => ({
16+
...capture,
17+
range: getStartOfEndOfRange(capture),
18+
name: getStartOfEndOfName(capture),
19+
}));
20+
}
21+
22+
export function getStartOfEndOfRange(capture: MutableQueryCapture): Range {
23+
if (capture.name.endsWith(".startOf")) {
24+
return capture.range.start.toEmptyRange();
25+
}
26+
if (capture.name.endsWith(".endOf")) {
27+
return capture.range.end.toEmptyRange();
28+
}
29+
return capture.range;
30+
}
31+
32+
function getStartOfEndOfName(capture: MutableQueryCapture): string {
33+
if (capture.name.endsWith(".startOf")) {
34+
return capture.name.slice(0, -8);
35+
}
36+
if (capture.name.endsWith(".endOf")) {
37+
return capture.name.slice(0, -6);
38+
}
39+
return capture.name;
3240
}

0 commit comments

Comments
 (0)