Skip to content

Commit 36948b8

Browse files
authored
Various minor updates (#1932)
* Add text unit ids to Community model * Add graph utilities * Turn off LCC for clustering by default * Simplify embeddings config/flow * Semver
1 parent ee1b2db commit 36948b8

File tree

16 files changed

+330
-147
lines changed

16 files changed

+330
-147
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "A few fixes and enhancements for better reuse and flow."
4+
}

docs/config/yaml.md

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -201,8 +201,7 @@ Supported embeddings names are:
201201
- `vector_store_id` **str** - Name of vector store definition to write to.
202202
- `batch_size` **int** - The maximum batch size to use.
203203
- `batch_max_tokens` **int** - The maximum batch # of tokens.
204-
- `target` **required|all|selected|none** - Determines which set of embeddings to export.
205-
- `names` **list[str]** - If target=selected, this should be an explicit list of the embeddings names we support.
204+
- `names` **list[str]** - List of the embeddings names to run (must be in supported list).
206205

207206
### extract_graph
208207

graphrag/config/defaults.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from pathlib import Path
88
from typing import Literal
99

10+
from graphrag.config.embeddings import default_embeddings
1011
from graphrag.config.enums import (
1112
AsyncType,
1213
AuthType,
@@ -18,7 +19,6 @@
1819
NounPhraseExtractorType,
1920
OutputType,
2021
ReportingType,
21-
TextEmbeddingTarget,
2222
)
2323
from graphrag.index.operations.build_noun_graph.np_extractors.stop_words import (
2424
EN_STOP_WORDS,
@@ -147,9 +147,8 @@ class EmbedTextDefaults:
147147
model: str = "text-embedding-3-small"
148148
batch_size: int = 16
149149
batch_max_tokens: int = 8191
150-
target = TextEmbeddingTarget.required
151150
model_id: str = DEFAULT_EMBEDDING_MODEL_ID
152-
names: list[str] = field(default_factory=list)
151+
names: list[str] = field(default_factory=lambda: default_embeddings)
153152
strategy: None = None
154153
vector_store_id: str = DEFAULT_VECTOR_STORE_ID
155154

graphrag/config/embeddings.py

Lines changed: 2 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,6 @@
33

44
"""A module containing embeddings values."""
55

6-
from graphrag.config.enums import TextEmbeddingTarget
7-
from graphrag.config.models.graph_rag_config import GraphRagConfig
8-
96
entity_title_embedding = "entity.title"
107
entity_description_embedding = "entity.description"
118
relationship_description_embedding = "relationship.description"
@@ -25,60 +22,11 @@
2522
community_full_content_embedding,
2623
text_unit_text_embedding,
2724
}
28-
required_embeddings: set[str] = {
25+
default_embeddings: list[str] = [
2926
entity_description_embedding,
3027
community_full_content_embedding,
3128
text_unit_text_embedding,
32-
}
33-
34-
35-
def get_embedded_fields(settings: GraphRagConfig) -> set[str]:
36-
"""Get the fields to embed based on the enum or specifically selected embeddings."""
37-
match settings.embed_text.target:
38-
case TextEmbeddingTarget.all:
39-
return all_embeddings
40-
case TextEmbeddingTarget.required:
41-
return required_embeddings
42-
case TextEmbeddingTarget.selected:
43-
return set(settings.embed_text.names)
44-
case TextEmbeddingTarget.none:
45-
return set()
46-
case _:
47-
msg = f"Unknown embeddings target: {settings.embed_text.target}"
48-
raise ValueError(msg)
49-
50-
51-
def get_embedding_settings(
52-
settings: GraphRagConfig,
53-
vector_store_params: dict | None = None,
54-
) -> dict:
55-
"""Transform GraphRAG config into settings for workflows."""
56-
# TEMP
57-
embeddings_llm_settings = settings.get_language_model_config(
58-
settings.embed_text.model_id
59-
)
60-
vector_store_settings = settings.get_vector_store_config(
61-
settings.embed_text.vector_store_id
62-
).model_dump()
63-
64-
#
65-
# If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding.
66-
# settings.vector_store.base contains connection information, or may be undefined
67-
# settings.vector_store.<vector_name> contains the specific settings for this embedding
68-
#
69-
strategy = settings.embed_text.resolved_strategy(
70-
embeddings_llm_settings
71-
) # get the default strategy
72-
strategy.update({
73-
"vector_store": {
74-
**(vector_store_params or {}),
75-
**(vector_store_settings),
76-
}
77-
}) # update the default strategy with the vector store settings
78-
# This ensures the vector store config is part of the strategy and not the global config
79-
return {
80-
"strategy": strategy,
81-
}
29+
]
8230

