Skip to content

Commit b94290e

Browse files
authored
add option to add metadata into text chunks (#1681)
* add new options * add metadata json into input document * remove doc change * add metadata column into text loader * prepend_metadata * run fix * fix tests and patch * fix test * add watrning for metadata tokens > config size * fix typo and run fix * fix test_integration * fix test * run check * rename and fix chunking * fix * fix * fiz test verbs * fix * fix tests * fix chunking * fix index * fix cosmos test * fix vars * fix after PR * fix
1 parent b9dc7b9 commit b94290e

27 files changed

+527
-54
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "add option to prepend metadata into chunks"
4+
}

graphrag/config/defaults.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -67,6 +67,8 @@
6767
CHUNK_OVERLAP = 100
6868
CHUNK_GROUP_BY_COLUMNS = ["id"]
6969
CHUNK_STRATEGY = ChunkStrategyType.tokens
70+
CHUNK_PREPEND_METADATA = False
71+
CHUNK_SIZE_INCLUDES_METADATA = False
7072

7173
# Claim extraction
7274
DESCRIPTION = "Any claims or facts that could be relevant to information discovery."

graphrag/config/models/chunking_config.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,11 @@ class ChunkingConfig(BaseModel):
2626
encoding_model: str = Field(
2727
description="The encoding model to use.", default=defs.ENCODING_MODEL
2828
)
29+
prepend_metadata: bool = Field(
30+
description="Prepend metadata into each chunk.",
31+
default=defs.CHUNK_PREPEND_METADATA,
32+
)
33+
chunk_size_includes_metadata: bool = Field(
34+
description="Count metadata in max tokens.",
35+
default=defs.CHUNK_SIZE_INCLUDES_METADATA,
36+
)

graphrag/index/flows/create_base_text_units.py

Lines changed: 56 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,15 @@
33

44
"""All the steps to transform base text_units."""
55

6-
from typing import cast
6+
import json
7+
from typing import Any, cast
78

89
import pandas as pd
910

1011
from graphrag.callbacks.workflow_callbacks import WorkflowCallbacks
1112
from graphrag.config.models.chunking_config import ChunkStrategyType
1213
from graphrag.index.operations.chunk_text.chunk_text import chunk_text
14+
from graphrag.index.operations.chunk_text.strategies import get_encoding_fn
1315
from graphrag.index.utils.hashing import gen_sha512_hash
1416
from graphrag.logger.progress import Progress
1517

