1
1
import type { Position , TextDocument } from "@cursorless/common" ;
2
- import { showError , type TreeSitter } from "@cursorless/common" ;
3
- import { groupBy , uniq } from "lodash-es" ;
4
- import type { Point , Query } from "web-tree-sitter" ;
2
+ import { type TreeSitter } from "@cursorless/common" ;
3
+ import type * as treeSitter from "web-tree-sitter" ;
5
4
import { ide } from "../../singletons/ide.singleton" ;
6
5
import { getNodeRange } from "../../util/nodeSelectors" ;
7
6
import type {
7
+ MutableQueryCapture ,
8
8
MutableQueryMatch ,
9
- QueryCapture ,
10
9
QueryMatch ,
11
10
} from "./QueryCapture" ;
12
11
import { checkCaptureStartEnd } from "./checkCaptureStartEnd" ;
13
12
import { isContainedInErrorNode } from "./isContainedInErrorNode" ;
14
- import { parsePredicates } from "./parsePredicates" ;
15
- import { predicateToString } from "./predicateToString" ;
16
- import { rewriteStartOfEndOf } from "./rewriteStartOfEndOf" ;
13
+ import { normalizeCaptureName } from "./normalizeCaptureName" ;
14
+ import { parsePredicatesWithErrorHandling } from "./parsePredicatesWithErrorHandling" ;
15
+ import { positionToPoint } from "./positionToPoint" ;
16
+ import {
17
+ getStartOfEndOfRange ,
18
+ rewriteStartOfEndOf ,
19
+ } from "./rewriteStartOfEndOf" ;
17
20
import { treeSitterQueryCache } from "./treeSitterQueryCache" ;
18
21
19
22
/**
20
23
* Wrapper around a tree-sitter query that provides a more convenient API, and
21
24
* defines our own custom predicate operators
22
25
*/
23
26
export class TreeSitterQuery {
27
+ private shouldCheckCaptures : boolean ;
28
+
24
29
private constructor (
25
30
private treeSitter : TreeSitter ,
26
31
27
32
/**
28
33
* The raw tree-sitter query as parsed by tree-sitter from the query file
29
34
*/
30
- private query : Query ,
35
+ private query : treeSitter . Query ,
31
36
32
37
/**
33
38
* The predicates for each pattern in the query. Each element of the outer
34
39
* array corresponds to a pattern, and each element of the inner array
35
40
* corresponds to a predicate for that pattern.
36
41
*/
37
42
private patternPredicates : ( ( match : MutableQueryMatch ) => boolean ) [ ] [ ] ,
38
- ) { }
39
-
40
- static create ( languageId : string , treeSitter : TreeSitter , query : Query ) {
41
- const { errors, predicates } = parsePredicates ( query . predicates ) ;
42
-
43
- if ( errors . length > 0 ) {
44
- for ( const error of errors ) {
45
- const context = [
46
- `language ${ languageId } ` ,
47
- `pattern ${ error . patternIdx } ` ,
48
- `predicate \`${ predicateToString (
49
- query . predicates [ error . patternIdx ] [ error . predicateIdx ] ,
50
- ) } \``,
51
- ] . join ( ", " ) ;
52
-
53
- void showError (
54
- ide ( ) . messages ,
55
- "TreeSitterQuery.parsePredicates" ,
56
- `Error parsing predicate for ${ context } : ${ error . error } ` ,
57
- ) ;
58
- }
43
+ ) {
44
+ this . shouldCheckCaptures = ide ( ) . runMode !== "production" ;
45
+ }
59
46
60
- // We show errors to the user, but we don't want to crash the extension
61
- // unless we're in test mode
62
- if ( ide ( ) . runMode === "test" ) {
63
- throw new Error ( "Invalid predicates" ) ;
64
- }
65
- }
47
+ static create (
48
+ languageId : string ,
49
+ treeSitter : TreeSitter ,
50
+ query : treeSitter . Query ,
51
+ ) {
52
+ const predicates = parsePredicatesWithErrorHandling ( languageId , query ) ;
66
53
67
54
return new TreeSitterQuery ( treeSitter , query , predicates ) ;
68
55
}
69
56
57
+ hasCapture ( name : string ) : boolean {
58
+ return this . query . captureNames . some (
59
+ ( n ) => normalizeCaptureName ( n ) === name ,
60
+ ) ;
61
+ }
62
+
70
63
matches (
71
64
document : TextDocument ,
72
65
start ?: Position ,
@@ -84,74 +77,114 @@ export class TreeSitterQuery {
84
77
start ?: Position ,
85
78
end ?: Position ,
86
79
) : QueryMatch [ ] {
87
- return this . query
88
- . matches ( this . treeSitter . getTree ( document ) . rootNode , {
89
- startPosition : start == null ? undefined : positionToPoint ( start ) ,
90
- endPosition : end == null ? undefined : positionToPoint ( end ) ,
91
- } )
92
- . map (
93
- ( { pattern, captures } ) : MutableQueryMatch => ( {
94
- patternIdx : pattern ,
95
- captures : captures . map ( ( { name, node } ) => ( {
96
- name,
97
- node,
98
- document,
99
- range : getNodeRange ( node ) ,
100
- insertionDelimiter : undefined ,
101
- allowMultiple : false ,
102
- hasError : ( ) => isContainedInErrorNode ( node ) ,
103
- } ) ) ,
104
- } ) ,
105
- )
106
- . filter ( ( match ) =>
107
- this . patternPredicates [ match . patternIdx ] . every ( ( predicate ) =>
108
- predicate ( match ) ,
109
- ) ,
110
- )
111
- . map ( ( match ) : QueryMatch => {
112
- // Merge the ranges of all captures with the same name into a single
113
- // range and return one capture with that name. We consider captures
114
- // with names `@foo`, `@foo.start`, and `@foo.end` to have the same
115
- // name, for which we'd return a capture with name `foo`.
116
- const captures : QueryCapture [ ] = Object . entries (
117
- groupBy ( match . captures , ( { name } ) => normalizeCaptureName ( name ) ) ,
118
- ) . map ( ( [ name , captures ] ) => {
119
- captures = rewriteStartOfEndOf ( captures ) ;
120
- const capturesAreValid = checkCaptureStartEnd (
121
- captures ,
122
- ide ( ) . messages ,
123
- ) ;
124
-
125
- if ( ! capturesAreValid && ide ( ) . runMode === "test" ) {
126
- throw new Error ( "Invalid captures" ) ;
127
- }
128
-
129
- return {
130
- name,
131
- range : captures
132
- . map ( ( { range } ) => range )
133
- . reduce ( ( accumulator , range ) => range . union ( accumulator ) ) ,
134
- allowMultiple : captures . some ( ( capture ) => capture . allowMultiple ) ,
135
- insertionDelimiter : captures . find (
136
- ( capture ) => capture . insertionDelimiter != null ,
137
- ) ?. insertionDelimiter ,
138
- hasError : ( ) => captures . some ( ( capture ) => capture . hasError ( ) ) ,
139
- } ;
140
- } ) ;
141
-
142
- return { ...match , captures } ;
143
- } ) ;
80
+ const matches = this . getTreeMatches ( document , start , end ) ;
81
+ const results : QueryMatch [ ] = [ ] ;
82
+
83
+ for ( const match of matches ) {
84
+ const mutableMatch = this . createMutableQueryMatch ( document , match ) ;
85
+
86
+ if ( ! this . runPredicates ( mutableMatch ) ) {
87
+ continue ;
88
+ }
89
+
90
+ results . push ( this . createQueryMatch ( mutableMatch ) ) ;
91
+ }
92
+
93
+ return results ;
144
94
}
145
95
146
- get captureNames ( ) {
147
- return uniq ( this . query . captureNames . map ( normalizeCaptureName ) ) ;
96
+ private getTreeMatches (
97
+ document : TextDocument ,
98
+ start ?: Position ,
99
+ end ?: Position ,
100
+ ) {
101
+ const { rootNode } = this . treeSitter . getTree ( document ) ;
102
+ return this . query . matches ( rootNode , {
103
+ startPosition : start != null ? positionToPoint ( start ) : undefined ,
104
+ endPosition : end != null ? positionToPoint ( end ) : undefined ,
105
+ } ) ;
148
106
}
149
- }
150
107
151
- function normalizeCaptureName ( name : string ) : string {
152
- return name . replace ( / ( \. ( s t a r t | e n d ) ) ? ( \. ( s t a r t O f | e n d O f ) ) ? $ / , "" ) ;
153
- }
108
+ private createMutableQueryMatch (
109
+ document : TextDocument ,
110
+ match : treeSitter . QueryMatch ,
111
+ ) : MutableQueryMatch {
112
+ return {
113
+ patternIdx : match . pattern ,
114
+ captures : match . captures . map ( ( { name, node } ) => ( {
115
+ name,
116
+ node,
117
+ document,
118
+ range : getNodeRange ( node ) ,
119
+ insertionDelimiter : undefined ,
120
+ allowMultiple : false ,
121
+ hasError : ( ) => isContainedInErrorNode ( node ) ,
122
+ } ) ) ,
123
+ } ;
124
+ }
154
125
155
- function positionToPoint ( start : Position ) : Point {
156
- return { row : start . line , column : start . character } ;
126
+ private runPredicates ( match : MutableQueryMatch ) : boolean {
127
+ for ( const predicate of this . patternPredicates [ match . patternIdx ] ) {
128
+ if ( ! predicate ( match ) ) {
129
+ return false ;
130
+ }
131
+ }
132
+ return true ;
133
+ }
134
+
135
+ private createQueryMatch ( match : MutableQueryMatch ) : QueryMatch {
136
+ const result : MutableQueryCapture [ ] = [ ] ;
137
+ const map = new Map <
138
+ string ,
139
+ { acc : MutableQueryCapture ; captures : MutableQueryCapture [ ] }
140
+ > ( ) ;
141
+
142
+ // Merge the ranges of all captures with the same name into a single
143
+ // range and return one capture with that name. We consider captures
144
+ // with names `@foo`, `@foo.start`, and `@foo.end` to have the same
145
+ // name, for which we'd return a capture with name `foo`.
146
+
147
+ for ( const capture of match . captures ) {
148
+ const name = normalizeCaptureName ( capture . name ) ;
149
+ const range = getStartOfEndOfRange ( capture ) ;
150
+ const existing = map . get ( name ) ;
151
+
152
+ if ( existing == null ) {
153
+ const captures = [ capture ] ;
154
+ const acc = {
155
+ ...capture ,
156
+ name,
157
+ range,
158
+ hasError : ( ) => captures . some ( ( c ) => c . hasError ( ) ) ,
159
+ } ;
160
+ result . push ( acc ) ;
161
+ map . set ( name , { acc, captures } ) ;
162
+ } else {
163
+ existing . acc . range = existing . acc . range . union ( range ) ;
164
+ existing . acc . allowMultiple =
165
+ existing . acc . allowMultiple || capture . allowMultiple ;
166
+ existing . acc . insertionDelimiter =
167
+ existing . acc . insertionDelimiter ?? capture . insertionDelimiter ;
168
+ existing . captures . push ( capture ) ;
169
+ }
170
+ }
171
+
172
+ if ( this . shouldCheckCaptures ) {
173
+ this . checkCaptures ( Array . from ( map . values ( ) ) ) ;
174
+ }
175
+
176
+ return { captures : result } ;
177
+ }
178
+
179
+ private checkCaptures ( matches : { captures : MutableQueryCapture [ ] } [ ] ) {
180
+ for ( const match of matches ) {
181
+ const capturesAreValid = checkCaptureStartEnd (
182
+ rewriteStartOfEndOf ( match . captures ) ,
183
+ ide ( ) . messages ,
184
+ ) ;
185
+ if ( ! capturesAreValid && ide ( ) . runMode === "test" ) {
186
+ throw new Error ( "Invalid captures" ) ;
187
+ }
188
+ }
189
+ }
157
190
}
0 commit comments