Skip to content

Commit b1525b6

Browse files
Merge branch 'main' into nli/external_predictions
2 parents 33511c9 + 15888fd commit b1525b6

File tree

3 files changed

+202
-49
lines changed

3 files changed

+202
-49
lines changed

docling_eval/campaign_tools/cvat_deliveries_to_hf.py

Lines changed: 149 additions & 37 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
from collections import defaultdict
88
from dataclasses import dataclass
99
from enum import Enum
10+
from fnmatch import fnmatch
1011
from io import BytesIO
1112
from pathlib import Path
1213
from typing import Dict, Iterable, List, Sequence, Set
@@ -55,9 +56,12 @@ class CombinedBuildStats:
5556
submission_count: int
5657
subsets: List[str]
5758

58-
def as_config(self, dataset_dir_name: str, split: str) -> ConfigEntry:
59+
def as_config(
60+
self, dataset_dir_name: str, split: str, config_name: str | None = None
61+
) -> ConfigEntry:
62+
name = config_name if config_name is not None else "default"
5963
pattern = f"{dataset_dir_name}/{split}/shard_*.parquet"
60-
return ConfigEntry(name="default", split=split, path_pattern=pattern)
64+
return ConfigEntry(name=name, split=split, path_pattern=pattern)
6165

6266

6367
@dataclass(frozen=True)
@@ -67,6 +71,18 @@ class DocumentAsset:
6771
doc_hash: str | None
6872

6973

