33from tree_sitter_languages .core import get_language , get_parser
44from typing_extensions import Protocol
55
6+ from patchwork .common .context_strategy .langugues import LanguageProtocol
67from patchwork .common .context_strategy .position import Position
78
89
@@ -47,20 +48,32 @@ def is_file_supported(self, filename: str, src: list[str]) -> bool:
4748 """
4849 ...
4950
51+ @property
52+ def language (self ) -> LanguageProtocol :
53+ """
54+ Retrieve the language for the current context strategy.
55+
56+ Returns:
57+ str: The language for the current context strategy.
58+ """
59+ ...
60+
5061
5162class TreeSitterStrategy (ContextStrategyProtocol ):
52- def __init__ (self , language : str , query : str , exts : list [str ]):
63+ def __init__ (self , language : str , query : str , exts : list [str ], language_protocol : LanguageProtocol ):
5364 """
5465 Initialize the instance with specified language, query, and file extensions.
5566
5667 Args:
5768 language (str): The programming language for the search.
5869 query (str): The search query.
5970 exts (list[str]): The list of file extensions to consider for the search.
71+ language_protocol (LanguageProtocol): The language protocol associated.
6072 """
61- self .language = language
73+ self .tree_sitter_language = language
6274 self .query = query
6375 self .exts = exts
76+ self .language_protocol = language_protocol
6477
6578 def query_src (self , src : list [str ]):
6679 """
@@ -72,8 +85,8 @@ def query_src(self, src: list[str]):
7285 Returns:
7386 list: Returns a list of captures that match the query in the source code's abstract syntax tree (AST).
7487 """
75- language = get_language (self .language )
76- parser = get_parser (self .language )
88+ language = get_language (self .tree_sitter_language )
89+ parser = get_parser (self .tree_sitter_language )
7790 tree = parser .parse ("" .join (src ).encode ("utf-8-sig" ))
7891 return language .query (self .query ).captures (tree .root_node )
7992
@@ -100,6 +113,7 @@ def get_contexts(self, src: list[str]) -> list[Position]:
100113 end = node .end_point [0 ] + 1 ,
101114 start_col = node .start_point [1 ],
102115 end_col = node .end_point [1 ] + 1 ,
116+ language = self .language ,
103117 )
104118 positions .append (position )
105119
@@ -115,6 +129,7 @@ def get_contexts(self, src: list[str]) -> list[Position]:
115129 end = node .end_point [0 ] + 1 ,
116130 start_col = node .start_point [1 ],
117131 end_col = node .end_point [1 ] + 1 ,
132+ language = self .language ,
118133 )
119134 break
120135
@@ -150,3 +165,7 @@ def is_file_supported(self, filename: str, src: list[str]) -> bool:
150165 bool: True if the file's extension is in the list of supported extensions and `src` is not empty, otherwise False.
151166 """
152167 return any (filename .endswith (ext ) for ext in self .exts ) and len (src ) > 0
168+
169+ @property
170+ def language (self ) -> LanguageProtocol :
171+ return self .language_protocol
0 commit comments