Skip to content

Commit a159426

Browse files
authored
Add more verb tests (#1773)
* Add NLP verb test * Add finalize_graph tests * Add more thorough final column assertions
1 parent b4b8b81 commit a159426

7 files changed

+154
-45
lines changed

tests/verbs/test_create_communities.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
55
from graphrag.config.create_graphrag_config import create_graphrag_config
6+
from graphrag.data_model.schemas import COMMUNITIES_FINAL_COLUMNS
67
from graphrag.index.workflows.create_communities import (
78
run_workflow,
89
)
@@ -36,11 +37,14 @@ async def test_create_communities():
3637

3738
actual = await load_table_from_storage("communities", context.storage)
3839

39-
assert "period" in expected.columns
4040
columns = list(expected.columns.values)
41+
# don't compare period since it is created with the current date each time
4142
columns.remove("period")
4243
compare_outputs(
4344
actual,
4445
expected,
4546
columns=columns,
4647
)
48+
49+
for column in COMMUNITIES_FINAL_COLUMNS:
50+
assert column in actual.columns

tests/verbs/test_create_community_reports.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
66
from graphrag.config.create_graphrag_config import create_graphrag_config
77
from graphrag.config.enums import ModelType
8+
from graphrag.data_model.schemas import COMMUNITY_REPORTS_FINAL_COLUMNS
89
from graphrag.index.operations.summarize_communities.community_reports_extractor import (
910
CommunityReportResponse,
1011
FindingModel,
@@ -80,3 +81,6 @@ async def test_create_community_reports():
8081
# assert a handful of mock data items to confirm they get put in the right spot
8182
assert actual["rank"][:1][0] == 2
8283
assert actual["rating_explanation"][:1][0] == "<rating_explanation>"
84+
85+
for column in COMMUNITY_REPORTS_FINAL_COLUMNS:
86+
assert column in actual.columns

tests/verbs/test_create_final_documents.py

Lines changed: 8 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
55
from graphrag.config.create_graphrag_config import create_graphrag_config
6+
from graphrag.data_model.schemas import DOCUMENTS_FINAL_COLUMNS
67
from graphrag.index.workflows.create_final_documents import (
78
run_workflow,
89
)
@@ -36,6 +37,9 @@ async def test_create_final_documents():
3637

3738
compare_outputs(actual, expected)
3839

40+
for column in DOCUMENTS_FINAL_COLUMNS:
41+
assert column in actual.columns
42+
3943

4044
async def test_create_final_documents_with_metadata_column():
4145
context = await create_test_context(
@@ -58,12 +62,7 @@ async def test_create_final_documents_with_metadata_column():
5862

5963
actual = await load_table_from_storage("documents", context.storage)
6064

61-
# our test dataframe does not have metadata, so we'll assert without it
62-
# and separately confirm it is in the output
63-
compare_outputs(
64-
actual, expected, columns=["id", "human_readable_id", "text", "metadata"]
65-
)
66-
assert len(actual.columns) == 7
67-
assert "title" in actual.columns
68-
assert "text_unit_ids" in actual.columns
69-
assert "metadata" in actual.columns
65+
compare_outputs(actual, expected)
66+
67+
for column in DOCUMENTS_FINAL_COLUMNS:
68+
assert column in actual.columns

tests/verbs/test_create_final_text_units.py

Lines changed: 4 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33

44
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
55
from graphrag.config.create_graphrag_config import create_graphrag_config
6+
from graphrag.data_model.schemas import TEXT_UNITS_FINAL_COLUMNS
67
from graphrag.index.workflows.create_final_text_units import (
78
run_workflow,
89
)
@@ -39,37 +40,7 @@ async def test_create_final_text_units():
3940

4041
actual = await load_table_from_storage("text_units", context.storage)
4142

42-
compare_outputs(actual, expected)
43-
44-
45-
async def test_create_final_text_units_no_covariates():
46-
expected = load_test_table("text_units")
47-
48-
context = await create_test_context(
49-
storage=[
50-
"text_units",
51-
"entities",
52-
"relationships",
53-
"covariates",
54-
],
55-
)
43+
for column in TEXT_UNITS_FINAL_COLUMNS:
44+
assert column in actual.columns
5645

57-
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
58-
config.extract_claims.enabled = False
59-
60-
await run_workflow(
61-
config,
62-
context,
63-
NoopWorkflowCallbacks(),
64-
)
65-
66-
actual = await load_table_from_storage("text_units", context.storage)
67-
68-
# we're short a covariate_ids column
69-
columns = list(expected.columns.values)
70-
columns.remove("covariate_ids")
71-
compare_outputs(
72-
actual,
73-
expected,
74-
columns=columns,
75-
)
46+
compare_outputs(actual, expected)

tests/verbs/test_extract_covariates.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
77
from graphrag.config.create_graphrag_config import create_graphrag_config
88
from graphrag.config.enums import ModelType
9+
from graphrag.data_model.schemas import COVARIATES_FINAL_COLUMNS
910
from graphrag.index.workflows.extract_covariates import (
1011
run_workflow,
1112
)
@@ -26,7 +27,6 @@
2627

2728
async def test_extract_covariates():
2829
input = load_test_table("text_units")
29-
expected = load_test_table("covariates")
3030

3131
context = await create_test_context(
3232
storage=["text_units"],
@@ -52,7 +52,9 @@ async def test_extract_covariates():
5252

5353
actual = await load_table_from_storage("covariates", context.storage)
5454

55-
assert len(actual.columns) == len(expected.columns)
55+
for column in COVARIATES_FINAL_COLUMNS:
56+
assert column in actual.columns
57+
5658
# our mock only returns one covariate per text unit, so that's a 1:1 mapping versus the LLM-extracted content in the test data
5759
assert len(actual) == len(input)
5860

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,38 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
5+
from graphrag.config.create_graphrag_config import create_graphrag_config
6+
from graphrag.index.workflows.extract_graph_nlp import (
7+
run_workflow,
8+
)
9+
from graphrag.utils.storage import load_table_from_storage
10+
11+
from .util import (
12+
DEFAULT_MODEL_CONFIG,
13+
create_test_context,
14+
)
15+
16+
17+
async def test_extract_graph_nlp():
18+
context = await create_test_context(
19+
storage=["text_units"],
20+
)
21+
22+
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
23+
24+
await run_workflow(
25+
config,
26+
context,
27+
NoopWorkflowCallbacks(),
28+
)
29+
30+
nodes_actual = await load_table_from_storage("entities", context.storage)
31+
edges_actual = await load_table_from_storage("relationships", context.storage)
32+
33+
# this will be the raw count of entities and edges with no pruning
34+
# with NLP it is deterministic, so we can assert exact row counts
35+
assert len(nodes_actual) == 1148
36+
assert len(nodes_actual.columns) == 5
37+
assert len(edges_actual) == 29445
38+
assert len(edges_actual.columns) == 5

tests/verbs/test_finalize_graph.py

Lines changed: 91 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,91 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
from graphrag.callbacks.noop_workflow_callbacks import NoopWorkflowCallbacks
5+
from graphrag.config.create_graphrag_config import create_graphrag_config
6+
from graphrag.data_model.schemas import (
7+
ENTITIES_FINAL_COLUMNS,
8+
RELATIONSHIPS_FINAL_COLUMNS,
9+
)
10+
from graphrag.index.workflows.finalize_graph import (
11+
run_workflow,
12+
)
13+
from graphrag.utils.storage import load_table_from_storage, write_table_to_storage
14+
15+
from .util import (
16+
DEFAULT_MODEL_CONFIG,
17+
create_test_context,
18+
load_test_table,
19+
)
20+
21+
22+
async def test_finalize_graph():
23+
context = await _prep_tables()
24+
25+
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
26+
27+
await run_workflow(
28+
config,
29+
context,
30+
NoopWorkflowCallbacks(),
31+
)
32+
33+
nodes_actual = await load_table_from_storage("entities", context.storage)
34+
edges_actual = await load_table_from_storage("relationships", context.storage)
35+
36+
assert len(nodes_actual) == 251
37+
assert len(edges_actual) == 372
38+
39+
# x and y will be zero with the default configuration, because we do not embed/umap
40+
assert nodes_actual["x"].sum() == 0
41+
assert nodes_actual["y"].sum() == 0
42+
43+
for column in ENTITIES_FINAL_COLUMNS:
44+
assert column in nodes_actual.columns
45+
for column in RELATIONSHIPS_FINAL_COLUMNS:
46+
assert column in edges_actual.columns
47+
48+
49+
async def test_finalize_graph_umap():
50+
context = await _prep_tables()
51+
52+
config = create_graphrag_config({"models": DEFAULT_MODEL_CONFIG})
53+
54+
config.embed_graph.enabled = True
55+
config.umap.enabled = True
56+
57+
await run_workflow(
58+
config,
59+
context,
60+
NoopWorkflowCallbacks(),
61+
)
62+
63+
nodes_actual = await load_table_from_storage("entities", context.storage)
64+
edges_actual = await load_table_from_storage("relationships", context.storage)
65+
66+
assert len(nodes_actual) == 251
67+
assert len(edges_actual) == 372
68+
69+
# x and y should have some value other than zero due to umap
70+
assert nodes_actual["x"].sum() != 0
71+
assert nodes_actual["y"].sum() != 0
72+
73+
for column in ENTITIES_FINAL_COLUMNS:
74+
assert column in nodes_actual.columns
75+
for column in RELATIONSHIPS_FINAL_COLUMNS:
76+
assert column in edges_actual.columns
77+
78+
79+
async def _prep_tables():
80+
context = await create_test_context(
81+
storage=["entities", "relationships"],
82+
)
83+
84+
# edit the tables to eliminate final fields that wouldn't be on the inputs
85+
entities = load_test_table("entities")
86+
entities.drop(columns=["x", "y", "degree"], inplace=True)
87+
await write_table_to_storage(entities, "entities", context.storage)
88+
relationships = load_test_table("relationships")
89+
relationships.drop(columns=["combined_degree"], inplace=True)
90+
await write_table_to_storage(relationships, "relationships", context.storage)
91+
return context

0 commit comments

Comments
 (0)