Skip to content

Commit 13aca76

Browse files
Add max_chars
1 parent 051c1c5 commit 13aca76

File tree

4 files changed

+69
-4
lines changed

4 files changed

+69
-4
lines changed

src/neo4j_graphrag/experimental/components/data_loader.py

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,11 +49,21 @@ def get_document_metadata(
4949
) -> Dict[str, str] | None:
5050
return metadata
5151

52+
@staticmethod
53+
def _apply_max_chars(text: str, max_chars: Optional[int] = None) -> str:
54+
if max_chars is None:
55+
return text
56+
if max_chars < 0:
57+
raise ValueError("max_chars must be >= 0")
58+
return text[:max_chars]
59+
5260
@abstractmethod
5361
async def run(
5462
self,
5563
filepath: Union[str, Path],
5664
metadata: Optional[Dict[str, str]] = None,
65+
fs: Optional[Union[AbstractFileSystem, str]] = None,
66+
max_chars: Optional[int] = None,
5767
) -> LoadedDocument: ...
5868

5969

@@ -83,14 +93,17 @@ async def run(
8393
filepath: Union[str, Path],
8494
metadata: Optional[Dict[str, str]] = None,
8595
fs: Optional[Union[AbstractFileSystem, str]] = None,
96+
max_chars: Optional[int] = None,
8697
) -> LoadedDocument:
8798
if not isinstance(filepath, str):
8899
filepath = str(filepath)
89100
if isinstance(fs, str):
90101
fs = fsspec.filesystem(fs)
91102
elif fs is None:
92103
fs = LocalFileSystem()
93-
text = self.load_file(filepath, fs)
104+
text = self._apply_max_chars(
105+
self.load_file(filepath, fs), max_chars=max_chars
106+
)
94107
return LoadedDocument(
95108
text=text,
96109
document_info=DocumentInfo(
@@ -121,14 +134,17 @@ async def run(
121134
filepath: Union[str, Path],
122135
metadata: Optional[Dict[str, str]] = None,
123136
fs: Optional[Union[AbstractFileSystem, str]] = None,
137+
max_chars: Optional[int] = None,
124138
) -> LoadedDocument:
125139
if not isinstance(filepath, str):
126140
filepath = str(filepath)
127141
if isinstance(fs, str):
128142
fs = fsspec.filesystem(fs)
129143
elif fs is None:
130144
fs = LocalFileSystem()
131-
text = MarkdownLoader.load_file(filepath, fs)
145+
text = self._apply_max_chars(
146+
MarkdownLoader.load_file(filepath, fs), max_chars=max_chars
147+
)
132148
return LoadedDocument(
133149
text=text,
134150
document_info=DocumentInfo(

src/neo4j_graphrag/experimental/pipeline/config/template_pipeline/simple_kg_builder.py

Lines changed: 15 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -85,14 +85,23 @@ async def run(
8585
filepath: Union[str, Path],
8686
metadata: Optional[dict[str, str]] = None,
8787
fs: Optional[Union[AbstractFileSystem, str]] = None,
88+
max_chars: Optional[int] = None,
8889
) -> LoadedDocument:
8990
path_str = str(filepath)
9091
suffix = Path(path_str).suffix.lower()
9192
if suffix == ".pdf":
92-
return await PdfLoader().run(filepath=path_str, metadata=metadata, fs=fs)
93+
return await PdfLoader().run(
94+
filepath=path_str,
95+
metadata=metadata,
96+
fs=fs,
97+
max_chars=max_chars,
98+
)
9399
if suffix in (".md", ".markdown"):
94100
return await MarkdownLoader().run(
95-
filepath=path_str, metadata=metadata, fs=fs
101+
filepath=path_str,
102+
metadata=metadata,
103+
fs=fs,
104+
max_chars=max_chars,
96105
)
97106
raise UnsupportedDocumentFormatError(
98107
f"Unsupported document format: {suffix!r}. "
@@ -430,6 +439,10 @@ def get_run_params(self, user_input: dict[str, Any]) -> dict[str, Any]:
430439
)
431440
run_params["file_loader"]["filepath"] = file_path
432441
run_params["file_loader"]["metadata"] = user_input.get("document_metadata")
442+
max_chars = user_input.get("max_chars")
443+
# Backward-compatible: only forward new arg for the default loader.
444+
if max_chars is not None and self.file_loader is None:
445+
run_params["file_loader"]["max_chars"] = max_chars
433446
else:
434447
if not text:
435448
raise PipelineDefinitionError(

tests/unit/experimental/components/test_data_loader.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -87,6 +87,14 @@ async def test_markdown_loader_run() -> None:
8787
assert "# Hello" in doc.text
8888

8989

90+
@pytest.mark.asyncio
91+
async def test_markdown_loader_run_max_chars() -> None:
92+
md_path = str(BASE_DIR / "sample_data/hello.md")
93+
loader = MarkdownLoader()
94+
doc = await loader.run(filepath=md_path, max_chars=7)
95+
assert doc.text == "# Hello"
96+
97+
9098
@pytest.mark.asyncio
9199
async def test_pdf_loader_run() -> None:
92100
"""``PdfLoader.run`` wraps ``load_file`` with :class:`DocumentInfo` (default ``fs``)."""
@@ -98,6 +106,14 @@ async def test_pdf_loader_run() -> None:
98106
assert doc.text == "Lorem ipsum dolor sit amet."
99107

100108

109+
@pytest.mark.asyncio
110+
async def test_pdf_loader_run_max_chars() -> None:
111+
pdf_path = str(BASE_DIR / "sample_data/lorem_ipsum.pdf")
112+
loader = PdfLoader()
113+
doc = await loader.run(filepath=pdf_path, max_chars=5)
114+
assert doc.text == "Lorem"
115+
116+
101117
@pytest.mark.asyncio
102118
async def test_pdf_loader_run_fs_string_resolves_with_fsspec(
103119
dummy_pdf_path: str,

tests/unit/experimental/pipeline/config/template_pipeline/test_simple_kg_builder.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -370,6 +370,26 @@ def test_simple_kg_pipeline_config_run_params_from_file_file_path() -> None:
370370
}
371371

372372

373+
def test_simple_kg_pipeline_config_run_params_from_file_file_path_with_max_chars() -> (
374+
None
375+
):
376+
config = SimpleKGPipelineConfig(from_file=True)
377+
assert config.get_run_params({"file_path": "my_file", "max_chars": 42}) == {
378+
"file_loader": {"filepath": "my_file", "metadata": None, "max_chars": 42}
379+
}
380+
381+
382+
def test_simple_kg_pipeline_config_run_params_custom_file_loader_ignores_max_chars() -> (
383+
None
384+
):
385+
config = SimpleKGPipelineConfig(
386+
from_file=True, file_loader=ComponentType(PdfLoader())
387+
)
388+
assert config.get_run_params({"file_path": "my_file", "max_chars": 42}) == {
389+
"file_loader": {"filepath": "my_file", "metadata": None}
390+
}
391+
392+
373393
def test_simple_kg_pipeline_config_run_params_from_text_text() -> None:
374394
config = SimpleKGPipelineConfig(from_file=False)
375395
run_params = config.get_run_params({"text": "my text"})

0 commit comments

Comments
 (0)