Skip to content

Commit 5f8bd31

Browse files
committed
Optimizing code and vector queries
1 parent 99d9a25 commit 5f8bd31

File tree

6 files changed

+195
-31
lines changed

6 files changed

+195
-31
lines changed

apps/review.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -231,6 +231,10 @@ async def review_specific_pr(pr_url: str):
231231
repo_detail = github.parse_pullrequest_url(pr_url)
232232
pr_data = await github.get_pullrequest(repo_detail.get_repo_fullname(), repo_detail.number)
233233
head_sha = pr_data['head']['sha']
234-
commit_message = f"{pr_data['title']}\n\n{pr_data.get('body', '')}"
234+
pr_body = pr_data.get('body', '')
235+
pr_title = pr_data['title']
236+
if not pr_body:
237+
pr_body = ""
238+
commit_message = f"{pr_title}\n\n{pr_body}"
235239
await review_pull_request(repo_detail.get_repo_fullname(), repo_detail.number, head_sha, commit_message)
236240

apps/webhook/handles.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -245,7 +245,8 @@ async def pull_request_handler(action: str, payload, event, delivery, headers):
245245
if not settings.REVIEW_MODEL.api_key:
246246
logger.info(f"Thread: {delivery}: No review model, skip")
247247
return
248-
248+
if not body:
249+
body = ""
249250
await review.review_pull_request(repo_name, pr_number, head_sha, f"{title}\n\n{body}")
250251

251252

core/analyze/analyzer.py

