Skip to content

Commit 61a2dd6

Browse files
author
ianrob
committed
Fix issues in batch fallback strategies
1 parent 736c2a2 commit 61a2dd6

File tree

4 files changed

+31
-14
lines changed

4 files changed

+31
-14
lines changed

lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/extract/batch_llm_proposition_extractor_sync.py

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -63,20 +63,29 @@ def _get_json(self, node, llm, inference_parameters):
6363
}
6464

6565
def _run_non_batch_extractor(self, nodes):
66+
67+
all_nodes = [node for node in nodes]
6668

6769
extractor = LLMPropositionExtractor(
6870
prompt_template=self.prompt_template,
6971
source_metadata_field=self.source_metadata_field
7072
)
71-
72-
return extractor.extract(nodes)
73+
74+
extracted = extractor.extract(all_nodes)
75+
76+
results = [{n.node_id: e[PROPOSITIONS_KEY]} for (n, e) in zip(all_nodes, extracted)]
77+
78+
return results
7379

7480
def _update_node(self, node:TextNode, node_metadata_map):
7581
if node.node_id in node_metadata_map:
76-
raw_response = node_metadata_map[node.node_id]
77-
propositions = raw_response.split('\n')
78-
propositions_model = Propositions(propositions=[p for p in propositions if p])
79-
node.metadata[PROPOSITIONS_KEY] = propositions_model.model_dump()['propositions']
82+
proposition_data = node_metadata_map[node.node_id]
83+
if isinstance(proposition_data, list):
84+
node.metadata[PROPOSITIONS_KEY] = proposition_data
85+
else:
86+
propositions = proposition_data.split('\n')
87+
propositions_model = Propositions(propositions=[p for p in propositions if p])
88+
node.metadata[PROPOSITIONS_KEY] = propositions_model.model_dump()['propositions']
8089
else:
8190
node.metadata[PROPOSITIONS_KEY] = []
8291
return node

lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/extract/batch_topic_extractor_sync.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,8 @@ def _get_json(self, node, llm, inference_parameters):
7676
}
7777

7878
def _run_non_batch_extractor(self, nodes):
79+
80+
all_nodes = [node for node in nodes]
7981

8082
extractor = TopicExtractor(
8183
prompt_template=self.prompt_template,
@@ -84,14 +86,20 @@ def _run_non_batch_extractor(self, nodes):
8486
topic_provider=self.topic_provider
8587
)
8688

87-
return extractor.extract(nodes)
89+
extracted = extractor.extract(all_nodes)
90+
91+
results = [{n.id_: e[TOPICS_KEY]} for (n, e) in zip(all_nodes, extracted)]
92+
93+
return results
8894

8995
def _update_node(self, node:TextNode, node_metadata_map):
9096
if node.node_id in node_metadata_map:
91-
raw_response = node_metadata_map[node.node_id]
92-
(topics, _) = parse_extracted_topics(raw_response)
93-
node.metadata[TOPICS_KEY] = topics.model_dump()
97+
topic_data = node_metadata_map[node.node_id]
98+
if isinstance(topic_data, dict):
99+
node.metadata[TOPICS_KEY] = topic_data
100+
else:
101+
(topics, _) = parse_extracted_topics(topic_data)
102+
node.metadata[TOPICS_KEY] = topics.model_dump()
94103
else:
95104
node.metadata[TOPICS_KEY] = []
96-
return node
97-
105+
return node

lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/extract/llm_proposition_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,7 +133,7 @@ async def _extract_propositions_for_nodes(self, nodes):
133133
jobs,
134134
show_progress=self.show_progress,
135135
workers=self.num_workers,
136-
desc=f'Extracting propositions [nodes: {len(nodes)}, num_workers: {self.num_workers}]'
136+
desc=f'Extracting propositions [nodes: {len(jobs)}, num_workers: {self.num_workers}]'
137137
)
138138

139139
async def _extract_propositions_for_node(self, node):

lexical-graph/src/graphrag_toolkit/lexical_graph/indexing/extract/topic_extractor.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -125,7 +125,7 @@ async def _extract_for_nodes(self, nodes):
125125
jobs,
126126
show_progress=self.show_progress,
127127
workers=self.num_workers,
128-
desc=f'Extracting topics [nodes: {len(nodes)}, num_workers: {self.num_workers}]'
128+
desc=f'Extracting topics [nodes: {len(jobs)}, num_workers: {self.num_workers}]'
129129
)
130130

131131
def _get_metadata_or_default(self, metadata, key, default):

0 commit comments

Comments
 (0)