8331

8432
def create_collection_name(

graphrag/config/enums.py

Lines changed: 12 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -87,19 +87,6 @@ def __repr__(self):
8787
return f'"{self.value}"'
8888

8989

90-
class TextEmbeddingTarget(str, Enum):
91-
"""The target to use for text embeddings."""
92-
93-
all = "all"
94-
required = "required"
95-
selected = "selected"
96-
none = "none"
97-
98-
def __repr__(self):
99-
"""Get a string representation."""
100-
return f'"{self.value}"'
101-
102-
10390
class ModelType(str, Enum):
10491
"""LLMType enum class definition."""
10592

@@ -176,3 +163,15 @@ class NounPhraseExtractorType(str, Enum):
176163
"""Noun phrase extractor based on dependency parsing and NER using SpaCy."""
177164
CFG = "cfg"
178165
"""Noun phrase extractor combining CFG-based noun-chunk extraction and NER."""
166+
167+
168+
class ModularityMetric(str, Enum):
169+
"""Enum for the modularity metric to use."""
170+
171+
Graph = "graph"
172+
"""Graph modularity metric."""
173+
174+
LCC = "lcc"
175+
176+
WeightedComponents = "weighted_components"
177+
"""Weighted components modularity metric."""
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""A module containing get_embedding_settings."""
5+
6+
from graphrag.config.models.graph_rag_config import GraphRagConfig
7+
8+
9+
def get_embedding_settings(
10+
settings: GraphRagConfig,
11+
vector_store_params: dict | None = None,
12+
) -> dict:
13+
"""Transform GraphRAG config into settings for workflows."""
14+
# TEMP
15+
embeddings_llm_settings = settings.get_language_model_config(
16+
settings.embed_text.model_id
17+
)
18+
vector_store_settings = settings.get_vector_store_config(
19+
settings.embed_text.vector_store_id
20+
).model_dump()
21+
22+
#
23+
# If we get to this point, settings.vector_store is defined, and there's a specific setting for this embedding.
24+
# settings.vector_store.base contains connection information, or may be undefined
25+
# settings.vector_store.<vector_name> contains the specific settings for this embedding
26+
#
27+
strategy = settings.embed_text.resolved_strategy(
28+
embeddings_llm_settings
29+
) # get the default strategy
30+
strategy.update({
31+
"vector_store": {
32+
**(vector_store_params or {}),
33+
**(vector_store_settings),
34+
}
35+
}) # update the default strategy with the vector store settings
36+
# This ensures the vector store config is part of the strategy and not the global config
37+
return {
38+
"strategy": strategy,
39+
}

graphrag/config/models/text_embedding_config.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
from pydantic import BaseModel, Field
77

88
from graphrag.config.defaults import graphrag_config_defaults
9-
from graphrag.config.enums import TextEmbeddingTarget
109
from graphrag.config.models.language_model_config import LanguageModelConfig
1110

1211

@@ -29,10 +28,6 @@ class TextEmbeddingConfig(BaseModel):
2928
description="The batch max tokens to use.",
3029
default=graphrag_config_defaults.embed_text.batch_max_tokens,
3130
)
32-
target: TextEmbeddingTarget = Field(
33-
description="The target to use. 'all', 'required', 'selected', or 'none'.",
34-
default=graphrag_config_defaults.embed_text.target,
35-
)
3631
names: list[str] = Field(
3732
description="The specific embeddings to perform.",
3833
default=graphrag_config_defaults.embed_text.names,

graphrag/data_model/community.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,9 @@ class Community(Named):
2828
relationship_ids: list[str] | None = None
2929
"""List of relationship IDs related to the community (optional)."""
3030

31+
text_unit_ids: list[str] | None = None
32+
"""List of text unit IDs related to the community (optional)."""
33+
3134
covariate_ids: dict[str, list[str]] | None = None
3235
"""Dictionary of different types of covariates related to the community (optional), e.g. claims"""
3336

@@ -50,6 +53,7 @@ def from_dict(
5053
level_key: str = "level",
5154
entities_key: str = "entity_ids",
5255
relationships_key: str = "relationship_ids",
56+
text_units_key: str = "text_unit_ids",
5357
covariates_key: str = "covariate_ids",
5458
parent_key: str = "parent",
5559
children_key: str = "children",
@@ -67,6 +71,7 @@ def from_dict(
6771
short_id=d.get(short_id_key),
6872
entity_ids=d.get(entities_key),
6973
relationship_ids=d.get(relationships_key),
74+
text_unit_ids=d.get(text_units_key),
7075
covariate_ids=d.get(covariates_key),
7176
attributes=d.get(attributes_key),
7277
size=d.get(size_key),

graphrag/index/operations/build_noun_graph/build_noun_graph.py

Lines changed: 2 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
BaseNounPhraseExtractor,
1616
)
1717
from graphrag.index.utils.derive_from_rows import derive_from_rows
18+
from graphrag.index.utils.graphs import calculate_pmi_edge_weights
1819
from graphrag.index.utils.hashing import gen_sha512_hash
1920

2021

@@ -127,52 +128,6 @@ def _extract_edges(
127128
]
128129
if normalize_edge_weights:
129130
# use PMI weight instead of raw weight
130-
grouped_edge_df = _calculate_pmi_edge_weights(nodes_df, grouped_edge_df)
131+
grouped_edge_df = calculate_pmi_edge_weights(nodes_df, grouped_edge_df)
131132

132133
return grouped_edge_df
133-
134-
135-
def _calculate_pmi_edge_weights(
136-
nodes_df: pd.DataFrame,
137-
edges_df: pd.DataFrame,
138-
node_name_col="title",
139-
node_freq_col="frequency",
140-
edge_weight_col="weight",
141-
edge_source_col="source",
142-
edge_target_col="target",
143-
) -> pd.DataFrame:
144-
"""
145-
Calculate pointwise mutual information (PMI) edge weights.
146-
147-
pmi(x,y) = log2(p(x,y) / (p(x)p(y)))
148-
p(x,y) = edge_weight(x,y) / total_edge_weights
149-
p(x) = freq_occurrence(x) / total_freq_occurrences
150-
"""
151-
copied_nodes_df = nodes_df[[node_name_col, node_freq_col]]
152-
153-
total_edge_weights = edges_df[edge_weight_col].sum()
154-
total_freq_occurrences = nodes_df[node_freq_col].sum()
155-
copied_nodes_df["prop_occurrence"] = (
156-
copied_nodes_df[node_freq_col] / total_freq_occurrences
157-
)
158-
copied_nodes_df = copied_nodes_df.loc[:, [node_name_col, "prop_occurrence"]]
159-
160-
edges_df["prop_weight"] = edges_df[edge_weight_col] / total_edge_weights
161-
edges_df = (
162-
edges_df.merge(
163-
copied_nodes_df, left_on=edge_source_col, right_on=node_name_col, how="left"
164-
)
165-
.drop(columns=[node_name_col])
166-
.rename(columns={"prop_occurrence": "source_prop"})
167-
)
168-
edges_df = (
169-
edges_df.merge(
170-
copied_nodes_df, left_on=edge_target_col, right_on=node_name_col, how="left"
171-
)
172-
.drop(columns=[node_name_col])
173-
.rename(columns={"prop_occurrence": "target_prop"})
174-
)
175-
edges_df[edge_weight_col] = edges_df["prop_weight"] * np.log2(
176-
edges_df["prop_weight"] / (edges_df["source_prop"] * edges_df["target_prop"])
177-
)
178-
return edges_df.drop(columns=["prop_weight", "source_prop", "target_prop"])

graphrag/index/operations/cluster_graph.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import logging
77

88
import networkx as nx
9+
from graspologic.partition import hierarchical_leiden
910

1011
from graphrag.index.utils.stable_lcc import stable_largest_connected_component
1112

@@ -60,9 +61,6 @@ def _compute_leiden_communities(
6061
seed: int | None = None,
6162
) -> tuple[dict[int, dict[str, int]], dict[int, int]]:
6263
"""Return Leiden root communities and their hierarchy mapping."""
63-
# NOTE: This import is done here to reduce the initial import time of the graphrag package
64-
from graspologic.partition import hierarchical_leiden
65-
6664
if use_lcc:
6765
graph = stable_largest_connected_component(graph)
6866

0 commit comments

Comments
 (0)