Skip to content

Commit ad5b512

Browse files
authored
remove unused columns and rename document_attribute_columns (microsoft#1672)
* remove unused columns and change property document_attribute_columns to metadata * format file * fix 'metadata' column on output * run check * fix test on nltk * remove docs changes
1 parent 907d271 commit ad5b512

File tree

11 files changed

+155
-79
lines changed

11 files changed

+155
-79
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "patch",
3+
"description": "remove unused columns and change property document_attribute_columns to metadata"
4+
}

graphrag/config/models/input_config.py

Lines changed: 2 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -40,21 +40,12 @@ class InputConfig(BaseModel):
4040
file_filter: dict[str, str] | None = Field(
4141
description="The optional file filter for the input files.", default=None
4242
)
43-
source_column: str | None = Field(
44-
description="The input source column to use.", default=None
45-
)
46-
timestamp_column: str | None = Field(
47-
description="The input timestamp column to use.", default=None
48-
)
49-
timestamp_format: str | None = Field(
50-
description="The input timestamp format to use.", default=None
51-
)
5243
text_column: str = Field(
5344
description="The input text column to use.", default=defs.INPUT_TEXT_COLUMN
5445
)
5546
title_column: str | None = Field(
5647
description="The input title column to use.", default=None
5748
)
58-
document_attribute_columns: list[str] = Field(
59-
description="The document attribute columns to use.", default=[]
49+
metadata: list[str] | None = Field(
50+
description="The document attribute columns to use.", default=None
6051
)

graphrag/index/flows/create_final_documents.py

Lines changed: 11 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
def create_final_documents(
1010
documents: pd.DataFrame,
1111
text_units: pd.DataFrame,
12-
document_attribute_columns: list[str] | None = None,
12+
metadata: list[str] | None = None,
1313
) -> pd.DataFrame:
1414
"""All the steps to transform final documents."""
1515
exploded = (
@@ -46,22 +46,18 @@ def create_final_documents(
4646
rejoined["id"] = rejoined["id"].astype(str)
4747
rejoined["human_readable_id"] = rejoined.index + 1
4848

49-
# Convert attribute columns to strings and collapse them into a JSON object
50-
if document_attribute_columns:
49+
# Convert metadata columns to strings and collapse them into a JSON object
50+
if metadata:
5151
# Convert all specified columns to string at once
52-
rejoined[document_attribute_columns] = rejoined[
53-
document_attribute_columns
54-
].astype(str)
52+
rejoined[metadata] = rejoined[metadata].astype(str)
5553

56-
# Collapse the document_attribute_columns into a single JSON object column
57-
rejoined["attributes"] = rejoined[document_attribute_columns].to_dict(
58-
orient="records"
59-
)
54+
# Collapse the metadata columns into a single JSON object column
55+
rejoined["metadata"] = rejoined[metadata].to_dict(orient="records")
6056

61-
# Drop the original attribute columns after collapsing them
62-
rejoined.drop(columns=document_attribute_columns, inplace=True)
57+
# Drop the original metadata columns after collapsing them
58+
rejoined.drop(columns=metadata, inplace=True)
6359

64-
# set the final column order, but adjust for attributes
60+
# set the final column order, but adjust for metadata
6561
core_columns = [
6662
"id",
6763
"human_readable_id",
@@ -70,7 +66,7 @@ def create_final_documents(
7066
"text_unit_ids",
7167
]
7268
final_columns = [column for column in core_columns if column in rejoined.columns]
73-
if document_attribute_columns:
74-
final_columns.append("attributes")
69+
if metadata:
70+
final_columns.append("metadata")
7571

7672
return rejoined.loc[:, final_columns]

graphrag/index/input/csv.py

Lines changed: 0 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -41,15 +41,6 @@ async def load_file(path: str, group: dict | None) -> pd.DataFrame:
4141
)
4242
if "id" not in data.columns:
4343
data["id"] = data.apply(lambda x: gen_sha512_hash(x, x.keys()), axis=1)
44-
if config.source_column is not None and "source" not in data.columns:
45-
if config.source_column not in data.columns:
46-
log.warning(
47-
"source_column %s not found in csv file %s",
48-
config.source_column,
49-
path,
50-
)
51-
else:
52-
data["source"] = data.apply(lambda x: x[config.source_column], axis=1)
5344
if config.text_column is not None and "text" not in data.columns:
5445
if config.text_column not in data.columns:
5546
log.warning(
@@ -69,37 +60,6 @@ async def load_file(path: str, group: dict | None) -> pd.DataFrame:
6960
else:
7061
data["title"] = data.apply(lambda x: x[config.title_column], axis=1)
7162

72-
if config.timestamp_column is not None:
73-
fmt = config.timestamp_format
74-
if fmt is None:
75-
msg = "Must specify timestamp_format if timestamp_column is specified"
76-
raise ValueError(msg)
77-
78-
if config.timestamp_column not in data.columns:
79-
log.warning(
80-
"timestamp_column %s not found in csv file %s",
81-
config.timestamp_column,
82-
path,
83-
)
84-
else:
85-
data["timestamp"] = pd.to_datetime(
86-
data[config.timestamp_column], format=fmt
87-
)
88-
89-
# TODO: Theres probably a less gross way to do this
90-
if "year" not in data.columns:
91-
data["year"] = data.apply(lambda x: x["timestamp"].year, axis=1)
92-
if "month" not in data.columns:
93-
data["month"] = data.apply(lambda x: x["timestamp"].month, axis=1)
94-
if "day" not in data.columns:
95-
data["day"] = data.apply(lambda x: x["timestamp"].day, axis=1)
96-
if "hour" not in data.columns:
97-
data["hour"] = data.apply(lambda x: x["timestamp"].hour, axis=1)
98-
if "minute" not in data.columns:
99-
data["minute"] = data.apply(lambda x: x["timestamp"].minute, axis=1)
100-
if "second" not in data.columns:
101-
data["second"] = data.apply(lambda x: x["timestamp"].second, axis=1)
102-
10363
return data
10464

10565
file_pattern = (

graphrag/index/workflows/create_final_documents.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,9 +28,7 @@ async def run_workflow(
2828
)
2929

3030
input = config.input
31-
output = create_final_documents(
32-
documents, text_units, input.document_attribute_columns
33-
)
31+
output = create_final_documents(documents, text_units, input.metadata)
3432

3533
await write_table_to_storage(output, workflow_name, context.storage)
3634

tests/unit/config/utils.py

Lines changed: 2 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -94,12 +94,9 @@
9494
"encoding": defs.INPUT_FILE_ENCODING,
9595
"file_pattern": defs.INPUT_TEXT_PATTERN,
9696
"file_filter": None,
97-
"source_column": None,
98-
"timestamp_column": None,
99-
"timestamp_format": None,
10097
"text_column": defs.INPUT_TEXT_COLUMN,
10198
"title_column": None,
102-
"document_attribute_columns": [],
99+
"metadata": None,
103100
},
104101
"embed_graph": {
105102
"enabled": defs.NODE2VEC_ENABLED,
@@ -344,12 +341,9 @@ def assert_input_configs(actual: InputConfig, expected: InputConfig) -> None:
344341
assert actual.encoding == expected.encoding
345342
assert actual.file_pattern == expected.file_pattern
346343
assert actual.file_filter == expected.file_filter
347-
assert actual.source_column == expected.source_column
348-
assert actual.timestamp_column == expected.timestamp_column
349-
assert actual.timestamp_format == expected.timestamp_format
350344
assert actual.text_column == expected.text_column
351345
assert actual.title_column == expected.title_column
352-
assert actual.document_attribute_columns == expected.document_attribute_columns
346+
assert actual.metadata == expected.metadata
353347

354348

355349
def assert_embed_graph_configs(
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,86 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
from unittest.mock import Mock, patch
5+
6+
from graphrag.config.models.chunking_config import ChunkingConfig
7+
from graphrag.index.bootstrap import bootstrap
8+
from graphrag.index.operations.chunk_text.strategies import run_sentences, run_tokens
9+
from graphrag.index.operations.chunk_text.typing import TextChunk
10+
11+
12+
class TestRunSentences:
13+
def setup_method(self, method):
14+
bootstrap()
15+
16+
def test_basic_functionality(self):
17+
"""Test basic sentence splitting without metadata"""
18+
input = ["This is a test. Another sentence."]
19+
tick = Mock()
20+
chunks = list(run_sentences(input, ChunkingConfig(), tick))
21+
22+
assert len(chunks) == 2
23+
assert chunks[0].text_chunk == "This is a test."
24+
assert chunks[1].text_chunk == "Another sentence."
25+
assert all(c.source_doc_indices == [0] for c in chunks)
26+
tick.assert_called_once_with(1)
27+
28+
def test_multiple_documents(self):
29+
"""Test processing multiple input documents"""
30+
input = ["First. Document.", "Second. Doc."]
31+
tick = Mock()
32+
chunks = list(run_sentences(input, ChunkingConfig(), tick))
33+
34+
assert len(chunks) == 4
35+
assert chunks[0].source_doc_indices == [0]
36+
assert chunks[2].source_doc_indices == [1]
37+
assert tick.call_count == 2
38+
39+
def test_mixed_whitespace_handling(self):
40+
"""Test input with irregular whitespace"""
41+
input = [" Sentence with spaces. Another one! "]
42+
chunks = list(run_sentences(input, ChunkingConfig(), Mock()))
43+
assert chunks[0].text_chunk == " Sentence with spaces."
44+
assert chunks[1].text_chunk == "Another one!"
45+
46+
47+
class TestRunTokens:
48+
@patch("tiktoken.get_encoding")
49+
def test_basic_functionality(self, mock_get_encoding):
50+
mock_encoder = Mock()
51+
mock_encoder.encode.side_effect = lambda x: list(x.encode())
52+
mock_encoder.decode.side_effect = lambda x: bytes(x).decode()
53+
mock_get_encoding.return_value = mock_encoder
54+
55+
# Input and config
56+
input = [
57+
"Marley was dead: to begin with. There is no doubt whatever about that. The register of his burial was signed by the clergyman, the clerk, the undertaker, and the chief mourner. Scrooge signed it. And Scrooge's name was good upon 'Change, for anything he chose to put his hand to."
58+
]
59+
config = ChunkingConfig(size=5, overlap=1, encoding_model="fake-encoding")
60+
tick = Mock()
61+
62+
# Run the function
63+
chunks = list(run_tokens(input, config, tick))
64+
65+
# Verify output
66+
assert len(chunks) > 0
67+
assert all(isinstance(chunk, TextChunk) for chunk in chunks)
68+
tick.assert_called_once_with(1)
69+
70+
@patch("tiktoken.get_encoding")
71+
def test_non_string_input(self, mock_get_encoding):
72+
"""Test handling of non-string input (e.g., numbers)."""
73+
mock_encoder = Mock()
74+
mock_encoder.encode.side_effect = lambda x: list(str(x).encode())
75+
mock_encoder.decode.side_effect = lambda x: bytes(x).decode()
76+
mock_get_encoding.return_value = mock_encoder
77+
78+
input = [123] # Non-string input
79+
config = ChunkingConfig(size=5, overlap=1, encoding_model="fake-encoding")
80+
tick = Mock()
81+
82+
chunks = list(run_tokens(input, config, tick)) # type: ignore
83+
84+
# Verify non-string input is handled
85+
assert len(chunks) > 0
86+
assert "123" in chunks[0].text_chunk

tests/unit/indexing/text_splitting/test_text_splitting.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from unittest.mock import MagicMock
66

77
import pytest
8+
import tiktoken
89

910
from graphrag.index.text_splitting.text_splitting import (
1011
NoopTextSplitter,
@@ -159,3 +160,45 @@ def test_split_multiple_texts_on_tokens():
159160

160161
split_multiple_texts_on_tokens(texts, tokenizer, tick=mock_tick)
161162
mock_tick.assert_called()
163+
164+
165+
def test_split_single_text_on_tokens_no_overlap():
166+
text = "This is a test text, meaning to be taken seriously by this test only."
167+
enc = tiktoken.get_encoding("cl100k_base")
168+
169+
def encode(text: str) -> list[int]:
170+
if not isinstance(text, str):
171+
text = f"{text}"
172+
return enc.encode(text)
173+
174+
def decode(tokens: list[int]) -> str:
175+
return enc.decode(tokens)
176+
177+
tokenizer = Tokenizer(
178+
chunk_overlap=1,
179+
tokens_per_chunk=2,
180+
decode=decode,
181+
encode=lambda text: encode(text),
182+
)
183+
184+
expected_splits = [
185+
"This is",
186+
" is a",
187+
" a test",
188+
" test text",
189+
" text,",
190+
", meaning",
191+
" meaning to",
192+
" to be",
193+
" be taken", # cspell:disable-line
194+
" taken seriously", # cspell:disable-line
195+
" seriously by",
196+
" by this", # cspell:disable-line
197+
" this test",
198+
" test only",
199+
" only.",
200+
".",
201+
]
202+
203+
result = split_single_text_on_tokens(text=text, tokenizer=tokenizer)
204+
assert result == expected_splits

0 commit comments

Comments
 (0)