@@ -450,75 +450,76 @@ def get_model_find_all_references(
450450 Returns:
451451 A list of references to the model across all files
452452 """
453- # First, get the references in the current file to determine what model we're looking for
454- current_file_references = [
455- ref
456- for ref in get_model_definitions_for_a_path (lint_context , document_uri )
457- if isinstance (ref , LSPModelReference )
458- ]
459-
460453 # Find the model reference at the cursor position
461- target_model_uri : t .Optional [str ] = None
462- for ref in current_file_references :
463- if _position_within_range (position , ref .range ):
464- # This is a model reference, get the target model URI
465- target_model_uri = ref .uri
466- break
454+ model_at_position = next (
455+ filter (
456+ lambda ref : isinstance (ref , LSPModelReference )
457+ and _position_within_range (position , ref .range ),
458+ get_model_definitions_for_a_path (lint_context , document_uri ),
459+ ),
460+ None ,
461+ )
467462
468- if target_model_uri is None :
463+ if not model_at_position :
469464 return []
470465
466+ assert isinstance (model_at_position , LSPModelReference ) # for mypy
467+
468+ target_model_uri = model_at_position .uri
469+
471470 # Start with the model definition
472471 all_references : t .List [LSPModelReference ] = [
473472 LSPModelReference (
474- uri = ref .uri ,
473+ uri = model_at_position .uri ,
475474 range = Range (
476475 start = Position (line = 0 , character = 0 ),
477476 end = Position (line = 0 , character = 0 ),
478477 ),
479- markdown_description = ref .markdown_description ,
478+ markdown_description = model_at_position .markdown_description ,
480479 )
481480 ]
482481
483- # Then add the original reference
484- for ref in current_file_references :
485- if ref .uri == target_model_uri and isinstance (ref , LSPModelReference ):
486- all_references .append (
487- LSPModelReference (
488- uri = document_uri .value ,
489- range = ref .range ,
490- markdown_description = ref .markdown_description ,
491- )
482+ # Then add references from the current file
483+ current_file_refs = filter (
484+ lambda ref : isinstance (ref , LSPModelReference ) and ref .uri == target_model_uri ,
485+ get_model_definitions_for_a_path (lint_context , document_uri ),
486+ )
487+
488+ for ref in current_file_refs :
489+ assert isinstance (ref , LSPModelReference ) # for mypy
490+
491+ all_references .append (
492+ LSPModelReference (
493+ uri = document_uri .value ,
494+ range = ref .range ,
495+ markdown_description = ref .markdown_description ,
492496 )
497+ )
493498
494499 # Search through the models in the project
495- for path , target in lint_context .map .items ():
496- if not isinstance (target , (ModelTarget , AuditTarget )):
497- continue
498-
500+ for path , _ in lint_context .map .items ():
499501 file_uri = URI .from_path (path )
500502
501503 # Skip current file, already processed
502504 if file_uri .value == document_uri .value :
503505 continue
504506
505- # Get model references for this file
506- file_references = [
507- ref
508- for ref in get_model_definitions_for_a_path (lint_context , file_uri )
509- if isinstance (ref , LSPModelReference )
510- ]
511-
512- # Add references that point to the target model file
513- for ref in file_references :
514- if ref .uri == target_model_uri and isinstance (ref , LSPModelReference ):
515- all_references .append (
516- LSPModelReference (
517- uri = file_uri .value ,
518- range = ref .range ,
519- markdown_description = ref .markdown_description ,
520- )
507+ # Get model references that point to the target model
508+ matching_refs = filter (
509+ lambda ref : isinstance (ref , LSPModelReference ) and ref .uri == target_model_uri ,
510+ get_model_definitions_for_a_path (lint_context , file_uri ),
511+ )
512+
513+ for ref in matching_refs :
514+ assert isinstance (ref , LSPModelReference ) # for mypy
515+
516+ all_references .append (
517+ LSPModelReference (
518+ uri = file_uri .value ,
519+ range = ref .range ,
520+ markdown_description = ref .markdown_description ,
521521 )
522+ )
522523
523524 return all_references
524525
@@ -587,13 +588,83 @@ def get_cte_references(
587588 return matching_references
588589
589590
591+ def get_macro_find_all_references (
592+ lsp_context : LSPContext , document_uri : URI , position : Position
593+ ) -> t .List [LSPMacroReference ]:
594+ """
595+ Get all references to a macro at a specific position in a document.
596+
597+ This function finds all usages of a macro across the entire project.
598+
599+ Args:
600+ lsp_context: The LSP context
601+ document_uri: The URI of the document
602+ position: The position to check for macro references
603+
604+ Returns:
605+ A list of references to the macro across all files
606+ """
607+ # Find the macro reference at the cursor position
608+ macro_at_position = next (
609+ filter (
610+ lambda ref : isinstance (ref , LSPMacroReference )
611+ and _position_within_range (position , ref .range ),
612+ get_macro_definitions_for_a_path (lsp_context , document_uri ),
613+ ),
614+ None ,
615+ )
616+
617+ if not macro_at_position :
618+ return []
619+
620+ assert isinstance (macro_at_position , LSPMacroReference ) # for mypy
621+
622+ target_macro_uri = macro_at_position .uri
623+ target_macro_target_range = macro_at_position .target_range
624+
625+ # Start with the macro definition
626+ all_references : t .List [LSPMacroReference ] = [
627+ LSPMacroReference (
628+ uri = target_macro_uri ,
629+ range = target_macro_target_range ,
630+ target_range = target_macro_target_range ,
631+ markdown_description = None ,
632+ )
633+ ]
634+
635+ # Search through all SQL and audit files in the project
636+ for path , _ in lsp_context .map .items ():
637+ file_uri = URI .from_path (path )
638+
639+ # Get macro references that point to the same macro definition
640+ matching_refs = filter (
641+ lambda ref : isinstance (ref , LSPMacroReference )
642+ and ref .uri == target_macro_uri
643+ and ref .target_range == target_macro_target_range ,
644+ get_macro_definitions_for_a_path (lsp_context , file_uri ),
645+ )
646+
647+ for ref in matching_refs :
648+ assert isinstance (ref , LSPMacroReference ) # for mypy
649+ all_references .append (
650+ LSPMacroReference (
651+ uri = file_uri .value ,
652+ range = ref .range ,
653+ target_range = ref .target_range ,
654+ markdown_description = ref .markdown_description ,
655+ )
656+ )
657+
658+ return all_references
659+
660+
590661def get_all_references (
591662 lint_context : LSPContext , document_uri : URI , position : Position
592663) -> t .Sequence [Reference ]:
593664 """
594665 Get all references of a symbol at a specific position in a document.
595666
596- This function determines the type of reference (CTE, model for now ) at the cursor
667+ This function determines the type of reference (CTE, model or macro ) at the cursor
597668 position and returns all references to that symbol across the project.
598669
599670 Args:
@@ -612,6 +683,10 @@ def get_all_references(
612683 if model_references := get_model_find_all_references (lint_context , document_uri , position ):
613684 return model_references
614685
686+ # Finally try macro references (across files)
687+ if macro_references := get_macro_find_all_references (lint_context , document_uri , position ):
688+ return macro_references
689+
615690 return []
616691
617692
0 commit comments