Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241212190223784600.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Streamline flows."
}
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20241213181544864279.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Move extractor code to co-locate with operations."
}
5 changes: 0 additions & 5 deletions graphrag/config/models/claim_extraction_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,7 @@ class ClaimExtractionConfig(LLMConfig):

def resolved_strategy(self, root_dir: str, encoding_model: str | None) -> dict:
"""Get the resolved claim extraction strategy."""
from graphrag.index.operations.extract_covariates import (
ExtractClaimsStrategyType,
)

return self.strategy or {
"type": ExtractClaimsStrategyType.graph_intelligence,
"llm": self.llm.model_dump(),
**self.parallelization.model_dump(),
"extraction_prompt": (Path(root_dir) / self.prompt)
Expand Down
2 changes: 1 addition & 1 deletion graphrag/config/models/embed_graph_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@ class EmbedGraphConfig(BaseModel):

def resolved_strategy(self) -> dict:
"""Get the resolved node2vec strategy."""
from graphrag.index.operations.embed_graph import (
from graphrag.index.operations.embed_graph.typing import (
EmbedGraphStrategyType,
)

Expand Down
14 changes: 1 addition & 13 deletions graphrag/index/flows/compute_communities.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,15 +9,11 @@

from graphrag.index.operations.cluster_graph import cluster_graph
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.snapshot import snapshot
from graphrag.storage.pipeline_storage import PipelineStorage


async def compute_communities(
def compute_communities(
base_relationship_edges: pd.DataFrame,
storage: PipelineStorage,
clustering_strategy: dict[str, Any],
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to create the base entity graph."""
graph = create_graph(base_relationship_edges)
Expand All @@ -32,12 +28,4 @@ async def compute_communities(
).explode("title")
base_communities["community"] = base_communities["community"].astype(int)

if snapshot_transient_enabled:
await snapshot(
base_communities,
name="base_communities",
storage=storage,
formats=["parquet"],
)

return base_communities
20 changes: 2 additions & 18 deletions graphrag/index/flows/create_base_text_units.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,18 +15,14 @@
)

from graphrag.index.operations.chunk_text import chunk_text
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.utils.hashing import gen_sha512_hash
from graphrag.storage.pipeline_storage import PipelineStorage


async def create_base_text_units(
def create_base_text_units(
documents: pd.DataFrame,
callbacks: VerbCallbacks,
storage: PipelineStorage,
chunk_by_columns: list[str],
chunk_strategy: dict[str, Any] | None = None,
snapshot_transient_enabled: bool = False,
) -> pd.DataFrame:
"""All the steps to transform base text_units."""
sort = documents.sort_values(by=["id"], ascending=[True])
Expand Down Expand Up @@ -74,19 +70,7 @@ async def create_base_text_units(
# rename for downstream consumption
chunked.rename(columns={"chunk": "text"}, inplace=True)

output = cast(
"pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True)
)

if snapshot_transient_enabled:
await snapshot(
output,
name="create_base_text_units",
storage=storage,
formats=["parquet"],
)

return output
return cast("pd.DataFrame", chunked[chunked["text"].notna()].reset_index(drop=True))


# TODO: would be nice to inline this completely in the main method with pandas
Expand Down
12 changes: 6 additions & 6 deletions graphrag/index/flows/create_final_community_reports.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,12 @@
)

from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.graph.extractors.community_reports.schemas import (
from graphrag.index.operations.summarize_communities import (
prepare_community_reports,
restore_community_hierarchy,
summarize_communities,
)
from graphrag.index.operations.summarize_communities.community_reports_extractor.schemas import (
CLAIM_DESCRIPTION,
CLAIM_DETAILS,
CLAIM_ID,
Expand All @@ -32,11 +37,6 @@
NODE_ID,
NODE_NAME,
)
from graphrag.index.operations.summarize_communities import (
prepare_community_reports,
restore_community_hierarchy,
summarize_communities,
)


async def create_final_community_reports(
Expand Down
2 changes: 1 addition & 1 deletion graphrag/index/flows/create_final_covariates.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
)

from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.operations.extract_covariates import (
from graphrag.index.operations.extract_covariates.extract_covariates import (
extract_covariates,
)

Expand Down
23 changes: 14 additions & 9 deletions graphrag/index/flows/create_final_nodes.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,10 @@
VerbCallbacks,
)

from graphrag.index.operations.compute_degree import compute_degree
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.embed_graph import embed_graph
from graphrag.index.operations.layout_graph import layout_graph
from graphrag.index.operations.embed_graph.embed_graph import embed_graph
from graphrag.index.operations.layout_graph.layout_graph import layout_graph


def create_final_nodes(
Expand All @@ -37,15 +38,19 @@ def create_final_nodes(
layout_strategy,
embeddings=graph_embeddings,
)
nodes = base_entity_nodes.merge(
layout, left_on="title", right_on="label", how="left"
)

joined = nodes.merge(base_communities, on="title", how="left")
joined["level"] = joined["level"].fillna(0).astype(int)
joined["community"] = joined["community"].fillna(-1).astype(int)
degrees = compute_degree(graph)

return joined.loc[
nodes = (
base_entity_nodes.merge(layout, left_on="title", right_on="label", how="left")
.merge(degrees, on="title", how="left")
.merge(base_communities, on="title", how="left")
)
nodes["level"] = nodes["level"].fillna(0).astype(int)
nodes["community"] = nodes["community"].fillna(-1).astype(int)
# disconnected nodes and those with no community even at level 0 can be missing degree
nodes["degree"] = nodes["degree"].fillna(0).astype(int)
return nodes.loc[
:,
[
"id",
Expand Down
9 changes: 7 additions & 2 deletions graphrag/index/flows/create_final_relationships.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,25 @@

import pandas as pd

from graphrag.index.operations.compute_degree import compute_degree
from graphrag.index.operations.compute_edge_combined_degree import (
compute_edge_combined_degree,
)
from graphrag.index.operations.create_graph import create_graph


def create_final_relationships(
base_relationship_edges: pd.DataFrame,
base_entity_nodes: pd.DataFrame,
) -> pd.DataFrame:
"""All the steps to transform final relationships."""
relationships = base_relationship_edges

graph = create_graph(base_relationship_edges)
degrees = compute_degree(graph)

relationships["combined_degree"] = compute_edge_combined_degree(
relationships,
base_entity_nodes,
degrees,
node_name_column="title",
node_degree_column="degree",
edge_source_column="source",
Expand Down
93 changes: 13 additions & 80 deletions graphrag/index/flows/extract_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,41 +6,33 @@
from typing import Any
from uuid import uuid4

import networkx as nx
import pandas as pd
from datashaper import (
AsyncType,
VerbCallbacks,
)

from graphrag.cache.pipeline_cache import PipelineCache
from graphrag.index.operations.create_graph import create_graph
from graphrag.index.operations.extract_entities import extract_entities
from graphrag.index.operations.snapshot import snapshot
from graphrag.index.operations.snapshot_graphml import snapshot_graphml
from graphrag.index.operations.summarize_descriptions import (
summarize_descriptions,
)
from graphrag.storage.pipeline_storage import PipelineStorage


async def extract_graph(
text_units: pd.DataFrame,
callbacks: VerbCallbacks,
cache: PipelineCache,
storage: PipelineStorage,
extraction_strategy: dict[str, Any] | None = None,
extraction_num_threads: int = 4,
extraction_async_mode: AsyncType = AsyncType.AsyncIO,
entity_types: list[str] | None = None,
summarization_strategy: dict[str, Any] | None = None,
summarization_num_threads: int = 4,
snapshot_graphml_enabled: bool = False,
snapshot_transient_enabled: bool = False,
) -> tuple[pd.DataFrame, pd.DataFrame]:
"""All the steps to create the base entity graph."""
# this returns a graph for each text unit, to be merged later
entity_dfs, relationship_dfs = await extract_entities(
entities, relationships = await extract_entities(
text_units,
callbacks,
cache,
Expand All @@ -52,87 +44,38 @@ async def extract_graph(
num_threads=extraction_num_threads,
)

if not _validate_data(entity_dfs):
if not _validate_data(entities):
error_msg = "Entity Extraction failed. No entities detected during extraction."
callbacks.error(error_msg)
raise ValueError(error_msg)

if not _validate_data(relationship_dfs):
if not _validate_data(relationships):
error_msg = (
"Entity Extraction failed. No relationships detected during extraction."
)
callbacks.error(error_msg)
raise ValueError(error_msg)

merged_entities = _merge_entities(entity_dfs)
merged_relationships = _merge_relationships(relationship_dfs)

entity_summaries, relationship_summaries = await summarize_descriptions(
merged_entities,
merged_relationships,
entities,
relationships,
callbacks,
cache,
strategy=summarization_strategy,
num_threads=summarization_num_threads,
)

base_relationship_edges = _prep_edges(merged_relationships, relationship_summaries)

graph = create_graph(base_relationship_edges)

base_entity_nodes = _prep_nodes(merged_entities, entity_summaries, graph)

if snapshot_graphml_enabled:
# todo: extract graphs at each level, and add in meta like descriptions
await snapshot_graphml(
graph,
name="graph",
storage=storage,
)
base_relationship_edges = _prep_edges(relationships, relationship_summaries)

if snapshot_transient_enabled:
await snapshot(
base_entity_nodes,
name="base_entity_nodes",
storage=storage,
formats=["parquet"],
)
await snapshot(
base_relationship_edges,
name="base_relationship_edges",
storage=storage,
formats=["parquet"],
)
base_entity_nodes = _prep_nodes(entities, entity_summaries)

return (base_entity_nodes, base_relationship_edges)


def _merge_entities(entity_dfs) -> pd.DataFrame:
all_entities = pd.concat(entity_dfs, ignore_index=True)
return (
all_entities.groupby(["name", "type"], sort=False)
.agg({"description": list, "source_id": list})
.reset_index()
)


def _merge_relationships(relationship_dfs) -> pd.DataFrame:
all_relationships = pd.concat(relationship_dfs, ignore_index=False)
return (
all_relationships.groupby(["source", "target"], sort=False)
.agg({"description": list, "source_id": list, "weight": "sum"})
.reset_index()
)


def _prep_nodes(entities, summaries, graph) -> pd.DataFrame:
degrees_df = _compute_degree(graph)
def _prep_nodes(entities, summaries) -> pd.DataFrame:
entities.drop(columns=["description"], inplace=True)
nodes = (
entities.merge(summaries, on="name", how="left")
.merge(degrees_df, on="name")
.drop_duplicates(subset="name")
.rename(columns={"name": "title", "source_id": "text_unit_ids"})
nodes = entities.merge(summaries, on="title", how="left").drop_duplicates(
subset="title"
)
nodes = nodes.loc[nodes["title"].notna()].reset_index()
nodes["human_readable_id"] = nodes.index
Expand All @@ -145,22 +88,12 @@ def _prep_edges(relationships, summaries) -> pd.DataFrame:
relationships.drop(columns=["description"])
.drop_duplicates(subset=["source", "target"])
.merge(summaries, on=["source", "target"], how="left")
.rename(columns={"source_id": "text_unit_ids"})
)
edges["human_readable_id"] = edges.index
edges["id"] = edges["human_readable_id"].apply(lambda _x: str(uuid4()))
return edges


def _compute_degree(graph: nx.Graph) -> pd.DataFrame:
return pd.DataFrame([
{"name": node, "degree": int(degree)}
for node, degree in graph.degree # type: ignore
])


def _validate_data(df_list: list[pd.DataFrame]) -> bool:
"""Validate that the dataframe list is valid. At least one dataframe must contain data."""
return any(
len(df) > 0 for df in df_list
) # Check for len, not .empty, as the dfs have schemas in some cases
def _validate_data(df: pd.DataFrame) -> bool:
"""Validate that the dataframe has data."""
return len(df) > 0
3 changes: 1 addition & 2 deletions graphrag/index/flows/generate_text_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,9 +129,8 @@ async def _run_and_snapshot_embeddings(
strategy=text_embed_config["strategy"],
)

data = data.loc[:, ["id", "embedding"]]

if snapshot_embeddings_enabled is True:
data = data.loc[:, ["id", "embedding"]]
await snapshot(
data,
name=f"embeddings.{name}",
Expand Down
4 changes: 0 additions & 4 deletions graphrag/index/graph/__init__.py

This file was deleted.

8 changes: 0 additions & 8 deletions graphrag/index/graph/embedding/__init__.py

This file was deleted.

Loading
Loading