Lines changed: 120 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,9 @@ def extract_names_from_patch(self, patch_content: str) -> Tuple[Set[str], Set[st
5252
pass
5353

5454
@abstractmethod
55+
def extract_functions_from_patch(self, patch_content: str) -> Set[str]:
56+
pass
57+
@abstractmethod
5558
def extract_definitions(self, content: str, names: Set[str]) -> Dict[str, str]:
5659
pass
5760

@@ -196,6 +199,26 @@ def extract_names_from_patch(self, patch_content: str) -> Tuple[Set[str], Set[st
196199

197200
return functions, variables
198201

202+
def extract_functions_from_patch(self, patch_content: str) -> Set[str]:
203+
# 用于存储提取的信息
204+
extracted_info = set()
205+
206+
# 正则表达式模式
207+
type_pattern = r'\b([A-Z][a-zA-Z0-9_]*)\b'
208+
function_pattern = r'\b([a-zA-Z_][a-zA-Z0-9_]*)\s*\('
209+
variable_pattern = r'\b([a-z_][a-zA-Z0-9_]*)\b'
210+
211+
# 逐行分析 patch 内容
212+
for line in patch_content.split('\n'):
213+
# 提取类型(假设以大写字母开头)
214+
types = re.findall(type_pattern, line)
215+
extracted_info.update(types)
216+
217+
# 提取函数(假设后面跟着括号)
218+
functions = re.findall(function_pattern, line)
219+
extracted_info.update(functions)
220+
return extracted_info
221+
199222
def extract_definitions(self, content: str, names: Set[str]) -> Dict[str, str]:
200223
tree = ast.parse(content)
201224
definitions = {}
@@ -293,7 +316,7 @@ def _is_from_project(self, node, current_file: str) -> bool:
293316
return False
294317
file_path = os.path.abspath(node.location.file.name)
295318
return file_path.startswith(self.project_root) and (
296-
file_path == current_file or not file_path.endswith(('.h', '.hpp')))
319+
file_path == current_file or not file_path.endswith(('.h', '.hpp')))
297320

298321
def _get_element_content(self, node) -> str:
299322
try:
@@ -315,23 +338,55 @@ def _is_likely_external(self, content: str) -> bool:
315338

316339
def analyze_dependencies(self, file_path: str, content: str) -> List[str]:
317340
"""
318-
分析文件的依赖关系,并过滤掉非项目内的依赖
341+
分析文件的依赖关系,并过滤掉非项目内的依赖
319342
320-
:param file_path: 当前分析的文件路径
321-
:param content: 文件内容
322-
:param base_path: 项目的基础路径
323-
:return: 项目内的依赖列表
324-
"""
343+
:param file_path: 当前分析的文件路径
344+
:param content: 文件内容
345+
:param base_path: 项目的基础路径
346+
:return: 项目内的依赖列表
347+
"""
325348
# 查找所有的 #include 语句
326349
includes = re.findall(r'#include\s*[<"]([^>"]+)[>"]', content)
327350

328351
# 转换和过滤依赖
329352
project_dependencies = self.find_dependencies(file_path, includes)
330-
# 去重并返回
331-
return list(set(project_dependencies))
353+
# 对于每个头文件依赖,尝试找到对应的实现文件
354+
implementation_dependencies = []
355+
for dep in project_dependencies:
356+
impl_file = self.find_implementation_file(dep)
357+
if impl_file:
358+
implementation_dependencies.append(impl_file)
359+
360+
# 合并头文件和实现文件的依赖,去重并返回
361+
all_dependencies = list(set(project_dependencies + implementation_dependencies))
362+
return all_dependencies
363+
364+
def find_implementation_file(self, header_path: str) -> Optional[str]:
365+
"""
366+
根据头文件路径查找对应的实现文件
367+
368+
:param header_path: 头文件的相对路径
369+
:return: 实现文件的相对路径,如果找不到则返回None
370+
"""
371+
implementation_extensions = ['.cpp', '.cxx', '.cc', '.c']
372+
base_name = os.path.splitext(header_path)[0]
373+
374+
for ext in implementation_extensions:
375+
impl_path = base_name + ext
376+
if impl_path in self.file_index.values():
377+
return impl_path
378+
379+
# 如果在同一目录下找不到,尝试在整个项目中查找
380+
file_name = os.path.basename(base_name)
381+
for ext in implementation_extensions:
382+
impl_file = file_name + ext
383+
if impl_file in self.file_index:
384+
return self.file_index[impl_file]
385+
386+
return None
332387

333388
def extract_names_from_patch(self, patch_content: str) -> Tuple[Set[str], Set[str]]:
334-
tu = self.index.parse('tmp.cpp', unsaved_files=[('tmp.cpp', patch_content)])
389+
tu = self.index.parse('tmp.cpp', unsaved_files=[('tmp.cpp', patch_content)], args=['-std=c++11'])
335390
functions = set()
336391
variables = set()
337392

@@ -344,9 +399,63 @@ def visit_node(node):
344399
for child in node.get_children():
345400
visit_node(child)
346401

347-
visit_node(tu.cursor)
402+
for child in tu.cursor.get_children():
403+
visit_node(child)
404+
405+
# visit_node(tu.cursor)
348406
return functions, variables
349407

408+
def extract_functions_from_patch(self, patch_content: str) -> Set[str]:
409+
functions = set()
410+
variables = set()
411+
412+
# 正则表达式模式
413+
# 匹配函数定义或声明,可能包含命名空间
414+
function_def_pattern = r'(?:(?:\w+::)*\w+\s+)+(\w+(?:::\w+)*)\s*\([^)]*\)\s*(?:const)?\s*(?:{\s*)?'
415+
# 匹配潜在的函数调用或控制结构
416+
potential_call_pattern = r'(\w+(?:::\w+)*)\s*\([^)]*\)'
417+
# 匹配变量声明,可能包含命名空间
418+
variable_pattern = r'(?:(?:\w+::)*\w+\s+)+((?:\w+::)*\w+)\s*(?:=|;)'
419+
420+
# 系统函数和关键字列表(可以根据需要扩展)
421+
system_functions = {'std::', 'boost::', 'printf', 'scanf', 'malloc', 'free', 'new', 'delete'}
422+
control_structures = {'if', 'while', 'for', 'switch', 'catch'}
423+
424+
# 提取函数定义
425+
for match in re.finditer(function_def_pattern, patch_content):
426+
func_name = match.group(1)
427+
if self._is_valid_function(func_name, system_functions, control_structures):
428+
functions.add(func_name)
429+
430+
# 提取潜在的函数调用
431+
for match in re.finditer(potential_call_pattern, patch_content):
432+
func_name = match.group(1)
433+
if self._is_valid_function(func_name, system_functions, control_structures):
434+
# 检查是否为控制结构
435+
prev_chars = patch_content[max(0, match.start() - 20):match.start()].split()
436+
if prev_chars and prev_chars[-1] not in control_structures:
437+
functions.add(func_name)
438+
439+
# 提取变量名
440+
for match in re.finditer(variable_pattern, patch_content):
441+
var_name = match.group(1)
442+
if self._is_valid_function(var_name, system_functions, control_structures):
443+
variables.add(var_name)
444+
445+
return functions
446+
447+
def _is_valid_function(self, name: str, system_functions: Set[str], control_structures: Set[str]) -> bool:
448+
"""
449+
检查名称是否为有效的函数名(不是系统函数或控制结构)
450+
451+
:param name: 要检查的名称
452+
:param system_functions: 系统函数集合
453+
:param control_structures: 控制结构集合
454+
:return: 如果是有效的函数名则返回True,否则返回False
455+
"""
456+
return not any(name.startswith(sys_func) for sys_func in system_functions) and name not in control_structures
457+
458+
350459
def extract_definitions(self, content: str, names: Set[str]) -> Dict[str, str]:
351460
tu = self.index.parse('tmp.cpp', unsaved_files=[('tmp.cpp', content)])
352461
definitions = {}

core/analyze/base.py

Lines changed: 54 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616
import glob
1717
import json
1818
import os
19+
import re
1920
import shutil
2021
from typing import List, Dict, Any, Optional
2122

@@ -55,7 +56,7 @@ def __init__(self, repo_fullname: str, milvus_uri: Optional[str] = None):
5556
self.dependencies = {}
5657
self.exclude_path = []
5758
self.milvus_uri = milvus_uri
58-
self.code_elements_collection = f"code_{self.repo_fullname.replace('/', '_').lower()}"
59+
self.code_elements_collection = f"v1_code_{self.repo_fullname.replace('/', '_').lower()}"
5960
self.code_elements_collection_loaded = False
6061
self.init_lock = asyncio.Lock()
6162

@@ -92,26 +93,26 @@ async def check_elements_collection(self) -> bool:
9293
FieldSchema(name="language", dtype=DataType.VARCHAR, max_length=20),
9394
FieldSchema(name="element_type", dtype=DataType.VARCHAR, max_length=20),
9495
FieldSchema(name="element_name", dtype=DataType.VARCHAR, max_length=100),
95-
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=20000),
96+
FieldSchema(name="content", dtype=DataType.VARCHAR, max_length=65535),
9697
FieldSchema(name="embedding", dtype=DataType.FLOAT_VECTOR, dim=768)
9798
]
98-
schema = CollectionSchema(fields=fields, description="代码元素嵌入向量")
99+
schema = CollectionSchema(fields=fields, description="Code search collection")
99100
await milvus_manager.create_collection(
100101
dimension=768,
101-
metric_type="COSINE",
102+
metric_type="IP",
102103
collection_name=self.code_elements_collection,
103104
schema=schema,
104105
vector_field_name="embedding",
105-
description="代码元素嵌入向量"
106+
description="Code search collection"
106107
)
107108
# 创建向量索引
108109
index_params = IndexParams()
109110
try:
110111
# 判断milvus使用的模式, 本地或者内存
111112
if self.milvus_uri == "sqlite://:memory:" or not self.milvus_uri or self.milvus_uri.startswith("/"):
112-
index_params.add_index("embedding", "FLAT", "embedding_index", metric_type="COSINE")
113+
index_params.add_index("embedding", "FLAT", "embedding_index", metric_type="IP")
113114
else:
114-
index_params.add_index("embedding", "IVF_FLAT", "embedding_index", nlist=1024, metric_type="COSINE")
115+
index_params.add_index("embedding", "IVF_FLAT", "embedding_index", nlist=1024, metric_type="IP")
115116
await milvus_manager.create_index(
116117
collection_name=self.code_elements_collection,
117118
index_params=index_params
@@ -306,6 +307,8 @@ async def save_to_db(self, file_detail: FileDetails):
306307
# logger.info(f"Too many code elements in {file_detail.file_name}, only saving the first 60.")
307308
exclude_types_list = [CodeElementType.CONSTANT, CodeElementType.VARIABLE]
308309
for element in file_detail.code_elements:
310+
if not element['name'] or len(element['name']) == 0:
311+
continue
309312
if element['type'] in exclude_types_list:
310313
continue
311314
if f'{element["type"]}_{element["name"]}' in added_set:
@@ -439,28 +442,65 @@ async def generate_project_overview(self, summary: Dict[str, Any]) -> str:
439442
messages, settings.REVIEW_MODEL, 0.3, 50, 0.9)
440443
return overview
441444

445+
def clean_patch(self, patch_content: str) -> str:
446+
"""
447+
清理补丁内容,删除两个@@之间的字符, 忽略删除的行
448+
"""
449+
cleaned_patch = []
450+
for line in patch_content.split('\n'):
451+
if line.startswith('@@'):
452+
cleaned_patch.append(line.rsplit('@@', 1)[1])
453+
elif line.startswith('-'):
454+
continue
455+
elif line.startswith('+'):
456+
cleaned_patch.append(line[1:])
457+
else:
458+
cleaned_patch.append(line)
459+
return '\n'.join(cleaned_patch)
460+
442461
async def get_review_context(self, filename: str, patch_content: str) -> Dict[str, Any]:
443462
"""
444463
审查所需要的上下文信息
445464
"""
446-
patch_embedding = await embedding_model.async_encode_text(patch_content)
447-
448-
search_params = {"metric_type": "COSINE", "params": {"nprobe": 20}}
465+
patch_embedding = []
466+
patch_content = self.clean_patch(patch_content)
467+
language = utils.get_support_file_language(filename)
468+
analyzer = self.analyzers.get(language)
469+
code_elements = list(analyzer.extract_functions_from_patch(patch_content))
470+
if not code_elements or len(code_elements) == 0:
471+
code_elements = patch_content.split("\n")
472+
code_elements_count = len(code_elements)
473+
limit = 20 // code_elements_count
474+
if limit < 1:
475+
limit = 1
476+
code_elements = code_elements[:20]
477+
for element in code_elements:
478+
element_v = await embedding_model.async_encode_text(element)
479+
patch_embedding.append(element_v.tolist())
480+
search_params = {"metric_type": "IP", "params": {"nprobe": 10}}
449481
await self.check_elements_collection()
450482
results = await milvus_manager.search(
451483
collection_name=self.code_elements_collection,
452-
data=[patch_embedding.tolist()],
484+
# filter="element_name in ['" + "','".join(code_elements) + "']",
485+
# filter="element_name != ''",
486+
data=patch_embedding,
453487
anns_field="embedding",
454488
search_params=search_params,
455-
limit=20,
489+
limit=limit,
456490
output_fields=["file_path", "language", "element_type", "element_name", "content"]
457491
)
492+
related_elements = []
493+
for result in results:
494+
if isinstance(result, dict):
495+
continue
496+
for code_element in result:
497+
related_elements.append(code_element)
458498

459-
related_elements = results[0]
460499
# 获取相关元素的上下文信息
461500
context_info = self.get_context_info(related_elements)
462501
# 分析补丁中的依赖关系
463502
patch_dependencies = self.get_dependencies(filename)
503+
logger.info("Dependencies: %s", patch_dependencies)
464504
# 项目的概述 "project_overview.md"
465505
project_overview = ""
466506
overview_path = os.path.join(self.analyze_data_path, "project_overview.md")
@@ -500,7 +540,7 @@ def get_dependencies(self, filename: str) -> Dict[str, str]:
500540
if not index_detail:
501541
return result
502542
for file_name in index_detail.dependencies:
503-
if len(result) > 5:
543+
if len(result) > 6:
504544
return result
505545
# 读取依赖文件的内容
506546
file_path = os.path.join(self.project_source_path, file_name)

core/analyze/index.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,7 @@ class IndexItem(pydantic.BaseModel):
3939
language: str
4040
last_modified: float
4141
dependencies: List[str]
42+
# code_elements: List[Dict[str, Any]]
4243

4344

4445
INDEX_PATH_PREFIX = '.index'
@@ -92,7 +93,8 @@ def insert_or_update(self, file_detail: FileDetails):
9293
code_hash=file_detail.code_hash,
9394
language=file_detail.language,
9495
last_modified=os.path.getmtime(os.path.join(self.source_path, file_detail.file_name)),
95-
dependencies=file_detail.dependencies
96+
dependencies=file_detail.dependencies,
97+
# code_elements=file_detail.code_elements
9698
)
9799
with open(index_file_name, 'w') as f:
98100
f.write(index_item.json())

core/embedding.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,13 @@ def get_model(self) -> TextEmbedding:
5959
self.load()
6060
return self.embedding_model
6161

62+
def normalize_vector(self, vector: np.ndarray) -> np.ndarray:
63+
"""对向量进行L2归一化"""
64+
norm = np.linalg.norm(vector)
65+
if norm == 0:
66+
return vector
67+
return vector / norm
68+
6269
def chunk_text(self, text: str, chunk_size: int = 1000) -> List[str]:
6370
"""将文本分割成更小的块"""
6471
words = text.split()
@@ -72,15 +79,16 @@ def encode_text(self, text: str, chunk_size: int = 1000) -> np.ndarray:
7279
model = self.get_model()
7380
if len(text.split()) <= chunk_size:
7481
embeddings = next(model.embed([text]))
75-
return embeddings
7682
else:
7783
chunks = self.chunk_text(text, chunk_size)
7884
embeddings = []
7985
for chunk in chunks:
8086
chunk_embedding = next(model.embed([chunk]))
8187
embeddings.append(chunk_embedding)
8288
gc.collect() # Force garbage collection after each chunk
83-
return np.mean(embeddings, axis=0)
89+
embeddings = np.mean(embeddings, axis=0)
90+
gc.collect()
91+
return self.normalize_vector(embeddings)
8492
except Exception as e:
8593
logger.error(f"Failed to encode text: {e}")
8694
return np.zeros(model.dim) # Use the dimension from the model
@@ -98,7 +106,7 @@ def process_large_document(self, document: str, chunk_size: int = 1000) -> np.nd
98106
result = np.mean(embeddings, axis=0)
99107
del embeddings
100108
gc.collect()
101-
return result
109+
return self.normalize_vector(result)
102110

103111
# 异步包装器
104112
async def async_encode_text(self, text: str, chunk_size: int = 1000) -> np.ndarray:

0 commit comments

Comments
 (0)