11from  __future__ import  annotations 
22
3- import  ast 
4- import  os 
5- import  re 
6- from  collections  import  defaultdict 
7- from  typing  import  TYPE_CHECKING 
8- 
9- import  jedi 
10- import  tiktoken 
113from  jedi .api .classes  import  Name 
12- 
13- from  codeflash .cli_cmds .console  import  logger 
14- from  codeflash .code_utils .code_extractor  import  get_code 
154from  codeflash .code_utils .code_utils  import  (
165    get_qualified_name ,
17-     module_name_from_file_path ,
18-     path_belongs_to_site_packages ,
19- )
20- from  codeflash .discovery .functions_to_optimize  import  FunctionToOptimize 
21- from  codeflash .models .models  import  FunctionParent , FunctionSource 
22- 
23- if  TYPE_CHECKING :
24-     from  pathlib  import  Path 
256
7+ )
268
279def  belongs_to_method (name : Name , class_name : str , method_name : str ) ->  bool :
2810    """Check if the given name belongs to the specified method.""" 
@@ -58,242 +40,4 @@ def belongs_to_function_qualified(name: Name, qualified_function_name: str) -> b
5840                return  get_qualified_name (name .module_name , name .full_name ) ==  qualified_function_name 
5941        return  False 
6042    except  ValueError :
61-         return  False 
62- 
63- # 
64- # def get_type_annotation_context( 
65- #     function: FunctionToOptimize, jedi_script: jedi.Script, project_root_path: Path 
66- # ) -> tuple[list[FunctionSource], set[tuple[str, str]]]: 
67- #     function_name: str = function.function_name 
68- #     file_path: Path = function.file_path 
69- #     file_contents: str = file_path.read_text(encoding="utf8") 
70- #     try: 
71- #         module: ast.Module = ast.parse(file_contents) 
72- #     except SyntaxError as e: 
73- #         logger.exception(f"get_type_annotation_context - Syntax error in code: {e}") 
74- #         return [], set() 
75- #     sources: list[FunctionSource] = [] 
76- #     ast_parents: list[FunctionParent] = [] 
77- #     contextual_dunder_methods = set() 
78- # 
79- #     def get_annotation_source( 
80- #         j_script: jedi.Script, name: str, node_parents: list[FunctionParent], line_no: int, col_no: str 
81- #     ) -> None: 
82- #         try: 
83- #             definition: list[Name] = j_script.goto( 
84- #                 line=line_no, column=col_no, follow_imports=True, follow_builtin_imports=False 
85- #             ) 
86- #         except Exception as ex: 
87- #             if hasattr(name, "full_name"): 
88- #                 logger.exception(f"Error while getting definition for {name.full_name}: {ex}") 
89- #             else: 
90- #                 logger.exception(f"Error while getting definition: {ex}") 
91- #             definition = [] 
92- #         if definition:  # TODO can be multiple definitions 
93- #             definition_path = definition[0].module_path 
94- # 
95- #             # The definition is part of this project and not defined within the original function 
96- #             if ( 
97- #                 str(definition_path).startswith(str(project_root_path) + os.sep) 
98- #                 and definition[0].full_name 
99- #                 and not path_belongs_to_site_packages(definition_path) 
100- #                 and not belongs_to_function(definition[0], function_name) 
101- #             ): 
102- #                 source_code = get_code([FunctionToOptimize(definition[0].name, definition_path, node_parents[:-1])]) 
103- #                 if source_code[0]: 
104- #                     sources.append( 
105- #                         FunctionSource( 
106- #                             fully_qualified_name=definition[0].full_name, 
107- #                             jedi_definition=definition[0], 
108- #                             source_code=source_code[0], 
109- #                             file_path=definition_path, 
110- #                             qualified_name=definition[0].full_name.removeprefix(definition[0].module_name + "."), 
111- #                             only_function_name=definition[0].name, 
112- #                         ) 
113- #                     ) 
114- #                     contextual_dunder_methods.update(source_code[1]) 
115- # 
116- #     def visit_children( 
117- #         node: ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, node_parents: list[FunctionParent] 
118- #     ) -> None: 
119- #         child: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module 
120- #         for child in ast.iter_child_nodes(node): 
121- #             visit(child, node_parents) 
122- # 
123- #     def visit_all_annotation_children( 
124- #         node: ast.Subscript | ast.Name | ast.BinOp, node_parents: list[FunctionParent] 
125- #     ) -> None: 
126- #         if isinstance(node, ast.BinOp) and isinstance(node.op, ast.BitOr): 
127- #             visit_all_annotation_children(node.left, node_parents) 
128- #             visit_all_annotation_children(node.right, node_parents) 
129- #         if isinstance(node, ast.Name) and hasattr(node, "id"): 
130- #             name: str = node.id 
131- #             line_no: int = node.lineno 
132- #             col_no: int = node.col_offset 
133- #             get_annotation_source(jedi_script, name, node_parents, line_no, col_no) 
134- #         if isinstance(node, ast.Subscript): 
135- #             if hasattr(node, "slice"): 
136- #                 if isinstance(node.slice, ast.Subscript): 
137- #                     visit_all_annotation_children(node.slice, node_parents) 
138- #                 elif isinstance(node.slice, ast.Tuple): 
139- #                     for elt in node.slice.elts: 
140- #                         if isinstance(elt, (ast.Name, ast.Subscript)): 
141- #                             visit_all_annotation_children(elt, node_parents) 
142- #                 elif isinstance(node.slice, ast.Name): 
143- #                     visit_all_annotation_children(node.slice, node_parents) 
144- #             if hasattr(node, "value"): 
145- #                 visit_all_annotation_children(node.value, node_parents) 
146- # 
147- #     def visit( 
148- #         node: ast.AST | ast.FunctionDef | ast.AsyncFunctionDef | ast.ClassDef | ast.Module, 
149- #         node_parents: list[FunctionParent], 
150- #     ) -> None: 
151- #         if isinstance(node, (ast.Module, ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)): 
152- #             if isinstance(node, (ast.FunctionDef, ast.AsyncFunctionDef)): 
153- #                 if node.name == function_name and node_parents == function.parents: 
154- #                     arg: ast.arg 
155- #                     for arg in node.args.args: 
156- #                         if arg.annotation: 
157- #                             visit_all_annotation_children(arg.annotation, node_parents) 
158- #                     if node.returns: 
159- #                         visit_all_annotation_children(node.returns, node_parents) 
160- # 
161- #             if not isinstance(node, ast.Module): 
162- #                 node_parents.append(FunctionParent(node.name, type(node).__name__)) 
163- #             visit_children(node, node_parents) 
164- #             if not isinstance(node, ast.Module): 
165- #                 node_parents.pop() 
166- # 
167- #     visit(module, ast_parents) 
168- # 
169- #     return sources, contextual_dunder_methods 
170- 
171- 
172- # def get_function_variables_definitions( 
173- #     function_to_optimize: FunctionToOptimize, project_root_path: Path 
174- # ) -> tuple[list[FunctionSource], set[tuple[str, str]]]: 
175- #     function_name = function_to_optimize.function_name 
176- #     file_path = function_to_optimize.file_path 
177- #     script = jedi.Script(path=file_path, project=jedi.Project(path=project_root_path)) 
178- #     sources: list[FunctionSource] = [] 
179- #     contextual_dunder_methods = set() 
180- #     # TODO: The function name condition can be stricter so that it does not clash with other class names etc. 
181- #     # TODO: The function could have been imported as some other name, 
182- #     #  we should be checking for the translation as well. Also check for the original function name. 
183- #     names = [] 
184- #     for ref in script.get_names(all_scopes=True, definitions=False, references=True): 
185- #         if ref.full_name: 
186- #             if function_to_optimize.parents: 
187- #                 # Check if the reference belongs to the specified class when FunctionParent is provided 
188- #                 if belongs_to_method(ref, function_to_optimize.parents[-1].name, function_name): 
189- #                     names.append(ref) 
190- #             elif belongs_to_function(ref, function_name): 
191- #                 names.append(ref) 
192- # 
193- #     for name in names: 
194- #         try: 
195- #             definitions: list[Name] = name.goto(follow_imports=True, follow_builtin_imports=False) 
196- #         except Exception as e: 
197- #             try: 
198- #                 logger.exception(f"Error while getting definition for {name.full_name}: {e}") 
199- #             except Exception as e: 
200- #                 # name.full_name can also throw exceptions sometimes 
201- #                 logger.exception(f"Error while getting definition: {e}") 
202- #             definitions = [] 
203- #         if definitions: 
204- #             # TODO: there can be multiple definitions, see how to handle such cases 
205- #             definition = definitions[0] 
206- #             definition_path = definition.module_path 
207- # 
208- #             # The definition is part of this project and not defined within the original function 
209- #             if ( 
210- #                 str(definition_path).startswith(str(project_root_path) + os.sep) 
211- #                 and not path_belongs_to_site_packages(definition_path) 
212- #                 and definition.full_name 
213- #                 and not belongs_to_function(definition, function_name) 
214- #             ): 
215- #                 module_name = module_name_from_file_path(definition_path, project_root_path) 
216- #                 m = re.match(rf"{module_name}\.(.*)\.{definitions[0].name}", definitions[0].full_name) 
217- #                 parents = [] 
218- #                 if m: 
219- #                     parents = [FunctionParent(m.group(1), "ClassDef")] 
220- # 
221- #                 source_code = get_code( 
222- #                     [FunctionToOptimize(function_name=definitions[0].name, file_path=definition_path, parents=parents)] 
223- #                 ) 
224- #                 if source_code[0]: 
225- #                     sources.append( 
226- #                         FunctionSource( 
227- #                             fully_qualified_name=definition.full_name, 
228- #                             jedi_definition=definition, 
229- #                             source_code=source_code[0], 
230- #                             file_path=definition_path, 
231- #                             qualified_name=definition.full_name.removeprefix(definition.module_name + "."), 
232- #                             only_function_name=definition.name, 
233- #                         ) 
234- #                     ) 
235- #                     contextual_dunder_methods.update(source_code[1]) 
236- #     annotation_sources, annotation_dunder_methods = get_type_annotation_context( 
237- #         function_to_optimize, script, project_root_path 
238- #     ) 
239- #     sources[:0] = annotation_sources  # prepend the annotation sources 
240- #     contextual_dunder_methods.update(annotation_dunder_methods) 
241- #     existing_fully_qualified_names = set() 
242- #     no_parent_sources: dict[Path, dict[str, set[FunctionSource]]] = defaultdict(lambda: defaultdict(set)) 
243- #     parent_sources = set() 
244- #     for source in sources: 
245- #         if (fully_qualified_name := source.fully_qualified_name) not in existing_fully_qualified_names: 
246- #             if not source.qualified_name.count("."): 
247- #                 no_parent_sources[source.file_path][source.qualified_name].add(source) 
248- #             else: 
249- #                 parent_sources.add(source) 
250- #             existing_fully_qualified_names.add(fully_qualified_name) 
251- #     deduped_parent_sources = [ 
252- #         source 
253- #         for source in parent_sources 
254- #         if source.file_path not in no_parent_sources 
255- #         or source.qualified_name.rpartition(".")[0] not in no_parent_sources[source.file_path] 
256- #     ] 
257- #     deduped_no_parent_sources = [ 
258- #         source for k1 in no_parent_sources for k2 in no_parent_sources[k1] for source in no_parent_sources[k1][k2] 
259- #     ] 
260- #     return deduped_no_parent_sources + deduped_parent_sources, contextual_dunder_methods 
261- # 
262- # 
263- # MAX_PROMPT_TOKENS = 4096  # 128000  # gpt-4-128k 
264- # 
265- # 
266- # def get_constrained_function_context_and_helper_functions( 
267- #     function_to_optimize: FunctionToOptimize, 
268- #     project_root_path: Path, 
269- #     code_to_optimize: str, 
270- #     max_tokens: int = MAX_PROMPT_TOKENS, 
271- # ) -> tuple[str, list[FunctionSource], set[tuple[str, str]]]: 
272- #     helper_functions, dunder_methods = get_function_variables_definitions(function_to_optimize, project_root_path) 
273- #     tokenizer = tiktoken.encoding_for_model("gpt-3.5-turbo") 
274- #     code_to_optimize_tokens = tokenizer.encode(code_to_optimize) 
275- # 
276- #     if not function_to_optimize.parents: 
277- #         helper_functions_sources = [function.source_code for function in helper_functions] 
278- #     else: 
279- #         helper_functions_sources = [ 
280- #             function.source_code 
281- #             for function in helper_functions 
282- #             if not function.qualified_name.count(".") 
283- #             or function.qualified_name.split(".")[0] != function_to_optimize.parents[0].name 
284- #         ] 
285- #     helper_functions_tokens = [len(tokenizer.encode(function)) for function in helper_functions_sources] 
286- # 
287- #     context_list = [] 
288- #     context_len = len(code_to_optimize_tokens) 
289- #     logger.debug(f"ORIGINAL CODE TOKENS LENGTH: {context_len}") 
290- #     logger.debug(f"ALL DEPENDENCIES TOKENS LENGTH: {sum(helper_functions_tokens)}") 
291- #     for function_source, source_len in zip(helper_functions_sources, helper_functions_tokens): 
292- #         if context_len + source_len <= max_tokens: 
293- #             context_list.append(function_source) 
294- #             context_len += source_len 
295- #         else: 
296- #             break 
297- #     logger.debug(f"FINAL OPTIMIZATION CONTEXT TOKENS LENGTH: {context_len}") 
298- #     helper_code: str = "\n".join(context_list) 
299- #     return helper_code, helper_functions, dunder_methods 
43+         return  False 
0 commit comments