11import type { Position , TextDocument } from "@cursorless/common" ;
22import { showError , type TreeSitter } from "@cursorless/common" ;
3- import { groupBy , uniq } from "lodash-es " ;
4- import type { Point , Query } from "web-tree-sitter" ;
3+ import type Parser from "web-tree-sitter " ;
4+ import type { Query } from "web-tree-sitter" ;
55import { ide } from "../../singletons/ide.singleton" ;
66import { getNodeRange } from "../../util/nodeSelectors" ;
77import type {
8+ ModifiableQueryCapture ,
9+ MutableQueryCapture ,
810 MutableQueryMatch ,
9- QueryCapture ,
1011 QueryMatch ,
1112} from "./QueryCapture" ;
1213import { checkCaptureStartEnd } from "./checkCaptureStartEnd" ;
1314import { isContainedInErrorNode } from "./isContainedInErrorNode" ;
15+ import { normalizeCaptureName } from "./normalizeCaptureName" ;
1416import { parsePredicates } from "./parsePredicates" ;
17+ import { positionToPoint } from "./positionToPoint" ;
1518import { predicateToString } from "./predicateToString" ;
16- import { rewriteStartOfEndOf } from "./rewriteStartOfEndOf" ;
19+ import {
20+ getStartOfEndOfRange ,
21+ rewriteStartOfEndOf ,
22+ } from "./rewriteStartOfEndOf" ;
1723import { treeSitterQueryCache } from "./treeSitterQueryCache" ;
1824
1925/**
@@ -67,6 +73,12 @@ export class TreeSitterQuery {
6773 return new TreeSitterQuery ( treeSitter , query , predicates ) ;
6874 }
6975
76+ hasCapture ( name : string ) : boolean {
77+ return this . query . captureNames . some (
78+ ( n ) => normalizeCaptureName ( n ) === name ,
79+ ) ;
80+ }
81+
7082 matches (
7183 document : TextDocument ,
7284 start ?: Position ,
@@ -84,74 +96,111 @@ export class TreeSitterQuery {
8496 start ?: Position ,
8597 end ?: Position ,
8698 ) : QueryMatch [ ] {
87- return this . query
88- . matches ( this . treeSitter . getTree ( document ) . rootNode , {
89- startPosition : start == null ? undefined : positionToPoint ( start ) ,
90- endPosition : end == null ? undefined : positionToPoint ( end ) ,
91- } )
92- . map (
93- ( { pattern, captures } ) : MutableQueryMatch => ( {
94- patternIdx : pattern ,
95- captures : captures . map ( ( { name, node } ) => ( {
96- name,
97- node,
98- document,
99- range : getNodeRange ( node ) ,
100- insertionDelimiter : undefined ,
101- allowMultiple : false ,
102- hasError : ( ) => isContainedInErrorNode ( node ) ,
103- } ) ) ,
104- } ) ,
105- )
106- . filter ( ( match ) =>
107- this . patternPredicates [ match . patternIdx ] . every ( ( predicate ) =>
108- predicate ( match ) ,
109- ) ,
110- )
111- . map ( ( match ) : QueryMatch => {
112- // Merge the ranges of all captures with the same name into a single
113- // range and return one capture with that name. We consider captures
114- // with names `@foo`, `@foo.start`, and `@foo.end` to have the same
115- // name, for which we'd return a capture with name `foo`.
116- const captures : QueryCapture [ ] = Object . entries (
117- groupBy ( match . captures , ( { name } ) => normalizeCaptureName ( name ) ) ,
118- ) . map ( ( [ name , captures ] ) => {
119- captures = rewriteStartOfEndOf ( captures ) ;
120- const capturesAreValid = checkCaptureStartEnd (
121- captures ,
122- ide ( ) . messages ,
123- ) ;
124-
125- if ( ! capturesAreValid && ide ( ) . runMode === "test" ) {
126- throw new Error ( "Invalid captures" ) ;
127- }
128-
129- return {
130- name,
131- range : captures
132- . map ( ( { range } ) => range )
133- . reduce ( ( accumulator , range ) => range . union ( accumulator ) ) ,
134- allowMultiple : captures . some ( ( capture ) => capture . allowMultiple ) ,
135- insertionDelimiter : captures . find (
136- ( capture ) => capture . insertionDelimiter != null ,
137- ) ?. insertionDelimiter ,
138- hasError : ( ) => captures . some ( ( capture ) => capture . hasError ( ) ) ,
139- } ;
140- } ) ;
141-
142- return { ...match , captures } ;
143- } ) ;
99+ const results : QueryMatch [ ] = [ ] ;
100+ const isTesting = ide ( ) . runMode === "test" ;
101+
102+ const matches = this . query . matches (
103+ this . treeSitter . getTree ( document ) . rootNode ,
104+ {
105+ startPosition : start != null ? positionToPoint ( start ) : undefined ,
106+ endPosition : end != null ? positionToPoint ( end ) : undefined ,
107+ } ,
108+ ) ;
109+
110+ for ( const match of matches ) {
111+ const mutableMatch = createMutableQueryMatch ( document , match ) ;
112+
113+ if ( this . runPredicates ( mutableMatch ) ) {
114+ results . push ( createQueryMatch ( mutableMatch , isTesting ) ) ;
115+ }
116+ }
117+
118+ return results ;
144119 }
145120
146- get captureNames ( ) {
147- return uniq ( this . query . captureNames . map ( normalizeCaptureName ) ) ;
121+ private runPredicates ( match : MutableQueryMatch ) : boolean {
122+ for ( const predicate of this . patternPredicates [ match . patternIdx ] ) {
123+ if ( ! predicate ( match ) ) {
124+ return false ;
125+ }
126+ }
127+ return true ;
148128 }
149129}
150130
151- function normalizeCaptureName ( name : string ) : string {
152- return name . replace ( / ( \. ( s t a r t | e n d ) ) ? ( \. ( s t a r t O f | e n d O f ) ) ? $ / , "" ) ;
131+ function createMutableQueryMatch (
132+ document : TextDocument ,
133+ match : Parser . QueryMatch ,
134+ ) : MutableQueryMatch {
135+ return {
136+ patternIdx : match . pattern ,
137+ captures : match . captures . map ( ( { name, node } ) => ( {
138+ name,
139+ node,
140+ document,
141+ range : getNodeRange ( node ) ,
142+ insertionDelimiter : undefined ,
143+ allowMultiple : false ,
144+ hasError : ( ) => isContainedInErrorNode ( node ) ,
145+ } ) ) ,
146+ } ;
153147}
154148
155- function positionToPoint ( start : Position ) : Point {
156- return { row : start . line , column : start . character } ;
149+ function createQueryMatch (
150+ match : MutableQueryMatch ,
151+ isTesting : boolean ,
152+ ) : QueryMatch {
153+ const result : ModifiableQueryCapture [ ] = [ ] ;
154+ const resultMap = new Map <
155+ string ,
156+ { acc : ModifiableQueryCapture ; captures : MutableQueryCapture [ ] }
157+ > ( ) ;
158+
159+ // Merge the ranges of all captures with the same name into a single
160+ // range and return one capture with that name. We consider captures
161+ // with names `@foo`, `@foo.start`, and `@foo.end` to have the same
162+ // name, for which we'd return a capture with name `foo`.
163+
164+ for ( const capture of match . captures ) {
165+ const name = normalizeCaptureName ( capture . name ) ;
166+ const range = getStartOfEndOfRange ( capture ) ;
167+ const existing = resultMap . get ( name ) ;
168+
169+ if ( existing == null ) {
170+ const accumulator = {
171+ name,
172+ range,
173+ allowMultiple : capture . allowMultiple ,
174+ insertionDelimiter : capture . insertionDelimiter ,
175+ hasError : ( ) => capture . hasError ( ) ,
176+ } ;
177+ result . push ( accumulator ) ;
178+ resultMap . set ( name , {
179+ acc : accumulator ,
180+ captures : [ capture ] ,
181+ } ) ;
182+ } else {
183+ existing . acc . range = existing . acc . range . union ( range ) ;
184+ existing . acc . allowMultiple =
185+ existing . acc . allowMultiple || capture . allowMultiple ;
186+ existing . acc . insertionDelimiter =
187+ existing . acc . insertionDelimiter ?? capture . insertionDelimiter ;
188+ existing . acc . hasError = ( ) => existing . captures . some ( ( c ) => c . hasError ( ) ) ;
189+ existing . captures . push ( capture ) ;
190+ }
191+ }
192+
193+ if ( isTesting ) {
194+ for ( const captureGroup of resultMap . values ( ) ) {
195+ const capturesAreValid = checkCaptureStartEnd (
196+ rewriteStartOfEndOf ( captureGroup . captures ) ,
197+ ide ( ) . messages ,
198+ ) ;
199+ if ( ! capturesAreValid ) {
200+ throw new Error ( "Invalid captures" ) ;
201+ }
202+ }
203+ }
204+
205+ return { captures : result } ;
157206}
0 commit comments