11from lsprotocol .types import Range , Position
22import typing as t
33from pathlib import Path
4+ from pydantic import Field
45
56from sqlmesh .core .audit import StandaloneAudit
67from sqlmesh .core .dialect import normalize_model_name
2324import inspect
2425
2526
26- class Reference (PydanticModel ):
27- """
28- A reference to a model or CTE.
27+ class LSPModelReference (PydanticModel ):
28+ """A LSP reference to a model."""
29+
30+ type : t .Literal ["model" ] = "model"
31+ uri : str
32+ range : Range
33+ markdown_description : t .Optional [str ] = None
2934
30- Attributes:
31- range: The range of the reference in the source file
32- uri: The uri of the referenced model
33- markdown_description: The markdown description of the referenced model
34- target_range: The range of the definition for go-to-definition (optional, used for CTEs)
35- """
3635
36+ class LSPCteReference (PydanticModel ):
37+ """A LSP reference to a CTE."""
38+
39+ type : t .Literal ["cte" ] = "cte"
40+ uri : str
3741 range : Range
42+ target_range : Range
43+
44+
45+ class LSPMacroReference (PydanticModel ):
46+ """A LSP reference to a macro."""
47+
48+ type : t .Literal ["macro" ] = "macro"
3849 uri : str
50+ range : Range
51+ target_range : Range
3952 markdown_description : t .Optional [str ] = None
40- target_range : t .Optional [Range ] = None
53+
54+
55+ Reference = t .Annotated [
56+ t .Union [LSPModelReference , LSPCteReference , LSPMacroReference ], Field (discriminator = "type" )
57+ ]
4158
4259
4360def by_position (position : Position ) -> t .Callable [[Reference ], bool ]:
@@ -136,7 +153,7 @@ def get_model_definitions_for_a_path(
136153 return []
137154
138155 # Find all possible references
139- references = []
156+ references : t . List [ Reference ] = []
140157
141158 with open (file_path , "r" , encoding = "utf-8" ) as file :
142159 read_file = file .readlines ()
@@ -173,7 +190,7 @@ def get_model_definitions_for_a_path(
173190 table_range = to_lsp_range (table_range_sqlmesh )
174191
175192 references .append (
176- Reference (
193+ LSPCteReference (
177194 uri = document_uri .value , # Same file
178195 range = table_range ,
179196 target_range = target_range ,
@@ -227,7 +244,7 @@ def get_model_definitions_for_a_path(
227244 description = generate_markdown_description (referenced_model )
228245
229246 references .append (
230- Reference (
247+ LSPModelReference (
231248 uri = referenced_model_uri .value ,
232249 range = Range (
233250 start = to_lsp_position (start_pos_sqlmesh ),
@@ -286,7 +303,7 @@ def get_macro_definitions_for_a_path(
286303 return []
287304
288305 references = []
289- config_for_model , config_path = lsp_context .context .config_for_path (
306+ _ , config_path = lsp_context .context .config_for_path (
290307 file_path ,
291308 )
292309
@@ -372,7 +389,7 @@ def get_macro_reference(
372389 # Create a reference to the macro definition
373390 macro_uri = URI .from_path (path )
374391
375- return Reference (
392+ return LSPMacroReference (
376393 uri = macro_uri .value ,
377394 range = to_lsp_range (macro_range ),
378395 target_range = Range (
@@ -405,7 +422,7 @@ def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optio
405422 # Calculate the end line number by counting the number of source lines
406423 end_line_number = line_number + len (source_lines ) - 1
407424
408- return Reference (
425+ return LSPMacroReference (
409426 uri = URI .from_path (Path (filename )).value ,
410427 range = macro_range ,
411428 target_range = Range (
@@ -416,9 +433,99 @@ def get_built_in_macro_reference(macro_name: str, macro_range: Range) -> t.Optio
416433 )
417434
418435
436+ def get_model_find_all_references (
437+ lint_context : LSPContext , document_uri : URI , position : Position
438+ ) -> t .List [LSPModelReference ]:
439+ """
440+ Get all references to a model across the entire project.
441+
442+ This function finds all usages of a model in other files by searching through
443+ all models in the project and checking their dependencies.
444+
445+ Args:
446+ lint_context: The LSP context
447+ document_uri: The URI of the document
448+ position: The position to check for model references
449+
450+ Returns:
451+ A list of references to the model across all files
452+ """
453+ # First, get the references in the current file to determine what model we're looking for
454+ current_file_references = [
455+ ref
456+ for ref in get_model_definitions_for_a_path (lint_context , document_uri )
457+ if isinstance (ref , LSPModelReference )
458+ ]
459+
460+ # Find the model reference at the cursor position
461+ target_model_uri : t .Optional [str ] = None
462+ for ref in current_file_references :
463+ if _position_within_range (position , ref .range ):
464+ # This is a model reference, get the target model URI
465+ target_model_uri = ref .uri
466+ break
467+
468+ if target_model_uri is None :
469+ return []
470+
471+ # Start with the model definition
472+ all_references : t .List [LSPModelReference ] = [
473+ LSPModelReference (
474+ uri = ref .uri ,
475+ range = Range (
476+ start = Position (line = 0 , character = 0 ),
477+ end = Position (line = 0 , character = 0 ),
478+ ),
479+ markdown_description = ref .markdown_description ,
480+ )
481+ ]
482+
483+ # Then add the original reference
484+ for ref in current_file_references :
485+ if ref .uri == target_model_uri and isinstance (ref , LSPModelReference ):
486+ all_references .append (
487+ LSPModelReference (
488+ uri = document_uri .value ,
489+ range = ref .range ,
490+ markdown_description = ref .markdown_description ,
491+ )
492+ )
493+
494+ # Search through the models in the project
495+ for path , target in lint_context .map .items ():
496+ if not isinstance (target , (ModelTarget , AuditTarget )):
497+ continue
498+
499+ file_uri = URI .from_path (path )
500+
501+ # Skip current file, already processed
502+ if file_uri .value == document_uri .value :
503+ continue
504+
505+ # Get model references for this file
506+ file_references = [
507+ ref
508+ for ref in get_model_definitions_for_a_path (lint_context , file_uri )
509+ if isinstance (ref , LSPModelReference )
510+ ]
511+
512+ # Add references that point to the target model file
513+ for ref in file_references :
514+ if ref .uri == target_model_uri and isinstance (ref , LSPModelReference ):
515+ all_references .append (
516+ LSPModelReference (
517+ uri = file_uri .value ,
518+ range = ref .range ,
519+ markdown_description = ref .markdown_description ,
520+ )
521+ )
522+
523+ return all_references
524+
525+
419526def get_cte_references (
420527 lint_context : LSPContext , document_uri : URI , position : Position
421- ) -> t .List [Reference ]:
528+ ) -> t .List [LSPCteReference ]:
422529 """
423530 Get all references to a CTE at a specific position in a document.
424531
@@ -432,12 +539,12 @@ def get_cte_references(
432539 Returns:
433540 A list of references to the CTE (including its definition and all usages)
434541 """
435- references = get_model_definitions_for_a_path (lint_context , document_uri )
436542
437- # Filter for CTE references (those with target_range set and same URI)
438- # TODO: Consider extending Reference class to explicitly indicate reference type instead
439- cte_references = [
440- ref for ref in references if ref .target_range is not None and ref .uri == document_uri .value
543+ # Filter to get the CTE references
544+ cte_references : t .List [LSPCteReference ] = [
545+ ref
546+ for ref in get_model_definitions_for_a_path (lint_context , document_uri )
547+ if isinstance (ref , LSPCteReference )
441548 ]
442549
443550 if not cte_references :
@@ -450,7 +557,7 @@ def get_cte_references(
450557 target_cte_definition_range = ref .target_range
451558 break
452559 # Check if cursor is on the CTE definition
453- elif ref . target_range and _position_within_range (position , ref .target_range ):
560+ elif _position_within_range (position , ref .target_range ):
454561 target_cte_definition_range = ref .target_range
455562 break
456563
@@ -459,27 +566,55 @@ def get_cte_references(
459566
460567 # Add the CTE definition
461568 matching_references = [
462- Reference (
569+ LSPCteReference (
463570 uri = document_uri .value ,
464571 range = target_cte_definition_range ,
465- markdown_description = "CTE definition" ,
572+ target_range = target_cte_definition_range ,
466573 )
467574 ]
468575
469576 # Add all usages
470577 for ref in cte_references :
471578 if ref .target_range == target_cte_definition_range :
472579 matching_references .append (
473- Reference (
580+ LSPCteReference (
474581 uri = document_uri .value ,
475582 range = ref .range ,
476- markdown_description = "CTE usage" ,
583+ target_range = ref . target_range ,
477584 )
478585 )
479586
480587 return matching_references
481588
482589
590+ def get_all_references (
591+ lint_context : LSPContext , document_uri : URI , position : Position
592+ ) -> t .Sequence [Reference ]:
593+ """
594+ Get all references of a symbol at a specific position in a document.
595+
596+ This function determines the type of reference (CTE, model for now) at the cursor
597+ position and returns all references to that symbol across the project.
598+
599+ Args:
600+ lint_context: The LSP context
601+ document_uri: The URI of the document
602+ position: The position to check for references
603+
604+ Returns:
605+ A list of references to the symbol at the given position
606+ """
607+ # First try CTE references (within same file)
608+ if cte_references := get_cte_references (lint_context , document_uri , position ):
609+ return cte_references
610+
611+ # Then try model references (across files)
612+ if model_references := get_model_find_all_references (lint_context , document_uri , position ):
613+ return model_references
614+
615+ return []
616+
617+
483618def _position_within_range (position : Position , range : Range ) -> bool :
484619 """Check if a position is within a given range."""
485620 return (
0 commit comments