Skip to content

Commit 99b160d

Browse files
committed
More pyright
1 parent 762988b commit 99b160d

File tree

2 files changed

+17
-13
lines changed

2 files changed

+17
-13
lines changed

graphrag/index/flows/create_final_community_reports.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -116,7 +116,7 @@ async def create_final_community_reports(
116116
def _prep_nodes(input: pd.DataFrame) -> pd.DataFrame:
117117
"""Prepare nodes by filtering, filling missing descriptions, and creating NODE_DETAILS."""
118118
# Filter rows where community is not -1
119-
input = input.loc[input.loc[:,COMMUNITY_ID] != -1]
119+
input = input.loc[input.loc[:, COMMUNITY_ID] != -1]
120120

121121
# Fill missing values in NODE_DESCRIPTION
122122
input.loc[:, NODE_DESCRIPTION] = input.loc[:, NODE_DESCRIPTION].fillna(
@@ -136,8 +136,8 @@ def _prep_edges(input: pd.DataFrame) -> pd.DataFrame:
136136
input.fillna(value={NODE_DESCRIPTION: "No Description"}, inplace=True)
137137

138138
# Create EDGE_DETAILS column
139-
input[EDGE_DETAILS] = input.loc[:,
140-
[EDGE_ID, EDGE_SOURCE, EDGE_TARGET, EDGE_DESCRIPTION, EDGE_DEGREE]
139+
input[EDGE_DETAILS] = input.loc[
140+
:, [EDGE_ID, EDGE_SOURCE, EDGE_TARGET, EDGE_DESCRIPTION, EDGE_DEGREE]
141141
].to_dict(orient="records")
142142

143143
return input
@@ -148,8 +148,8 @@ def _prep_claims(input: pd.DataFrame) -> pd.DataFrame:
148148
input.fillna(value={NODE_DESCRIPTION: "No Description"}, inplace=True)
149149

150150
# Create CLAIM_DETAILS column
151-
input[CLAIM_DETAILS] = input.loc[:,
152-
[CLAIM_ID, CLAIM_SUBJECT, CLAIM_TYPE, CLAIM_STATUS, CLAIM_DESCRIPTION]
151+
input[CLAIM_DETAILS] = input.loc[
152+
:, [CLAIM_ID, CLAIM_SUBJECT, CLAIM_TYPE, CLAIM_STATUS, CLAIM_DESCRIPTION]
153153
].to_dict(orient="records")
154154

155155
return input

graphrag/index/graph/extractors/community_reports/prep_community_report_context.py

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -46,24 +46,28 @@ def prep_community_report_context(
4646

4747
# Filter by community level
4848
level_context_df = local_context_df[
49-
local_context_df.loc[:,schemas.COMMUNITY_LEVEL] == level
49+
local_context_df.loc[:, schemas.COMMUNITY_LEVEL] == level
5050
]
5151

5252
# Filter valid and invalid contexts using boolean logic
53-
valid_context_df = level_context_df[~level_context_df.loc[:,schemas.CONTEXT_EXCEED_FLAG]]
54-
invalid_context_df = level_context_df[level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG]]
53+
valid_context_df = level_context_df.loc[
54+
~level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG]
55+
]
56+
invalid_context_df = level_context_df.loc[
57+
level_context_df.loc[:, schemas.CONTEXT_EXCEED_FLAG]
58+
]
5559

5660
# there is no report to substitute with, so we just trim the local context of the invalid context records
5761
# this case should only happen at the bottom level of the community hierarchy where there are no sub-communities
5862
if invalid_context_df.empty:
5963
return valid_context_df
6064

6165
if report_df.empty:
62-
invalid_context_df[schemas.CONTEXT_STRING] = _sort_and_trim_context(
66+
invalid_context_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
6367
invalid_context_df, max_tokens
6468
)
65-
invalid_context_df[schemas.CONTEXT_SIZE] = invalid_context_df[
66-
schemas.CONTEXT_STRING
69+
invalid_context_df[schemas.CONTEXT_SIZE] = invalid_context_df.loc[
70+
:, schemas.CONTEXT_STRING
6771
].map(num_tokens)
6872
invalid_context_df[schemas.CONTEXT_EXCEED_FLAG] = 0
6973
return union(valid_context_df, invalid_context_df)
@@ -80,12 +84,12 @@ def prep_community_report_context(
8084
# handle any remaining invalid records that can't be subsituted with sub-community reports
8185
# this should be rare, but if it happens, we will just trim the local context to fit the limit
8286
remaining_df = _antijoin_reports(invalid_context_df, community_df)
83-
remaining_df[schemas.CONTEXT_STRING] = _sort_and_trim_context(
87+
remaining_df.loc[:, schemas.CONTEXT_STRING] = _sort_and_trim_context(
8488
remaining_df, max_tokens
8589
)
8690

8791
result = union(valid_context_df, community_df, remaining_df)
88-
result.loc[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(num_tokens)
92+
result[schemas.CONTEXT_SIZE] = result.loc[:, schemas.CONTEXT_STRING].map(num_tokens)
8993

9094
result[schemas.CONTEXT_EXCEED_FLAG] = 0
9195
return result

0 commit comments

Comments
 (0)