Skip to content

Commit 9a74d13

Browse files
dolfim-ibmvagenas
andauthored
feat: extend source resolution with streams and workdir (#79)
Signed-off-by: Michele Dolfi <[email protected]> Signed-off-by: Panos Vagenas <[email protected]> Signed-off-by: Michele Dolfi <[email protected]> Co-authored-by: Panos Vagenas <[email protected]>
1 parent fc1cfb0 commit 9a74d13

File tree

5 files changed

+172
-27
lines changed

5 files changed

+172
-27
lines changed

docling_core/types/io/__init__.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
#
2+
# Copyright IBM Corp. 2024 - 2024
3+
# SPDX-License-Identifier: MIT
4+
#
5+
6+
"""Models for io."""
7+
8+
from io import BytesIO
9+
10+
from pydantic import BaseModel, ConfigDict
11+
12+
13+
class DocumentStream(BaseModel):
14+
"""Wrapper class for a bytes stream with a filename."""
15+
16+
model_config = ConfigDict(arbitrary_types_allowed=True)
17+
18+
name: str
19+
stream: BytesIO

docling_core/utils/file.py

Lines changed: 124 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -7,28 +7,62 @@
77

88
import importlib
99
import tempfile
10+
from io import BytesIO
1011
from pathlib import Path
1112
from typing import Dict, Optional, Union
1213

1314
import requests
1415
from pydantic import AnyHttpUrl, TypeAdapter, ValidationError
16+
from typing_extensions import deprecated
1517

18+
from docling_core.types.io import DocumentStream
1619

17-
def resolve_file_source(
18-
source: Union[Path, AnyHttpUrl, str], headers: Optional[Dict[str, str]] = None
19-
) -> Path:
20-
"""Resolves the source (URL, path) of a file to a local file path.
2120

22-
If a URL is provided, the content is first downloaded to a temporary local file.
21+
def resolve_remote_filename(
22+
http_url: AnyHttpUrl,
23+
response_headers: Dict[str, str],
24+
fallback_filename="file",
25+
) -> str:
26+
"""Resolves the filename from a remote url and its response headers.
27+
28+
Args:
29+
source AnyHttpUrl: The source http url.
30+
response_headers Dict: Headers received while fetching the remote file.
31+
fallback_filename str: Filename to use in case none can be determined.
32+
33+
Returns:
34+
str: The actual filename of the remote url.
35+
"""
36+
fname = None
37+
# try to get filename from response header
38+
if cont_disp := response_headers.get("Content-Disposition"):
39+
for par in cont_disp.strip().split(";"):
40+
# currently only handling directive "filename" (not "*filename")
41+
if (split := par.split("=")) and split[0].strip() == "filename":
42+
fname = "=".join(split[1:]).strip().strip("'\"") or None
43+
break
44+
# otherwise, use name from URL:
45+
if fname is None:
46+
fname = Path(http_url.path or "").name or fallback_filename
47+
48+
return fname
49+
50+
51+
def resolve_source_to_stream(
52+
source: Union[Path, AnyHttpUrl, str], headers: Optional[Dict[str, str]] = None
53+
) -> DocumentStream:
54+
"""Resolves the source (URL, path) of a file to a binary stream.
2355
2456
Args:
2557
source (Path | AnyHttpUrl | str): The file input source. Can be a path or URL.
58+
headers (Dict | None): Optional set of headers to use for fetching
59+
the remote URL.
2660
2761
Raises:
2862
ValueError: If source is of unexpected type.
2963
3064
Returns:
31-
Path: The local file path.
65+
DocumentStream: The resolved file loaded as a stream.
3266
"""
3367
try:
3468
http_url: AnyHttpUrl = TypeAdapter(AnyHttpUrl).validate_python(source)
@@ -44,29 +78,98 @@ def resolve_file_source(
4478
# fetch the page
4579
res = requests.get(http_url, stream=True, headers=req_headers)
4680
res.raise_for_status()
47-
fname = None
48-
# try to get filename from response header
49-
if cont_disp := res.headers.get("Content-Disposition"):
50-
for par in cont_disp.strip().split(";"):
51-
# currently only handling directive "filename" (not "*filename")
52-
if (split := par.split("=")) and split[0].strip() == "filename":
53-
fname = "=".join(split[1:]).strip().strip("'\"") or None
54-
break
55-
# otherwise, use name from URL:
56-
if fname is None:
57-
fname = Path(http_url.path or "").name or "file"
58-
local_path = Path(tempfile.mkdtemp()) / fname
59-
with open(local_path, "wb") as f:
60-
for chunk in res.iter_content(chunk_size=1024): # using 1-KB chunks
61-
f.write(chunk)
81+
fname = resolve_remote_filename(http_url=http_url, response_headers=res.headers)
82+
83+
stream = BytesIO(res.content)
84+
doc_stream = DocumentStream(name=fname, stream=stream)
6285
except ValidationError:
6386
try:
6487
local_path = TypeAdapter(Path).validate_python(source)
88+
stream = BytesIO(local_path.read_bytes())
89+
doc_stream = DocumentStream(name=local_path.name, stream=stream)
6590
except ValidationError:
6691
raise ValueError(f"Unexpected source type encountered: {type(source)}")
92+
return doc_stream
93+
94+
95+
def _resolve_source_to_path(
96+
source: Union[Path, AnyHttpUrl, str],
97+
headers: Optional[Dict[str, str]] = None,
98+
workdir: Optional[Path] = None,
99+
) -> Path:
100+
doc_stream = resolve_source_to_stream(source=source, headers=headers)
101+
102+
# use a temporary directory if not specified
103+
if workdir is None:
104+
workdir = Path(tempfile.mkdtemp())
105+
106+
# create the parent workdir if it doesn't exist
107+
workdir.mkdir(exist_ok=True, parents=True)
108+
109+
# save result to a local file
110+
local_path = workdir / doc_stream.name
111+
with local_path.open("wb") as f:
112+
f.write(doc_stream.stream.read())
113+
67114
return local_path
68115

69116

117+
def resolve_source_to_path(
118+
source: Union[Path, AnyHttpUrl, str],
119+
headers: Optional[Dict[str, str]] = None,
120+
workdir: Optional[Path] = None,
121+
) -> Path:
122+
"""Resolves the source (URL, path) of a file to a local file path.
123+
124+
If a URL is provided, the content is first downloaded to a local file, located in
125+
the provided workdir or in a temporary directory if no workdir provided.
126+
127+
Args:
128+
source (Path | AnyHttpUrl | str): The file input source. Can be a path or URL.
129+
headers (Dict | None): Optional set of headers to use for fetching
130+
the remote URL.
131+
workdir (Path | None): If set, the work directory where the file will
132+
be downloaded, otherwise a temp dir will be used.
133+
134+
Raises:
135+
ValueError: If source is of unexpected type.
136+
137+
Returns:
138+
Path: The local file path.
139+
"""
140+
return _resolve_source_to_path(
141+
source=source,
142+
headers=headers,
143+
workdir=workdir,
144+
)
145+
146+
147+
@deprecated("Use `resolve_source_to_path()` or `resolve_source_to_stream()` instead")
148+
def resolve_file_source(
149+
source: Union[Path, AnyHttpUrl, str],
150+
headers: Optional[Dict[str, str]] = None,
151+
) -> Path:
152+
"""Resolves the source (URL, path) of a file to a local file path.
153+
154+
If a URL is provided, the content is first downloaded to a temporary local file.
155+
156+
Args:
157+
source (Path | AnyHttpUrl | str): The file input source. Can be a path or URL.
158+
headers (Dict | None): Optional set of headers to use for fetching
159+
the remote URL.
160+
161+
Raises:
162+
ValueError: If source is of unexpected type.
163+
164+
Returns:
165+
Path: The local file path.
166+
"""
167+
return _resolve_source_to_path(
168+
source=source,
169+
headers=headers,
170+
)
171+
172+
70173
def relative_path(src: Path, target: Path) -> Path:
71174
"""Compute the relative path from `src` to `target`.
72175

poetry.lock

Lines changed: 3 additions & 3 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,7 @@ tabulate = "^0.9.0"
5454
pandas = "^2.1.4"
5555
pillow = "^10.3.0"
5656
pyyaml = ">=5.1,<7.0.0"
57+
typing-extensions = "^4.12.2"
5758

5859
[tool.poetry.group.dev.dependencies]
5960
black = "^24.4.2"

test/test_utils.py

Lines changed: 25 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@
1010
from requests import Response
1111

1212
from docling_core.utils.alias import AliasModel
13-
from docling_core.utils.file import resolve_file_source
13+
from docling_core.utils.file import resolve_source_to_path, resolve_source_to_stream
1414

1515

1616
def test_alias_model():
@@ -51,7 +51,7 @@ class AliasModelGrandChild(AliasModelChild):
5151
assert obj.model_dump_json() != json.dumps(data, separators=(",", ":"))
5252

5353

54-
def test_resolve_file_source_url_wout_path(monkeypatch):
54+
def test_resolve_source_to_path_url_wout_path(monkeypatch):
5555
expected_str = "foo"
5656
expected_bytes = bytes(expected_str, "utf-8")
5757

@@ -66,7 +66,29 @@ def get_dummy_response(*args, **kwargs):
6666
"requests.models.Response.iter_content",
6767
lambda *args, **kwargs: [expected_bytes],
6868
)
69-
path = resolve_file_source("https://pypi.org")
69+
path = resolve_source_to_path("https://pypi.org")
7070
with open(path) as f:
7171
text = f.read()
7272
assert text == expected_str
73+
74+
75+
def test_resolve_source_to_stream_url_wout_path(monkeypatch):
76+
expected_str = "foo"
77+
expected_bytes = bytes(expected_str, "utf-8")
78+
79+
def get_dummy_response(*args, **kwargs):
80+
r = Response()
81+
r.status_code = 200
82+
r._content = expected_bytes
83+
return r
84+
85+
monkeypatch.setattr("requests.get", get_dummy_response)
86+
monkeypatch.setattr(
87+
"requests.models.Response.iter_content",
88+
lambda *args, **kwargs: [expected_bytes],
89+
)
90+
doc_stream = resolve_source_to_stream("https://pypi.org")
91+
assert doc_stream.name == "file"
92+
93+
text = doc_stream.stream.read().decode("utf8")
94+
assert text == expected_str

0 commit comments

Comments
 (0)