Skip to content

Commit 66c2cfb

Browse files
Support JSON input files (microsoft#1777)
* Add csv loader tests * Add test loader tests * Add json input support * Remove temp path constraint * Reuse loader cose * Semver * Set file pattern automatically based on type, if empty * Remove pattern from smoke test config * Spelling --------- Co-authored-by: Alonso Guevara <[email protected]>
1 parent bcb7478 commit 66c2cfb

File tree

27 files changed

+386
-107
lines changed

27 files changed

+386
-107
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "minor",
3+
"description": "Add support for JSON inuput files."
4+
}

dictionary.txt

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,8 @@ upvote
188188
# Misc
189189
Arxiv
190190
kwds
191+
jsons
192+
txts
191193

192194
# Dulce
193195
astrotechnician

graphrag/config/defaults.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,7 +257,7 @@ class InputDefaults:
257257
storage_account_blob_url: None = None
258258
container_name: None = None
259259
encoding: str = "utf-8"
260-
file_pattern: str = ".*\\.txt$"
260+
file_pattern: str = ""
261261
file_filter: None = None
262262
text_column: str = "text"
263263
title_column: None = None

graphrag/config/enums.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,8 @@ class InputFileType(str, Enum):
3434
"""The CSV input type."""
3535
text = "text"
3636
"""The text input type."""
37+
json = "json"
38+
"""The JSON input type."""
3739

3840
def __repr__(self):
3941
"""Get a string representation."""

graphrag/config/init_content.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -70,10 +70,8 @@
7070
7171
input:
7272
type: {graphrag_config_defaults.input.type.value} # or blob
73-
file_type: {graphrag_config_defaults.input.file_type.value} # or csv
73+
file_type: {graphrag_config_defaults.input.file_type.value} # [csv, text, json]
7474
base_dir: "{graphrag_config_defaults.input.base_dir}"
75-
file_encoding: {graphrag_config_defaults.input.encoding}
76-
file_pattern: ".*\\\\.txt$$"
7775
7876
chunks:
7977
size: {graphrag_config_defaults.chunks.size}

graphrag/config/models/graph_rag_config.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -166,6 +166,14 @@ def _validate_update_index_output_base_dir(self) -> None:
166166
)
167167
"""The input configuration."""
168168

