Skip to content

Commit 75d94ab

Browse files
authored
fix: add handling for str filenames in save/load methods (#205)
* fix: add handling for str filenames in save/load methods Signed-off-by: Musashi Hinck <[email protected]> * fix: add handling for str filenames in save/load methods Signed-off-by: Musashi Hinck <[email protected]> --------- Signed-off-by: Musashi Hinck <[email protected]> Co-authored-by: Musashi Hinck <[email protected]>
1 parent 510649e commit 75d94ab

File tree

2 files changed

+42
-14
lines changed

2 files changed

+42
-14
lines changed

docling_core/types/doc/document.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2410,12 +2410,14 @@ def export_to_element_tree(self) -> str:
24102410

24112411
def save_as_json(
24122412
self,
2413-
filename: Path,
2413+
filename: Union[str, Path],
24142414
artifacts_dir: Optional[Path] = None,
24152415
image_mode: ImageRefMode = ImageRefMode.EMBEDDED,
24162416
indent: int = 2,
24172417
):
24182418
"""Save as json."""
2419+
if isinstance(filename, str):
2420+
filename = Path(filename)
24192421
artifacts_dir, reference_path = self._get_output_paths(filename, artifacts_dir)
24202422

24212423
if image_mode == ImageRefMode.REFERENCED:
@@ -2430,7 +2432,7 @@ def save_as_json(
24302432
json.dump(out, fw, indent=indent)
24312433

24322434
@classmethod
2433-
def load_from_json(cls, filename: Path) -> "DoclingDocument":
2435+
def load_from_json(cls, filename: Union[str, Path]) -> "DoclingDocument":
24342436
"""load_from_json.
24352437
24362438
:param filename: The filename to load a saved DoclingDocument from a .json.
@@ -2440,17 +2442,21 @@ def load_from_json(cls, filename: Path) -> "DoclingDocument":
24402442
:rtype: DoclingDocument
24412443
24422444
"""
2445+
if isinstance(filename, str):
2446+
filename = Path(filename)
24432447
with open(filename, "r", encoding="utf-8") as f:
24442448
return cls.model_validate_json(f.read())
24452449

24462450
def save_as_yaml(
24472451
self,
2448-
filename: Path,
2452+
filename: Union[str, Path],
24492453
artifacts_dir: Optional[Path] = None,
24502454
image_mode: ImageRefMode = ImageRefMode.EMBEDDED,
24512455
default_flow_style: bool = False,
24522456
):
24532457
"""Save as yaml."""
2458+
if isinstance(filename, str):
2459+
filename = Path(filename)
24542460
artifacts_dir, reference_path = self._get_output_paths(filename, artifacts_dir)
24552461

24562462
if image_mode == ImageRefMode.REFERENCED:
@@ -2465,7 +2471,7 @@ def save_as_yaml(
24652471
yaml.dump(out, fw, default_flow_style=default_flow_style)
24662472

24672473
@classmethod
2468-
def load_from_yaml(cls, filename: Path) -> "DoclingDocument":
2474+
def load_from_yaml(cls, filename: Union[str, Path]) -> "DoclingDocument":
24692475
"""load_from_yaml.
24702476
24712477
Args:
@@ -2474,6 +2480,8 @@ def load_from_yaml(cls, filename: Path) -> "DoclingDocument":
24742480
Returns:
24752481
DoclingDocument: the loaded DoclingDocument
24762482
"""
2483+
if isinstance(filename, str):
2484+
filename = Path(filename)
24772485
with open(filename, encoding="utf-8") as f:
24782486
data = yaml.load(f, Loader=yaml.FullLoader)
24792487
return DoclingDocument.model_validate(data)
@@ -2491,7 +2499,7 @@ def export_to_dict(
24912499

24922500
def save_as_markdown(
24932501
self,
2494-
filename: Path,
2502+
filename: Union[str, Path],
24952503
artifacts_dir: Optional[Path] = None,
24962504
delim: str = "\n\n",
24972505
from_element: int = 0,
@@ -2507,6 +2515,8 @@ def save_as_markdown(
25072515
included_content_layers: set[ContentLayer] = DEFAULT_CONTENT_LAYERS,
25082516
):
25092517
"""Save to markdown."""
2518+
if isinstance(filename, str):
2519+
filename = Path(filename)
25102520
artifacts_dir, reference_path = self._get_output_paths(filename, artifacts_dir)
25112521

25122522
if image_mode == ImageRefMode.REFERENCED:
@@ -2634,7 +2644,7 @@ def export_to_text( # noqa: C901
26342644

26352645
def save_as_html(
26362646
self,
2637-
filename: Path,
2647+
filename: Union[str, Path],
26382648
artifacts_dir: Optional[Path] = None,
26392649
from_element: int = 0,
26402650
to_element: int = sys.maxsize,
@@ -2647,6 +2657,8 @@ def save_as_html(
26472657
included_content_layers: set[ContentLayer] = DEFAULT_CONTENT_LAYERS,
26482658
):
26492659
"""Save to HTML."""
2660+
if isinstance(filename, str):
2661+
filename = Path(filename)
26502662
artifacts_dir, reference_path = self._get_output_paths(filename, artifacts_dir)
26512663

26522664
if image_mode == ImageRefMode.REFERENCED:
@@ -2672,8 +2684,10 @@ def save_as_html(
26722684
fw.write(html_out)
26732685

26742686
def _get_output_paths(
2675-
self, filename: Path, artifacts_dir: Optional[Path] = None
2687+
self, filename: Union[str, Path], artifacts_dir: Optional[Path] = None
26762688
) -> Tuple[Path, Optional[Path]]:
2689+
if isinstance(filename, str):
2690+
filename = Path(filename)
26772691
if artifacts_dir is None:
26782692
# Remove the extension and add '_pictures'
26792693
artifacts_dir = filename.with_suffix("")
@@ -3455,7 +3469,7 @@ def save_as_document_tokens(self, *args, **kwargs):
34553469

34563470
def save_as_doctags(
34573471
self,
3458-
filename: Path,
3472+
filename: Union[str, Path],
34593473
delim: str = "",
34603474
from_element: int = 0,
34613475
to_element: int = sys.maxsize,
@@ -3470,6 +3484,8 @@ def save_as_doctags(
34703484
add_table_cell_text: bool = True,
34713485
):
34723486
r"""Save the document content to DocTags format."""
3487+
if isinstance(filename, str):
3488+
filename = Path(filename)
34733489
out = self.export_to_document_tokens(
34743490
delim=delim,
34753491
from_element=from_element,

docling_core/types/doc/page.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -546,7 +546,7 @@ def export_to_dict(self) -> Dict:
546546

547547
def save_as_json(
548548
self,
549-
filename: Path,
549+
filename: Union[str, Path],
550550
indent: int = 2,
551551
):
552552
"""Save the page data as a JSON file.
@@ -555,12 +555,14 @@ def save_as_json(
555555
filename: Path to save the JSON file
556556
indent: Indentation level for JSON formatting
557557
"""
558+
if isinstance(filename, str):
559+
filename = Path(filename)
558560
out = self.export_to_dict()
559561
with open(filename, "w", encoding="utf-8") as fw:
560562
json.dump(out, fw, indent=indent)
561563

562564
@classmethod
563-
def load_from_json(cls, filename: Path) -> "SegmentedPdfPage":
565+
def load_from_json(cls, filename: Union[str, Path]) -> "SegmentedPdfPage":
564566
"""Load page data from a JSON file.
565567
566568
Args:
@@ -569,6 +571,8 @@ def load_from_json(cls, filename: Path) -> "SegmentedPdfPage":
569571
Returns:
570572
Instantiated SegmentedPdfPage object
571573
"""
574+
if isinstance(filename, str):
575+
filename = Path(filename)
572576
with open(filename, "r", encoding="utf-8") as f:
573577
return cls.model_validate_json(f.read())
574578

@@ -1155,19 +1159,21 @@ def export_to_dict(self, mode: str = "json") -> Dict:
11551159
"""
11561160
return self.model_dump(mode=mode, by_alias=True, exclude_none=True)
11571161

1158-
def save_as_json(self, filename: Path, indent: int = 2):
1162+
def save_as_json(self, filename: Union[str, Path], indent: int = 2):
11591163
"""Save the table of contents as a JSON file.
11601164
11611165
Args:
11621166
filename: Path to save the JSON file
11631167
indent: Indentation level for JSON formatting
11641168
"""
1169+
if isinstance(filename, str):
1170+
filename = Path(filename)
11651171
out = self.export_to_dict()
11661172
with open(filename, "w", encoding="utf-8") as fw:
11671173
json.dump(out, fw, indent=indent)
11681174

11691175
@classmethod
1170-
def load_from_json(cls, filename: Path) -> "PdfTableOfContents":
1176+
def load_from_json(cls, filename: Union[str, Path]) -> "PdfTableOfContents":
11711177
"""Load table of contents from a JSON file.
11721178
11731179
Args:
@@ -1176,6 +1182,8 @@ def load_from_json(cls, filename: Path) -> "PdfTableOfContents":
11761182
Returns:
11771183
Instantiated PdfTableOfContents object
11781184
"""
1185+
if isinstance(filename, str):
1186+
filename = Path(filename)
11791187
with open(filename, "r", encoding="utf-8") as f:
11801188
return cls.model_validate_json(f.read())
11811189

@@ -1213,19 +1221,21 @@ def export_to_dict(
12131221
"""
12141222
return self.model_dump(mode=mode, by_alias=True, exclude_none=True)
12151223

1216-
def save_as_json(self, filename: Path, indent: int = 2):
1224+
def save_as_json(self, filename: Union[str, Path], indent: int = 2):
12171225
"""Save the document as a JSON file.
12181226
12191227
Args:
12201228
filename: Path to save the JSON file
12211229
indent: Indentation level for JSON formatting
12221230
"""
1231+
if isinstance(filename, str):
1232+
filename = Path(filename)
12231233
out = self.export_to_dict()
12241234
with open(filename, "w", encoding="utf-8") as fw:
12251235
json.dump(out, fw, indent=indent)
12261236

12271237
@classmethod
1228-
def load_from_json(cls, filename: Path) -> "ParsedPdfDocument":
1238+
def load_from_json(cls, filename: Union[str, Path]) -> "ParsedPdfDocument":
12291239
"""Load document from a JSON file.
12301240
12311241
Args:
@@ -1234,5 +1244,7 @@ def load_from_json(cls, filename: Path) -> "ParsedPdfDocument":
12341244
Returns:
12351245
Instantiated ParsedPdfDocument object
12361246
"""
1247+
if isinstance(filename, str):
1248+
filename = Path(filename)
12371249
with open(filename, "r", encoding="utf-8") as f:
12381250
return cls.model_validate_json(f.read())

0 commit comments

Comments
 (0)