@@ -52,6 +52,9 @@ def extract_names_from_patch(self, patch_content: str) -> Tuple[Set[str], Set[st
5252 pass
5353
5454 @abstractmethod
55+ def extract_functions_from_patch (self , patch_content : str ) -> Set [str ]:
56+ pass
57+ @abstractmethod
5558 def extract_definitions (self , content : str , names : Set [str ]) -> Dict [str , str ]:
5659 pass
5760
@@ -196,6 +199,26 @@ def extract_names_from_patch(self, patch_content: str) -> Tuple[Set[str], Set[st
196199
197200 return functions , variables
198201
202+ def extract_functions_from_patch (self , patch_content : str ) -> Set [str ]:
203+ # 用于存储提取的信息
204+ extracted_info = set ()
205+
206+ # 正则表达式模式
207+ type_pattern = r'\b([A-Z][a-zA-Z0-9_]*)\b'
208+ function_pattern = r'\b([a-zA-Z_][a-zA-Z0-9_]*)\s*\('
209+ variable_pattern = r'\b([a-z_][a-zA-Z0-9_]*)\b'
210+
211+ # 逐行分析 patch 内容
212+ for line in patch_content .split ('\n ' ):
213+ # 提取类型(假设以大写字母开头)
214+ types = re .findall (type_pattern , line )
215+ extracted_info .update (types )
216+
217+ # 提取函数(假设后面跟着括号)
218+ functions = re .findall (function_pattern , line )
219+ extracted_info .update (functions )
220+ return extracted_info
221+
199222 def extract_definitions (self , content : str , names : Set [str ]) -> Dict [str , str ]:
200223 tree = ast .parse (content )
201224 definitions = {}
@@ -293,7 +316,7 @@ def _is_from_project(self, node, current_file: str) -> bool:
293316 return False
294317 file_path = os .path .abspath (node .location .file .name )
295318 return file_path .startswith (self .project_root ) and (
296- file_path == current_file or not file_path .endswith (('.h' , '.hpp' )))
319+ file_path == current_file or not file_path .endswith (('.h' , '.hpp' )))
297320
298321 def _get_element_content (self , node ) -> str :
299322 try :
@@ -315,23 +338,55 @@ def _is_likely_external(self, content: str) -> bool:
315338
316339 def analyze_dependencies (self , file_path : str , content : str ) -> List [str ]:
317340 """
318- 分析文件的依赖关系,并过滤掉非项目内的依赖
341+ 分析文件的依赖关系,并过滤掉非项目内的依赖
319342
320- :param file_path: 当前分析的文件路径
321- :param content: 文件内容
322- :param base_path: 项目的基础路径
323- :return: 项目内的依赖列表
324- """
343+ :param file_path: 当前分析的文件路径
344+ :param content: 文件内容
345+ :param base_path: 项目的基础路径
346+ :return: 项目内的依赖列表
347+ """
325348 # 查找所有的 #include 语句
326349 includes = re .findall (r'#include\s*[<"]([^>"]+)[>"]' , content )
327350
328351 # 转换和过滤依赖
329352 project_dependencies = self .find_dependencies (file_path , includes )
330- # 去重并返回
331- return list (set (project_dependencies ))
353+ # 对于每个头文件依赖,尝试找到对应的实现文件
354+ implementation_dependencies = []
355+ for dep in project_dependencies :
356+ impl_file = self .find_implementation_file (dep )
357+ if impl_file :
358+ implementation_dependencies .append (impl_file )
359+
360+ # 合并头文件和实现文件的依赖,去重并返回
361+ all_dependencies = list (set (project_dependencies + implementation_dependencies ))
362+ return all_dependencies
363+
364+ def find_implementation_file (self , header_path : str ) -> Optional [str ]:
365+ """
366+ 根据头文件路径查找对应的实现文件
367+
368+ :param header_path: 头文件的相对路径
369+ :return: 实现文件的相对路径,如果找不到则返回None
370+ """
371+ implementation_extensions = ['.cpp' , '.cxx' , '.cc' , '.c' ]
372+ base_name = os .path .splitext (header_path )[0 ]
373+
374+ for ext in implementation_extensions :
375+ impl_path = base_name + ext
376+ if impl_path in self .file_index .values ():
377+ return impl_path
378+
379+ # 如果在同一目录下找不到,尝试在整个项目中查找
380+ file_name = os .path .basename (base_name )
381+ for ext in implementation_extensions :
382+ impl_file = file_name + ext
383+ if impl_file in self .file_index :
384+ return self .file_index [impl_file ]
385+
386+ return None
332387
333388 def extract_names_from_patch (self , patch_content : str ) -> Tuple [Set [str ], Set [str ]]:
334- tu = self .index .parse ('tmp.cpp' , unsaved_files = [('tmp.cpp' , patch_content )])
389+ tu = self .index .parse ('tmp.cpp' , unsaved_files = [('tmp.cpp' , patch_content )], args = [ '-std=c++11' ] )
335390 functions = set ()
336391 variables = set ()
337392
@@ -344,9 +399,63 @@ def visit_node(node):
344399 for child in node .get_children ():
345400 visit_node (child )
346401
347- visit_node (tu .cursor )
402+ for child in tu .cursor .get_children ():
403+ visit_node (child )
404+
405+ # visit_node(tu.cursor)
348406 return functions , variables
349407
408+ def extract_functions_from_patch (self , patch_content : str ) -> Set [str ]:
409+ functions = set ()
410+ variables = set ()
411+
412+ # 正则表达式模式
413+ # 匹配函数定义或声明,可能包含命名空间
414+ function_def_pattern = r'(?:(?:\w+::)*\w+\s+)+(\w+(?:::\w+)*)\s*\([^)]*\)\s*(?:const)?\s*(?:{\s*)?'
415+ # 匹配潜在的函数调用或控制结构
416+ potential_call_pattern = r'(\w+(?:::\w+)*)\s*\([^)]*\)'
417+ # 匹配变量声明,可能包含命名空间
418+ variable_pattern = r'(?:(?:\w+::)*\w+\s+)+((?:\w+::)*\w+)\s*(?:=|;)'
419+
420+ # 系统函数和关键字列表(可以根据需要扩展)
421+ system_functions = {'std::' , 'boost::' , 'printf' , 'scanf' , 'malloc' , 'free' , 'new' , 'delete' }
422+ control_structures = {'if' , 'while' , 'for' , 'switch' , 'catch' }
423+
424+ # 提取函数定义
425+ for match in re .finditer (function_def_pattern , patch_content ):
426+ func_name = match .group (1 )
427+ if self ._is_valid_function (func_name , system_functions , control_structures ):
428+ functions .add (func_name )
429+
430+ # 提取潜在的函数调用
431+ for match in re .finditer (potential_call_pattern , patch_content ):
432+ func_name = match .group (1 )
433+ if self ._is_valid_function (func_name , system_functions , control_structures ):
434+ # 检查是否为控制结构
435+ prev_chars = patch_content [max (0 , match .start () - 20 ):match .start ()].split ()
436+ if prev_chars and prev_chars [- 1 ] not in control_structures :
437+ functions .add (func_name )
438+
439+ # 提取变量名
440+ for match in re .finditer (variable_pattern , patch_content ):
441+ var_name = match .group (1 )
442+ if self ._is_valid_function (var_name , system_functions , control_structures ):
443+ variables .add (var_name )
444+
445+ return functions
446+
447+ def _is_valid_function (self , name : str , system_functions : Set [str ], control_structures : Set [str ]) -> bool :
448+ """
449+ 检查名称是否为有效的函数名(不是系统函数或控制结构)
450+
451+ :param name: 要检查的名称
452+ :param system_functions: 系统函数集合
453+ :param control_structures: 控制结构集合
454+ :return: 如果是有效的函数名则返回True,否则返回False
455+ """
456+ return not any (name .startswith (sys_func ) for sys_func in system_functions ) and name not in control_structures
457+
458+
350459 def extract_definitions (self , content : str , names : Set [str ]) -> Dict [str , str ]:
351460 tu = self .index .parse ('tmp.cpp' , unsaved_files = [('tmp.cpp' , content )])
352461 definitions = {}
0 commit comments