@@ -22,6 +24,8 @@ def create_base_text_units(
2224
overlap: int,
2325
encoding_model: str,
2426
strategy: ChunkStrategyType,
27+
prepend_metadata: bool = False,
28+
chunk_size_includes_metadata: bool = False,
2529
) -> pd.DataFrame:
2630
"""All the steps to transform base text_units."""
2731
sort = documents.sort_values(by=["id"], ascending=[True])
@@ -32,25 +36,66 @@ def create_base_text_units(
3236

3337
callbacks.progress(Progress(percent=0))
3438

39+
agg_dict = {"text_with_ids": list}
40+
if "metadata" in documents:
41+
agg_dict["metadata"] = "first" # type: ignore
42+
3543
aggregated = (
3644
(
3745
sort.groupby(group_by_columns, sort=False)
3846
if len(group_by_columns) > 0
3947
else sort.groupby(lambda _x: True)
4048
)
41-
.agg(texts=("text_with_ids", list))
49+
.agg(agg_dict)
4250
.reset_index()
4351
)
52+
aggregated.rename(columns={"text_with_ids": "texts"}, inplace=True)
4453

45-
aggregated["chunks"] = chunk_text(
46-
aggregated,
47-
column="texts",
48-
size=size,
49-
overlap=overlap,
50-
encoding_model=encoding_model,
51-
strategy=strategy,
52-
callbacks=callbacks,
53-
)
54+
def chunker(row: dict[str, Any]) -> Any:
55+
line_delimiter = ".\n"
56+
metadata_str = ""
57+
metadata_tokens = 0
58+
59+
if prepend_metadata and "metadata" in row:
60+
metadata = row["metadata"]
61+
if isinstance(metadata, str):
62+
metadata = json.loads(metadata)
63+
if isinstance(metadata, dict):
64+
metadata_str = (
65+
line_delimiter.join(f"{k}: {v}" for k, v in metadata.items())
66+
+ line_delimiter
67+
)
68+
69+
if chunk_size_includes_metadata:
70+
encode, _ = get_encoding_fn(encoding_model)
71+
metadata_tokens = len(encode(metadata_str))
72+
if metadata_tokens >= size:
73+
message = "Metadata tokens exceeds the maximum tokens per chunk. Please increase the tokens per chunk."
74+
raise ValueError(message)
75+
76+
chunked = chunk_text(
77+
pd.DataFrame([row]).reset_index(drop=True),
78+
column="texts",
79+
size=size - metadata_tokens,
80+
overlap=overlap,
81+
encoding_model=encoding_model,
82+
strategy=strategy,
83+
callbacks=callbacks,
84+
)[0]
85+
86+
if prepend_metadata:
87+
for index, chunk in enumerate(chunked):
88+
if isinstance(chunk, str):
89+
chunked[index] = metadata_str + chunk
90+
else:
91+
chunked[index] = (
92+
(chunk[0], metadata_str + chunk[1], chunk[2]) if chunk else None
93+
)
94+
95+
row["chunks"] = chunked
96+
return row
97+
98+
aggregated = aggregated.apply(lambda row: chunker(row), axis=1)
5499

55100
aggregated = cast("pd.DataFrame", aggregated[[*group_by_columns, "chunks"]])
56101
aggregated = aggregated.explode("chunks")

graphrag/index/flows/create_final_documents.py

Lines changed: 3 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,7 @@
77

88

99
def create_final_documents(
10-
documents: pd.DataFrame,
11-
text_units: pd.DataFrame,
12-
metadata: list[str] | None = None,
10+
documents: pd.DataFrame, text_units: pd.DataFrame
1311
) -> pd.DataFrame:
1412
"""All the steps to transform final documents."""
1513
exploded = (
@@ -46,27 +44,17 @@ def create_final_documents(
4644
rejoined["id"] = rejoined["id"].astype(str)
4745
rejoined["human_readable_id"] = rejoined.index + 1
4846

49-
# Convert metadata columns to strings and collapse them into a JSON object
50-
if metadata:
51-
# Convert all specified columns to string at once
52-
rejoined[metadata] = rejoined[metadata].astype(str)
53-
54-
# Collapse the metadata columns into a single JSON object column
55-
rejoined["metadata"] = rejoined[metadata].to_dict(orient="records")
56-
57-
# Drop the original metadata columns after collapsing them
58-
rejoined.drop(columns=metadata, inplace=True)
59-
6047
# set the final column order, but adjust for metadata
6148
core_columns = [
6249
"id",
6350
"human_readable_id",
6451
"title",
6552
"text",
6653
"text_unit_ids",
54+
"creation_date",
6755
]
6856
final_columns = [column for column in core_columns if column in rejoined.columns]
69-
if metadata:
57+
if "metadata" in rejoined.columns:
7058
final_columns.append("metadata")
7159

7260
return rejoined.loc[:, final_columns]

graphrag/index/input/csv.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ async def load_file(path: str, group: dict | None) -> pd.DataFrame:
5050
)
5151
else:
5252
data["text"] = data.apply(lambda x: x[config.text_column], axis=1)
53-
if config.title_column is not None and "title" not in data.columns:
53+
if config.title_column is not None:
5454
if config.title_column not in data.columns:
5555
log.warning(
5656
"title_column %s not found in csv file %s",
@@ -59,6 +59,11 @@ async def load_file(path: str, group: dict | None) -> pd.DataFrame:
5959
)
6060
else:
6161
data["title"] = data.apply(lambda x: x[config.title_column], axis=1)
62+
else:
63+
data["title"] = data.apply(lambda _: path, axis=1)
64+
65+
creation_date = await storage.get_creation_date(path)
66+
data["creation_date"] = data.apply(lambda _: creation_date, axis=1)
6267

6368
return data
6469

graphrag/index/input/factory.py

Lines changed: 17 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -72,8 +72,23 @@ async def create_input(
7272
f"Loading Input ({config.file_type})", transient=False
7373
)
7474
loader = loaders[config.file_type]
75-
results = await loader(config, progress, storage)
76-
return cast("pd.DataFrame", results)
75+
result = await loader(config, progress, storage)
76+
# Convert metadata columns to strings and collapse them into a JSON object
77+
if config.metadata:
78+
if all(col in result.columns for col in config.metadata):
79+
# Collapse the metadata columns into a single JSON object column
80+
result["metadata"] = result[config.metadata].apply(
81+
lambda row: row.to_dict(), axis=1
82+
)
83+
else:
84+
value_error_msg = (
85+
"One or more metadata columns not found in the DataFrame."
86+
)
87+
raise ValueError(value_error_msg)
88+
89+
result[config.metadata] = result[config.metadata].astype(str)
90+
91+
return cast("pd.DataFrame", result)
7792

7893
msg = f"Unknown input type {config.file_type}"
7994
raise ValueError(msg)

graphrag/index/input/text.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,7 @@ async def load_file(
3838
new_item = {**group, "text": text}
3939
new_item["id"] = gen_sha512_hash(new_item, new_item.keys())
4040
new_item["title"] = str(Path(path).name)
41+
new_item["creation_date"] = await storage.get_creation_date(path)
4142
return new_item
4243

4344
files = list(

graphrag/index/operations/chunk_text/chunk_text.py

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -58,14 +58,21 @@ def chunk_text(
5858

5959
num_total = _get_num_total(input, column)
6060
tick = progress_ticker(callbacks.progress, num_total)
61+
6162
# collapse the config back to a single object to support "polymorphic" function call
6263
config = ChunkingConfig(size=size, overlap=overlap, encoding_model=encoding_model)
64+
6365
return cast(
6466
"pd.Series",
6567
input.apply(
6668
cast(
6769
"Any",
68-
lambda x: run_strategy(strategy_exec, x[column], config, tick),
70+
lambda x: run_strategy(
71+
strategy_exec,
72+
x[column],
73+
config,
74+
tick,
75+
),
6976
),
7077
axis=1,
7178
),
@@ -85,12 +92,7 @@ def run_strategy(
8592
# We can work with both just a list of text content
8693
# or a list of tuples of (document_id, text content)
8794
# text_to_chunk = '''
88-
texts = []
89-
for item in input:
90-
if isinstance(item, str):
91-
texts.append(item)
92-
else:
93-
texts.append(item[1])
95+
texts = [item if isinstance(item, str) else item[1] for item in input]
9496

9597
strategy_results = strategy_exec(texts, config, tick)
9698

graphrag/index/operations/chunk_text/strategies.py

Lines changed: 16 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,13 +17,8 @@
1717
from graphrag.logger.progress import ProgressTicker
1818

1919

20-
def run_tokens(
21-
input: list[str], config: ChunkingConfig, tick: ProgressTicker
22-
) -> Iterable[TextChunk]:
23-
"""Chunks text into chunks based on encoding tokens."""
24-
tokens_per_chunk = config.size
25-
chunk_overlap = config.overlap
26-
encoding_name = config.encoding_model
20+
def get_encoding_fn(encoding_name):
21+
"""Get the encoding model."""
2722
enc = tiktoken.get_encoding(encoding_name)
2823

2924
def encode(text: str) -> list[int]:
@@ -34,6 +29,20 @@ def encode(text: str) -> list[int]:
3429
def decode(tokens: list[int]) -> str:
3530
return enc.decode(tokens)
3631

32+
return encode, decode
33+
34+
35+
def run_tokens(
36+
input: list[str],
37+
config: ChunkingConfig,
38+
tick: ProgressTicker,
39+
) -> Iterable[TextChunk]:
40+
"""Chunks text into chunks based on encoding tokens."""
41+
tokens_per_chunk = config.size
42+
chunk_overlap = config.overlap
43+
encoding_name = config.encoding_model
44+
45+
encode, decode = get_encoding_fn(encoding_name)
3746
return split_multiple_texts_on_tokens(
3847
input,
3948
Tokenizer(

0 commit comments

Comments
 (0)