Skip to content

Commit 5f31f71

Browse files
committed
fix
1 parent d30ad5a commit 5f31f71

File tree

3 files changed

+40
-6
lines changed

3 files changed

+40
-6
lines changed

hugegraph-llm/src/hugegraph_llm/operators/index_op/build_gremlin_example_index.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,17 +37,17 @@ def __init__(self, embedding: BaseEmbedding, examples: List[Dict[str, str]]):
3737

3838
def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
3939
embed_dim = 0
40-
40+
4141
if len(self.examples) > 0:
4242
# Use the new async parallel embedding approach from upstream
4343
queries = [example["query"] for example in self.examples]
4444
# TODO: refactor function chain async to avoid blocking
4545
examples_embedding = asyncio.run(get_embeddings_parallel(self.embedding, queries))
4646
embed_dim = len(examples_embedding[0])
47-
47+
4848
vector_index = VectorIndex(embed_dim)
4949
vector_index.add(examples_embedding, self.examples)
5050
vector_index.to_index_file(self.index_dir, self.filename_prefix)
51-
51+
5252
context["embed_dim"] = embed_dim
5353
return context

hugegraph-llm/src/hugegraph_llm/operators/llm_op/keyword_extract.py

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,40 @@ def run(self, context: Dict[str, Any]) -> Dict[str, Any]:
7575
context["call_count"] = context.get("call_count", 0) + 1
7676
return context
7777

78+
async def arun(self, context: Dict[str, Any]) -> Dict[str, Any]:
79+
"""异步版本的关键词提取"""
80+
if self._query is None:
81+
self._query = context.get("query")
82+
assert self._query is not None, "No query for keywords extraction."
83+
else:
84+
context["query"] = self._query
85+
86+
if self._llm is None:
87+
self._llm = LLMs().get_extract_llm()
88+
assert isinstance(self._llm, BaseLLM), "Invalid LLM Object."
89+
90+
self._language = context.get("language", self._language).lower()
91+
self._max_keywords = context.get("max_keywords", self._max_keywords)
92+
93+
prompt_run = f"{self._extract_template.format(question=self._query, max_keywords=self._max_keywords)}"
94+
start_time = time.perf_counter()
95+
96+
# 关键改动:使用异步LLM调用
97+
response = await self._llm.agenerate(prompt=prompt_run)
98+
99+
end_time = time.perf_counter()
100+
log.debug("Keyword extraction time: %.2f seconds", end_time - start_time)
101+
102+
keywords = self._extract_keywords_from_response(
103+
response=response, lowercase=False, start_token="KEYWORDS:"
104+
)
105+
keywords = {k.replace("'", "") for k in keywords}
106+
context["keywords"] = list(keywords)
107+
log.info("User Query: %s\nKeywords: %s", self._query, context["keywords"])
108+
109+
context["call_count"] = context.get("call_count", 0) + 1
110+
return context
111+
78112
def _extract_keywords_from_response(
79113
self,
80114
response: str,

hugegraph-llm/src/tests/operators/index_op/test_build_gremlin_example_index.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -51,11 +51,11 @@ def setUp(self):
5151
self.patcher2 = patch("hugegraph_llm.operators.index_op.build_gremlin_example_index.get_index_folder_name")
5252
self.mock_get_index_folder_name = self.patcher2.start()
5353
self.mock_get_index_folder_name.return_value = "hugegraph"
54-
54+
5555
self.patcher3 = patch("hugegraph_llm.operators.index_op.build_gremlin_example_index.get_filename_prefix")
5656
self.mock_get_filename_prefix = self.patcher3.start()
5757
self.mock_get_filename_prefix.return_value = "test_prefix"
58-
58+
5959
self.patcher4 = patch("hugegraph_llm.operators.index_op.build_gremlin_example_index.get_embeddings_parallel")
6060
self.mock_get_embeddings_parallel = self.patcher4.start()
6161
self.mock_get_embeddings_parallel.return_value = [[0.1, 0.2, 0.3], [0.4, 0.5, 0.6]]
@@ -128,7 +128,7 @@ def test_run_with_empty_examples(self):
128128

129129
# The run method should handle empty examples gracefully
130130
result = builder.run(context)
131-
131+
132132
# Should return embed_dim as 0 for empty examples
133133
self.assertEqual(result["embed_dim"], 0)
134134
self.assertEqual(result["test"], "value") # Original context should be preserved

0 commit comments

Comments
 (0)