169+
def _validate_input_pattern(self) -> None:
170+
"""Validate the input file pattern based on the specified type."""
171+
if len(self.input.file_pattern) == 0:
172+
if self.input.file_type == defs.InputFileType.text:
173+
self.input.file_pattern = ".*\\.txt$"
174+
else:
175+
self.input.file_pattern = f".*\\.{self.input.file_type.value}$"
176+
169177
embed_graph: EmbedGraphConfig = Field(
170178
description="Graph embedding configuration.",
171179
default=EmbedGraphConfig(),
@@ -336,6 +344,7 @@ def _validate_model(self):
336344
"""Validate the model configuration."""
337345
self._validate_root_dir()
338346
self._validate_models()
347+
self._validate_input_pattern()
339348
self._validate_reporting_base_dir()
340349
self._validate_output_base_dir()
341350
self._validate_multi_output_base_dirs()

graphrag/index/input/csv.py

Lines changed: 5 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -4,24 +4,19 @@
44
"""A module containing load method definition."""
55

66
import logging
7-
import re
87
from io import BytesIO
98

109
import pandas as pd
1110

1211
from graphrag.config.models.input_config import InputConfig
13-
from graphrag.index.utils.hashing import gen_sha512_hash
12+
from graphrag.index.input.util import load_files, process_data_columns
1413
from graphrag.logger.base import ProgressLogger
1514
from graphrag.storage.pipeline_storage import PipelineStorage
1615

1716
log = logging.getLogger(__name__)
1817

19-
DEFAULT_FILE_PATTERN = re.compile(r"(?P<filename>[^\\/]).csv$")
2018

21-
input_type = "csv"
22-
23-
24-
async def load(
19+
async def load_csv(
2520
config: InputConfig,
2621
progress: ProgressLogger | None,
2722
storage: PipelineStorage,
@@ -39,61 +34,12 @@ async def load_file(path: str, group: dict | None) -> pd.DataFrame:
3934
data[[*additional_keys]] = data.apply(
4035
lambda _row: pd.Series([group[key] for key in additional_keys]), axis=1
4136
)
42-
if "id" not in data.columns:
43-
data["id"] = data.apply(lambda x: gen_sha512_hash(x, x.keys()), axis=1)
44-
if config.text_column is not None and "text" not in data.columns:
45-
if config.text_column not in data.columns:
46-
log.warning(
47-
"text_column %s not found in csv file %s",
48-
config.text_column,
49-
path,
50-
)
51-
else:
52-
data["text"] = data.apply(lambda x: x[config.text_column], axis=1)
53-
if config.title_column is not None:
54-
if config.title_column not in data.columns:
55-
log.warning(
56-
"title_column %s not found in csv file %s",
57-
config.title_column,
58-
path,
59-
)
60-
else:
61-
data["title"] = data.apply(lambda x: x[config.title_column], axis=1)
62-
else:
63-
data["title"] = data.apply(lambda _: path, axis=1)
37+
38+
data = process_data_columns(data, config, path)
6439

6540
creation_date = await storage.get_creation_date(path)
6641
data["creation_date"] = data.apply(lambda _: creation_date, axis=1)
6742

6843
return data
6944

70-
file_pattern = (
71-
re.compile(config.file_pattern)
72-
if config.file_pattern is not None
73-
else DEFAULT_FILE_PATTERN
74-
)
75-
files = list(
76-
storage.find(
77-
file_pattern,
78-
progress=progress,
79-
file_filter=config.file_filter,
80-
)
81-
)
82-
83-
if len(files) == 0:
84-
msg = f"No CSV files found in {config.base_dir}"
85-
raise ValueError(msg)
86-
87-
files_loaded = []
88-
89-
for file, group in files:
90-
try:
91-
files_loaded.append(await load_file(file, group))
92-
except Exception: # noqa: BLE001 (catching Exception is fine here)
93-
log.warning("Warning! Error loading csv file %s. Skipping...", file)
94-
95-
log.info("Found %d csv files, loading %d", len(files), len(files_loaded))
96-
result = pd.concat(files_loaded)
97-
total_files_log = f"Total number of unfiltered csv rows: {len(result)}"
98-
log.info(total_files_log)
99-
return result
45+
return await load_files(load_file, config, storage, progress)

graphrag/index/input/factory.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -10,21 +10,21 @@
1010

1111
import pandas as pd
1212

13-
from graphrag.config.enums import InputType
13+
from graphrag.config.enums import InputFileType, InputType
1414
from graphrag.config.models.input_config import InputConfig
15-
from graphrag.index.input.csv import input_type as csv
16-
from graphrag.index.input.csv import load as load_csv
17-
from graphrag.index.input.text import input_type as text
18-
from graphrag.index.input.text import load as load_text
15+
from graphrag.index.input.csv import load_csv
16+
from graphrag.index.input.json import load_json
17+
from graphrag.index.input.text import load_text
1918
from graphrag.logger.base import ProgressLogger
2019
from graphrag.logger.null_progress import NullProgressLogger
2120
from graphrag.storage.blob_pipeline_storage import BlobPipelineStorage
2221
from graphrag.storage.file_pipeline_storage import FilePipelineStorage
2322

2423
log = logging.getLogger(__name__)
2524
loaders: dict[str, Callable[..., Awaitable[pd.DataFrame]]] = {
26-
text: load_text,
27-
csv: load_csv,
25+
InputFileType.text: load_text,
26+
InputFileType.csv: load_csv,
27+
InputFileType.json: load_json,
2828
}
2929

3030

graphrag/index/input/json.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
# Copyright (c) 2024 Microsoft Corporation.
2+
# Licensed under the MIT License
3+
4+
"""A module containing load method definition."""
5+
6+
import json
7+
import logging
8+
9+
import pandas as pd
10+
11+
from graphrag.config.models.input_config import InputConfig
12+
from graphrag.index.input.util import load_files, process_data_columns
13+
from graphrag.logger.base import ProgressLogger
14+
from graphrag.storage.pipeline_storage import PipelineStorage
15+
16+
log = logging.getLogger(__name__)
17+
18+
19+
async def load_json(
20+
config: InputConfig,
21+
progress: ProgressLogger | None,
22+
storage: PipelineStorage,
23+
) -> pd.DataFrame:
24+
"""Load json inputs from a directory."""
25+
log.info("Loading json files from %s", config.base_dir)
26+
27+
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
28+
if group is None:
29+
group = {}
30+
text = await storage.get(path, encoding=config.encoding)
31+
as_json = json.loads(text)
32+
# json file could just be a single object, or an array of objects
33+
rows = as_json if isinstance(as_json, list) else [as_json]
34+
data = pd.DataFrame(rows)
35+
36+
additional_keys = group.keys()
37+
if len(additional_keys) > 0:
38+
data[[*additional_keys]] = data.apply(
39+
lambda _row: pd.Series([group[key] for key in additional_keys]), axis=1
40+
)
41+
42+
data = process_data_columns(data, config, path)
43+
44+
creation_date = await storage.get_creation_date(path)
45+
data["creation_date"] = data.apply(lambda _: creation_date, axis=1)
46+
47+
return data
48+
49+
return await load_files(load_file, config, storage, progress)

graphrag/index/input/text.py

Lines changed: 6 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -4,64 +4,34 @@
44
"""A module containing load method definition."""
55

66
import logging
7-
import re
87
from pathlib import Path
9-
from typing import Any
108

119
import pandas as pd
1210

1311
from graphrag.config.models.input_config import InputConfig
12+
from graphrag.index.input.util import load_files
1413
from graphrag.index.utils.hashing import gen_sha512_hash
1514
from graphrag.logger.base import ProgressLogger
1615
from graphrag.storage.pipeline_storage import PipelineStorage
1716

18-
DEFAULT_FILE_PATTERN = re.compile(
19-
r".*[\\/](?P<source>[^\\/]+)[\\/](?P<year>\d{4})-(?P<month>\d{2})-(?P<day>\d{2})_(?P<author>[^_]+)_\d+\.txt"
20-
)
21-
input_type = "text"
2217
log = logging.getLogger(__name__)
2318

2419

25-
async def load(
20+
async def load_text(
2621
config: InputConfig,
2722
progress: ProgressLogger | None,
2823
storage: PipelineStorage,
2924
) -> pd.DataFrame:
3025
"""Load text inputs from a directory."""
3126

32-
async def load_file(
33-
path: str, group: dict | None = None, _encoding: str = "utf-8"
34-
) -> dict[str, Any]:
27+
async def load_file(path: str, group: dict | None = None) -> pd.DataFrame:
3528
if group is None:
3629
group = {}
37-
text = await storage.get(path, encoding="utf-8")
30+
text = await storage.get(path, encoding=config.encoding)
3831
new_item = {**group, "text": text}
3932
new_item["id"] = gen_sha512_hash(new_item, new_item.keys())
4033
new_item["title"] = str(Path(path).name)
4134
new_item["creation_date"] = await storage.get_creation_date(path)
42-
return new_item
35+
return pd.DataFrame([new_item])
4336

44-
files = list(
45-
storage.find(
46-
re.compile(config.file_pattern),
47-
progress=progress,
48-
file_filter=config.file_filter,
49-
)
50-
)
51-
if len(files) == 0:
52-
msg = f"No text files found in {config.base_dir}"
53-
raise ValueError(msg)
54-
found_files = f"found text files from {config.base_dir}, found {files}"
55-
log.info(found_files)
56-
57-
files_loaded = []
58-
59-
for file, group in files:
60-
try:
61-
files_loaded.append(await load_file(file, group))
62-
except Exception: # noqa: BLE001 (catching Exception is fine here)
63-
log.warning("Warning! Error loading file %s. Skipping...", file)
64-
65-
log.info("Found %d files, loading %d", len(files), len(files_loaded))
66-
67-
return pd.DataFrame(files_loaded)
37+
return await load_files(load_file, config, storage, progress)

0 commit comments

Comments
 (0)