Skip to content

Commit 978e798

Browse files
authored
Remove file filtering (#2050)
* Remove document filtering * Semver * Fix integ tests * Fix file find tuple * Fix another dangling find tuple
1 parent 429e1b1 commit 978e798

File tree

17 files changed

+40
-128
lines changed

17 files changed

+40
-128
lines changed
Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,4 @@
1+
{
2+
"type": "major",
3+
"description": "Remove document filtering option."
4+
}

docs/config/yaml.md

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -87,7 +87,6 @@ Our pipeline can ingest .csv, .txt, or .json data from an input location. See th
8787
- `file_type` **text|csv|json** - The type of input data to load. Default is `text`
8888
- `encoding` **str** - The encoding of the input file. Default is `utf-8`
8989
- `file_pattern` **str** - A regex to match input files. Default is `.*\.csv$`, `.*\.txt$`, or `.*\.json$` depending on the specified `file_type`, but you can customize it if needed.
90-
- `file_filter` **dict** - Key/value pairs to filter. Default is None.
9190
- `text_column` **str** - (CSV/JSON only) The text column name. If unset we expect a column named `text`.
9291
- `title_column` **str** - (CSV/JSON only) The title column name, filename will be used if unset.
9392
- `metadata` **list[str]** - (CSV/JSON only) The additional document attributes fields to keep.

graphrag/config/defaults.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,6 @@ class InputDefaults:
248248
file_type: ClassVar[InputFileType] = InputFileType.text
249249
encoding: str = "utf-8"
250250
file_pattern: str = ""
251-
file_filter: None = None
252251
text_column: str = "text"
253252
title_column: None = None
254253
metadata: None = None

graphrag/config/models/input_config.py

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,6 @@ class InputConfig(BaseModel):
3232
description="The input file pattern to use.",
3333
default=graphrag_config_defaults.input.file_pattern,
3434
)
35-
file_filter: dict[str, str] | None = Field(
36-
description="The optional file filter for the input files.",
37-
default=graphrag_config_defaults.input.file_filter,
38-
)
3935
text_column: str = Field(
4036
description="The input text column to use.",
4137
default=graphrag_config_defaults.input.text_column,

graphrag/index/input/csv.py

Lines changed: 1 addition & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -22,19 +22,10 @@ async def load_csv(
2222
"""Load csv inputs from a directory."""
2323
logger.info("Loading csv files from %s", config.storage.base_dir)
2424

25-
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
26-
if group is None:
27-
group = {}
25+
async def load_file(path: str) -> pd.DataFrame:
2826
buffer = BytesIO(await storage.get(path, as_bytes=True))
2927
data = pd.read_csv(buffer, encoding=config.encoding)
30-
additional_keys = group.keys()
31-
if len(additional_keys) > 0:
32-
data[[*additional_keys]] = data.apply(
33-
lambda _row: pd.Series([group[key] for key in additional_keys]), axis=1
34-
)
35-
3628
data = process_data_columns(data, config, path)
37-
3829
creation_date = await storage.get_creation_date(path)
3930
data["creation_date"] = data.apply(lambda _: creation_date, axis=1)
4031

graphrag/index/input/json.py

Lines changed: 1 addition & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -22,23 +22,13 @@ async def load_json(
2222
"""Load json inputs from a directory."""
2323
logger.info("Loading json files from %s", config.storage.base_dir)
2424

25-
async def load_file(path: str, group: dict | None) -> pd.DataFrame:
26-
if group is None:
27-
group = {}
25+
async def load_file(path: str) -> pd.DataFrame:
2826
text = await storage.get(path, encoding=config.encoding)
2927
as_json = json.loads(text)
3028
# json file could just be a single object, or an array of objects
3129
rows = as_json if isinstance(as_json, list) else [as_json]
3230
data = pd.DataFrame(rows)
33-
34-
additional_keys = group.keys()
35-
if len(additional_keys) > 0:
36-
data[[*additional_keys]] = data.apply(
37-
lambda _row: pd.Series([group[key] for key in additional_keys]), axis=1
38-
)
39-
4031
data = process_data_columns(data, config, path)
41-
4232
creation_date = await storage.get_creation_date(path)
4333
data["creation_date"] = data.apply(lambda _: creation_date, axis=1)
4434

graphrag/index/input/text.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,11 +22,9 @@ async def load_text(
2222
) -> pd.DataFrame:
2323
"""Load text inputs from a directory."""
2424

25-
async def load_file(path: str, group: dict | None = None) -> pd.DataFrame:
26-
if group is None:
27-
group = {}
25+
async def load_file(path: str) -> pd.DataFrame:
2826
text = await storage.get(path, encoding=config.encoding)
29-
new_item = {**group, "text": text}
27+
new_item = {"text": text}
3028
new_item["id"] = gen_sha512_hash(new_item, new_item.keys())
3129
new_item["title"] = str(Path(path).name)
3230
new_item["creation_date"] = await storage.get_creation_date(path)

graphrag/index/input/util.py

Lines changed: 3 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -22,22 +22,17 @@ async def load_files(
2222
storage: PipelineStorage,
2323
) -> pd.DataFrame:
2424
"""Load files from storage and apply a loader function."""
25-
files = list(
26-
storage.find(
27-
re.compile(config.file_pattern),
28-
file_filter=config.file_filter,
29-
)
30-
)
25+
files = list(storage.find(re.compile(config.file_pattern)))
3126

3227
if len(files) == 0:
3328
msg = f"No {config.file_type} files found in {config.storage.base_dir}"
3429
raise ValueError(msg)
3530

3631
files_loaded = []
3732

38-
for file, group in files:
33+
for file in files:
3934
try:
40-
files_loaded.append(await loader(file, group))
35+
files_loaded.append(await loader(file))
4136
except Exception as e: # noqa: BLE001 (catching Exception is fine here)
4237
logger.warning("Warning! Error loading file %s. Skipping...", file)
4338
logger.warning("Error: %s", e)

graphrag/storage/blob_pipeline_storage.py

Lines changed: 7 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -101,15 +101,13 @@ def find(
101101
self,
102102
file_pattern: re.Pattern[str],
103103
base_dir: str | None = None,
104-
file_filter: dict[str, Any] | None = None,
105104
max_count=-1,
106-
) -> Iterator[tuple[str, dict[str, Any]]]:
107-
"""Find blobs in a container using a file pattern, as well as a custom filter function.
105+
) -> Iterator[str]:
106+
"""Find blobs in a container using a file pattern.
108107
109108
Params:
110109
base_dir: The name of the base container.
111110
file_pattern: The file pattern to use.
112-
file_filter: A dictionary of key-value pairs to filter the blobs.
113111
max_count: The maximum number of blobs to return. If -1, all blobs are returned.
114112
115113
Returns
@@ -131,14 +129,6 @@ def _blobname(blob_name: str) -> str:
131129
blob_name = blob_name[1:]
132130
return blob_name
133131

134-
def item_filter(item: dict[str, Any]) -> bool:
135-
if file_filter is None:
136-
return True
137-
138-
return all(
139-
re.search(value, item[key]) for key, value in file_filter.items()
140-
)
141-
142132
try:
143133
container_client = self._blob_service_client.get_container_client(
144134
self._container_name
@@ -151,14 +141,10 @@ def item_filter(item: dict[str, Any]) -> bool:
151141
for blob in all_blobs:
152142
match = file_pattern.search(blob.name)
153143
if match and blob.name.startswith(base_dir):
154-
group = match.groupdict()
155-
if item_filter(group):
156-
yield (_blobname(blob.name), group)
157-
num_loaded += 1
158-
if max_count > 0 and num_loaded >= max_count:
159-
break
160-
else:
161-
num_filtered += 1
144+
yield _blobname(blob.name)
145+
num_loaded += 1
146+
if max_count > 0 and num_loaded >= max_count:
147+
break
162148
else:
163149
num_filtered += 1
164150
logger.debug(
@@ -169,10 +155,9 @@ def item_filter(item: dict[str, Any]) -> bool:
169155
)
170156
except Exception: # noqa: BLE001
171157
logger.warning(
172-
"Error finding blobs: base_dir=%s, file_pattern=%s, file_filter=%s",
158+
"Error finding blobs: base_dir=%s, file_pattern=%s",
173159
base_dir,
174160
file_pattern,
175-
file_filter,
176161
)
177162

178163
async def get(

graphrag/storage/cosmosdb_pipeline_storage.py

Lines changed: 7 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -121,15 +121,13 @@ def find(
121121
self,
122122
file_pattern: re.Pattern[str],
123123
base_dir: str | None = None,
124-
file_filter: dict[str, Any] | None = None,
125124
max_count=-1,
126-
) -> Iterator[tuple[str, dict[str, Any]]]:
127-
"""Find documents in a Cosmos DB container using a file pattern regex and custom file filter (optional).
125+
) -> Iterator[str]:
126+
"""Find documents in a Cosmos DB container using a file pattern regex.
128127
129128
Params:
130129
base_dir: The name of the base directory (not used in Cosmos DB context).
131130
file_pattern: The file pattern to use.
132-
file_filter: A dictionary of key-value pairs to filter the documents.
133131
max_count: The maximum number of documents to return. If -1, all documents are returned.
134132
135133
Returns
@@ -145,23 +143,12 @@ def find(
145143
if not self._database_client or not self._container_client:
146144
return
147145

148-
def item_filter(item: dict[str, Any]) -> bool:
149-
if file_filter is None:
150-
return True
151-
return all(
152-
re.search(value, item.get(key, ""))
153-
for key, value in file_filter.items()
154-
)
155-
156146
try:
157147
query = "SELECT * FROM c WHERE RegexMatch(c.id, @pattern)"
158148
parameters: list[dict[str, Any]] = [
159149
{"name": "@pattern", "value": file_pattern.pattern}
160150
]
161-
if file_filter:
162-
for key, value in file_filter.items():
163-
query += f" AND c.{key} = @{key}"
164-
parameters.append({"name": f"@{key}", "value": value})
151+
165152
items = list(
166153
self._container_client.query_items(
167154
query=query,
@@ -177,14 +164,10 @@ def item_filter(item: dict[str, Any]) -> bool:
177164
for item in items:
178165
match = file_pattern.search(item["id"])
179166
if match:
180-
group = match.groupdict()
181-
if item_filter(group):
182-
yield (item["id"], group)
183-
num_loaded += 1
184-
if max_count > 0 and num_loaded >= max_count:
185-
break
186-
else:
187-
num_filtered += 1
167+
yield item["id"]
168+
num_loaded += 1
169+
if max_count > 0 and num_loaded >= max_count:
170+
break
188171
else:
189172
num_filtered += 1
190173

0 commit comments

Comments
 (0)