74+
@dataclass(frozen=True)
75+
class SplitRule:
76+
pattern: str
77+
split: str
78+
79+
80+
@dataclass(frozen=True)
81+
class BuildResult:
82+
config: ConfigEntry
83+
stats: CombinedBuildStats
84+
85+
7086
def _find_overview_path(assets_root: Path) -> Path | None:
7187
for candidate_name in ("cvat_annotation_overview.json", "cvat_overview.json"):
7288
candidate_path = assets_root / candidate_name
@@ -229,6 +245,57 @@ def discover_subset_sources(
229245
return subset_dirs
230246

231247

248+
def _parse_subset_split_rules(rules: List[str]) -> List[SplitRule]:
249+
parsed_rules: List[SplitRule] = []
250+
for rule in rules:
251+
if "=" not in rule:
252+
raise typer.BadParameter(
253+
f"Invalid --subset-split value '{rule}'. Expected format pattern=split."
254+
)
255+
pattern, split = rule.split("=", maxsplit=1)
256+
if not pattern or not split:
257+
raise typer.BadParameter(
258+
f"Invalid --subset-split value '{rule}'. Both pattern and split are required."
259+
)
260+
parsed_rules.append(SplitRule(pattern=pattern, split=split))
261+
return parsed_rules
262+
263+
264+
def _route_subsets_to_splits(
265+
subset_sources: Dict[str, List[Path]],
266+
rules: List[SplitRule],
267+
default_split: str,
268+
*,
269+
fail_on_unmatched: bool,
270+
) -> Dict[str, Dict[str, List[Path]]]:
271+
split_map: Dict[str, Dict[str, List[Path]]] = defaultdict(dict)
272+
273+
for subset_name, sources in subset_sources.items():
274+
matched_split = None
275+
for rule in rules:
276+
if fnmatch(subset_name, rule.pattern):
277+
matched_split = rule.split
278+
break
279+
280+
if matched_split is None:
281+
if fail_on_unmatched:
282+
raise RuntimeError(
283+
f"Subset '{subset_name}' did not match any --subset-split rule "
284+
"and --fail-on-unmatched was provided."
285+
)
286+
matched_split = default_split
287+
288+
split_map[matched_split][subset_name] = sources
289+
290+
return split_map
291+
292+
293+
def _dataset_dir_for_split(base_name: str, split: str, multi_split: bool) -> str:
294+
if not multi_split:
295+
return base_name
296+
return f"{base_name}-{split}"
297+
298+
232299
def ensure_clean_dir(path: Path) -> None:
233300
if path.exists():
234301
shutil.rmtree(path)
@@ -448,38 +515,40 @@ def build_combined_dataset(
448515
)
449516

450517

451-
def render_configs_block(config: ConfigEntry) -> List[str]:
518+
def render_configs_block(configs: Sequence[ConfigEntry]) -> List[str]:
452519
lines: List[str] = ["configs:"]
453-
lines.append(f"- config_name: {config.name}")
454-
lines.append(" data_files:")
455-
lines.append(f" - split: {config.split}")
456-
lines.append(f" path: {config.path_pattern}")
520+
for config in configs:
521+
lines.append(f"- config_name: {config.name}")
522+
lines.append(" data_files:")
523+
lines.append(f" - split: {config.split}")
524+
lines.append(f" path: {config.path_pattern}")
457525
return lines
458526

459527

460-
def render_dataset_info_block(config: ConfigEntry) -> List[str]:
528+
def render_dataset_info_block(configs: Sequence[ConfigEntry]) -> List[str]:
461529
lines: List[str] = ["dataset_info:"]
462-
feature_rows = iter_dataset_features()
463-
lines.append(f"- config_name: {config.name}")
464-
lines.append(" features:")
465-
for feature_name, attr, value in feature_rows:
466-
lines.append(f" - name: {feature_name}")
467-
lines.append(f" {attr}: {value}")
468-
lines.append("")
530+
for config in configs:
531+
feature_rows = iter_dataset_features()
532+
lines.append(f"- config_name: {config.name}")
533+
lines.append(" features:")
534+
for feature_name, attr, value in feature_rows:
535+
lines.append(f" - name: {feature_name}")
536+
lines.append(f" {attr}: {value}")
537+
lines.append("")
469538
return lines
470539

471540

472541
def write_readme(
473542
output_root: Path,
474-
config: ConfigEntry,
475-
stats: CombinedBuildStats,
543+
builds: Sequence[BuildResult],
476544
license_name: str,
477545
export_kind: DeliveryExportKind,
478546
custom_dirname: str | None = None,
479547
) -> None:
480548
lines: List[str] = ["---"]
481-
lines.extend(render_configs_block(config))
482-
lines.extend(render_dataset_info_block(config))
549+
configs = [build.config for build in builds]
550+
lines.extend(render_configs_block(configs))
551+
lines.extend(render_dataset_info_block(configs))
483552
lines.append(f"license: {license_name}")
484553
lines.append("---")
485554
lines.append("")
@@ -504,9 +573,12 @@ def write_readme(
504573
)
505574
lines.append("")
506575
lines.append("## Dataset Statistics")
507-
lines.append(f"- Total records: {stats.record_count}")
508-
lines.append(f"- Total submissions: {stats.submission_count}")
509-
lines.append(f"- Subsets included: {', '.join(f'`{s}`' for s in stats.subsets)}")
576+
for build in builds:
577+
lines.append(
578+
f"- `{build.config.name}` ({build.config.split}): "
579+
f"{build.stats.record_count} records from {build.stats.submission_count} submissions; "
580+
f"subsets: {', '.join(f'`{s}`' for s in build.stats.subsets)}"
581+
)
510582
readme_path = output_root / "README.md"
511583
with readme_path.open("w", encoding="utf-8") as handle:
512584
handle.write("\n".join(lines).rstrip() + "\n")
@@ -565,6 +637,22 @@ def main(
565637
"(default: predictions_json). Only used when --export-kind is predictions."
566638
),
567639
),
640+
subset_split: List[str] | None = typer.Option(
641+
None,
642+
"--subset-split",
643+
help=(
644+
"Route subsets to splits using pattern=split (fnmatch). "
645+
"Example: --subset-split pdf_val=validation --subset-split 'pdf_train_*'=train"
646+
),
647+
),
648+
fail_on_unmatched: bool = typer.Option(
649+
False,
650+
"--fail-on-unmatched/--allow-unmatched",
651+
help=(
652+
"Error if a subset does not match any --subset-split rule instead of "
653+
"falling back to the default --split."
654+
),
655+
),
568656
datasets_root: Path | None = typer.Option(
569657
None,
570658
"--datasets-root",
@@ -594,6 +682,7 @@ def main(
594682
output_dir = output_dir.expanduser().resolve()
595683
datasets_root = datasets_root.expanduser().resolve() if datasets_root else None
596684
staging_root = output_dir / "_staging"
685+
subset_split_rules = _parse_subset_split_rules(subset_split) if subset_split else []
597686

598687
if not deliveries_root.exists():
599688
raise typer.BadParameter(f"{deliveries_root} does not exist.")
@@ -645,30 +734,53 @@ def main(
645734
)
646735

647736
staging_dir = staging_root / "combined"
648-
stats = build_combined_dataset(
649-
subset_sources=sorted_subset_sources,
650-
staging_dir=staging_dir,
651-
output_root=output_dir,
652-
dataset_dir_name=dataset_dir_name,
653-
split=split,
654-
chunk_size=chunk_size,
655-
export_kind=export_kind,
656-
force=force,
657-
subset_assets=subset_assets,
658-
assets_required=datasets_root is not None,
737+
split_map = (
738+
_route_subsets_to_splits(
739+
subset_sources=sorted_subset_sources,
740+
rules=subset_split_rules,
741+
default_split=split,
742+
fail_on_unmatched=fail_on_unmatched,
743+
)
744+
if subset_split_rules
745+
else {split: sorted_subset_sources}
659746
)
660747

748+
build_results: List[BuildResult] = []
749+
multi_split = len(split_map) > 1
750+
751+
for split_name in sorted(split_map.keys()):
752+
staging_dir = staging_root / f"combined-{split_name}"
753+
dataset_dir = _dataset_dir_for_split(dataset_dir_name, split_name, multi_split)
754+
stats = build_combined_dataset(
755+
subset_sources=split_map[split_name],
756+
staging_dir=staging_dir,
757+
output_root=output_dir,
758+
dataset_dir_name=dataset_dir,
759+
split=split_name,
760+
chunk_size=chunk_size,
761+
export_kind=export_kind,
762+
force=force,
763+
subset_assets=subset_assets,
764+
assets_required=datasets_root is not None,
765+
)
766+
if stats is not None:
767+
config_name = split_name if multi_split or subset_split_rules else None
768+
config = stats.as_config(
769+
dataset_dir_name=dataset_dir,
770+
split=split_name,
771+
config_name=config_name,
772+
)
773+
build_results.append(BuildResult(config=config, stats=stats))
774+
661775
shutil.rmtree(staging_root, ignore_errors=True)
662776

663-
if not stats:
777+
if not build_results:
664778
typer.echo("No datasets were produced.", err=True)
665779
raise typer.Exit(code=1)
666780

667-
config = stats.as_config(dataset_dir_name, split)
668781
write_readme(
669782
output_root=output_dir,
670-
config=config,
671-
stats=stats,
783+
builds=build_results,
672784
license_name=license_name,
673785
export_kind=export_kind,
674786
custom_dirname=custom_dirname,

docling_eval/cli/main.py

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939
from docling.datamodel.vlm_model_specs import (
4040
SMOLDOCLING_TRANSFORMERS as smoldocling_vlm_conversion_options,
4141
)
42-
from docling.document_converter import FormatOption, PdfFormatOption
42+
from docling.document_converter import FormatOption, ImageFormatOption, PdfFormatOption
4343
from docling.models.factories import get_ocr_factory
4444
from docling.pipeline.vlm_pipeline import VlmPipeline
4545
from PyPDF2 import PdfReader, PdfWriter
@@ -414,7 +414,7 @@ def get_prediction_provider(
414414
return DoclingPredictionProvider(
415415
format_options={
416416
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options),
417-
InputFormat.IMAGE: PdfFormatOption(pipeline_options=pipeline_options),
417+
InputFormat.IMAGE: ImageFormatOption(pipeline_options=pipeline_options),
418418
},
419419
do_visualization=do_visualization,
420420
ignore_missing_predictions=True,
@@ -444,7 +444,7 @@ def get_prediction_provider(
444444
return DoclingPredictionProvider(
445445
format_options={
446446
InputFormat.PDF: PdfFormatOption(pipeline_options=pipeline_options),
447-
InputFormat.IMAGE: PdfFormatOption(pipeline_options=pipeline_options),
447+
InputFormat.IMAGE: ImageFormatOption(pipeline_options=pipeline_options),
448448
},
449449
do_visualization=do_visualization,
450450
ignore_missing_predictions=True,
@@ -493,7 +493,7 @@ def get_prediction_provider(
493493
return DoclingPredictionProvider(
494494
format_options={
495495
InputFormat.PDF: PdfFormatOption(pipeline_options=pdf_pipeline_options),
496-
InputFormat.IMAGE: PdfFormatOption(
496+
InputFormat.IMAGE: ImageFormatOption(
497497
pipeline_options=ocr_pipeline_options
498498
),
499499
},
@@ -528,10 +528,14 @@ def get_prediction_provider(
528528
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
529529
)
530530

531+
image_format_option = ImageFormatOption(
532+
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
533+
)
534+
531535
return DoclingPredictionProvider(
532536
format_options={
533537
InputFormat.PDF: pdf_format_option,
534-
InputFormat.IMAGE: pdf_format_option,
538+
InputFormat.IMAGE: image_format_option,
535539
},
536540
do_visualization=do_visualization,
537541
ignore_missing_predictions=True,
@@ -575,10 +579,14 @@ def get_prediction_provider(
575579
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
576580
)
577581

582+
image_format_option = ImageFormatOption(
583+
pipeline_cls=VlmPipeline, pipeline_options=pipeline_options
584+
)
585+
578586
return DoclingPredictionProvider(
579587
format_options={
580588
InputFormat.PDF: pdf_format_option,
581-
InputFormat.IMAGE: pdf_format_option,
589+
InputFormat.IMAGE: image_format_option,
582590
},
583591
do_visualization=do_visualization,
584592
ignore_missing_predictions=True,

0 commit comments

Comments
 (0)