Skip to content

Commit 5ef2399

Browse files
Chore/remove iterrows (#1708)
* Remove most iterrow usages * Semver * Ruff * Pyright * Format
1 parent f14cda2 commit 5ef2399

File tree

7 files changed

+208
-326
lines changed

7 files changed

+208
-326
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "Optimize data iteration by removing some iterrows from code"
4+
}

graphrag/index/operations/create_graph.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,6 @@ def create_graph(
1818

1919
if nodes is not None:
2020
nodes.set_index(node_id, inplace=True)
21-
graph.add_nodes_from((n, dict(d)) for n, d in nodes.iterrows())
21+
graph.add_nodes_from(nodes.to_dict("index").items())
2222

2323
return graph

graphrag/index/operations/snapshot_rows.py

Lines changed: 0 additions & 85 deletions
This file was deleted.

graphrag/index/operations/summarize_descriptions/summarize_descriptions.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -89,12 +89,12 @@ async def get_summarized(
8989

9090
node_futures = [
9191
do_summarize_descriptions(
92-
str(row[1]["title"]),
93-
sorted(set(row[1]["description"])),
92+
str(row.title), # type: ignore
93+
sorted(set(row.description)), # type: ignore
9494
ticker,
9595
semaphore,
9696
)
97-
for row in nodes.iterrows()
97+
for row in nodes.itertuples(index=False)
9898
]
9999

100100
node_results = await asyncio.gather(*node_futures)
@@ -109,12 +109,12 @@ async def get_summarized(
109109

110110
edge_futures = [
111111
do_summarize_descriptions(
112-
(str(row[1]["source"]), str(row[1]["target"])),
113-
sorted(set(row[1]["description"])),
112+
(str(row.source), str(row.target)), # type: ignore
113+
sorted(set(row.description)), # type: ignore
114114
ticker,
115115
semaphore,
116116
)
117-
for row in edges.iterrows()
117+
for row in edges.itertuples(index=False)
118118
]
119119

120120
edge_results = await asyncio.gather(*edge_futures)

graphrag/index/update/entities.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -119,11 +119,12 @@ async def _run_entity_summarization(
119119

120120
# Prepare tasks for async summarization where needed
121121
async def process_row(row):
122-
description = row["description"]
122+
# Accessing attributes directly from the named tuple.
123+
description = row.description
123124
if isinstance(description, list) and len(description) > 1:
124125
# Run entity summarization asynchronously
125126
result = await run_entity_summarization(
126-
row["title"],
127+
row.title,
127128
description,
128129
callbacks,
129130
cache,
@@ -134,7 +135,9 @@ async def process_row(row):
134135
return description[0] if isinstance(description, list) else description
135136

136137
# Create a list of async tasks for summarization
137-
tasks = [process_row(row) for _, row in entities_df.iterrows()]
138+
tasks = [
139+
process_row(row) for row in entities_df.itertuples(index=False, name="Entity")
140+
]
138141
results = await asyncio.gather(*tasks)
139142

140143
# Update the 'description' column in the DataFrame

graphrag/query/input/loaders/dfs.py

Lines changed: 62 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,16 @@
2121
)
2222

2323

24+
def _prepare_records(df: pd.DataFrame) -> list[dict]:
25+
"""
26+
Reset index and convert the DataFrame to a list of dictionaries.
27+
28+
We rename the reset index column to 'Index' for consistency.
29+
"""
30+
df_reset = df.reset_index().rename(columns={"index": "Index"})
31+
return df_reset.to_dict("records")
32+
33+
2434
def read_entities(
2535
df: pd.DataFrame,
2636
id_col: str = "id",
@@ -35,12 +45,14 @@ def read_entities(
3545
rank_col: str | None = "degree",
3646
attributes_cols: list[str] | None = None,
3747
) -> list[Entity]:
38-
"""Read entities from a dataframe."""
39-
entities = []
40-
for idx, row in df.iterrows():
41-
entity = Entity(
48+
"""Read entities from a dataframe using pre-converted records."""
49+
records = _prepare_records(df)
50+
return [
51+
Entity(
4252
id=to_str(row, id_col),
43-
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
53+
short_id=to_optional_str(row, short_id_col)
54+
if short_id_col
55+
else str(row["Index"]),
4456
title=to_str(row, title_col),
4557
type=to_optional_str(row, type_col),
4658
description=to_optional_str(row, description_col),
@@ -57,8 +69,8 @@ def read_entities(
5769
else None
5870
),
5971
)
60-
entities.append(entity)
61-
return entities
72+
for row in records
73+
]
6274

6375

6476
def read_relationships(
@@ -74,12 +86,14 @@ def read_relationships(
7486
text_unit_ids_col: str | None = "text_unit_ids",
7587
attributes_cols: list[str] | None = None,
7688
) -> list[Relationship]:
77-
"""Read relationships from a dataframe."""
78-
relationships = []
79-
for idx, row in df.iterrows():
80-
rel = Relationship(
89+
"""Read relationships from a dataframe using pre-converted records."""
90+
records = _prepare_records(df)
91+
return [
92+
Relationship(
8193
id=to_str(row, id_col),
82-
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
94+
short_id=to_optional_str(row, short_id_col)
95+
if short_id_col
96+
else str(row["Index"]),
8397
source=to_str(row, source_col),
8498
target=to_str(row, target_col),
8599
description=to_optional_str(row, description_col),
@@ -95,8 +109,8 @@ def read_relationships(
95109
else None
96110
),
97111
)
98-
relationships.append(rel)
99-
return relationships
112+
for row in records
113+
]
100114

101115

102116
def read_covariates(
@@ -108,12 +122,14 @@ def read_covariates(
108122
text_unit_ids_col: str | None = "text_unit_ids",
109123
attributes_cols: list[str] | None = None,
110124
) -> list[Covariate]:
111-
"""Read covariates from a dataframe."""
112-
covariates = []
113-
for idx, row in df.iterrows():
114-
cov = Covariate(
125+
"""Read covariates from a dataframe using pre-converted records."""
126+
records = _prepare_records(df)
127+
return [
128+
Covariate(
115129
id=to_str(row, id_col),
116-
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
130+
short_id=to_optional_str(row, short_id_col)
131+
if short_id_col
132+
else str(row["Index"]),
117133
subject_id=to_str(row, subject_col),
118134
covariate_type=(
119135
to_str(row, covariate_type_col) if covariate_type_col else "claim"
@@ -125,8 +141,8 @@ def read_covariates(
125141
else None
126142
),
127143
)
128-
covariates.append(cov)
129-
return covariates
144+
for row in records
145+
]
130146

131147

132148
def read_communities(
@@ -141,12 +157,14 @@ def read_communities(
141157
sub_communities_col: str | None = "sub_community_ids",
142158
attributes_cols: list[str] | None = None,
143159
) -> list[Community]:
144-
"""Read communities from a dataframe."""
145-
communities = []
146-
for idx, row in df.iterrows():
147-
comm = Community(
160+
"""Read communities from a dataframe using pre-converted records."""
161+
records = _prepare_records(df)
162+
return [
163+
Community(
148164
id=to_str(row, id_col),
149-
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
165+
short_id=to_optional_str(row, short_id_col)
166+
if short_id_col
167+
else str(row["Index"]),
150168
title=to_str(row, title_col),
151169
level=to_str(row, level_col),
152170
entity_ids=to_optional_list(row, entities_col, item_type=str),
@@ -161,8 +179,8 @@ def read_communities(
161179
else None
162180
),
163181
)
164-
communities.append(comm)
165-
return communities
182+
for row in records
183+
]
166184

167185

168186
def read_community_reports(
@@ -177,12 +195,14 @@ def read_community_reports(
177195
content_embedding_col: str | None = "full_content_embedding",
178196
attributes_cols: list[str] | None = None,
179197
) -> list[CommunityReport]:
180-
"""Read community reports from a dataframe."""
181-
reports = []
182-
for idx, row in df.iterrows():
183-
report = CommunityReport(
198+
"""Read community reports from a dataframe using pre-converted records."""
199+
records = _prepare_records(df)
200+
return [
201+
CommunityReport(
184202
id=to_str(row, id_col),
185-
short_id=to_optional_str(row, short_id_col) if short_id_col else str(idx),
203+
short_id=to_optional_str(row, short_id_col)
204+
if short_id_col
205+
else str(row["Index"]),
186206
title=to_str(row, title_col),
187207
community_id=to_str(row, community_col),
188208
summary=to_str(row, summary_col),
@@ -197,8 +217,8 @@ def read_community_reports(
197217
else None
198218
),
199219
)
200-
reports.append(report)
201-
return reports
220+
for row in records
221+
]
202222

203223

204224
def read_text_units(
@@ -212,12 +232,12 @@ def read_text_units(
212232
document_ids_col: str | None = "document_ids",
213233
attributes_cols: list[str] | None = None,
214234
) -> list[TextUnit]:
215-
"""Read text units from a dataframe."""
216-
text_units = []
217-
for idx, row in df.iterrows():
218-
chunk = TextUnit(
235+
"""Read text units from a dataframe using pre-converted records."""
236+
records = _prepare_records(df)
237+
return [
238+
TextUnit(
219239
id=to_str(row, id_col),
220-
short_id=str(idx),
240+
short_id=str(row["Index"]),
221241
text=to_str(row, text_col),
222242
entity_ids=to_optional_list(row, entities_col, item_type=str),
223243
relationship_ids=to_optional_list(row, relationships_col, item_type=str),
@@ -232,5 +252,5 @@ def read_text_units(
232252
else None
233253
),
234254
)
235-
text_units.append(chunk)
236-
return text_units
255+
for row in records
256+
]

0 commit comments

Comments
 (0)