77from collections import defaultdict
88from dataclasses import dataclass
99from enum import Enum
10+ from fnmatch import fnmatch
1011from io import BytesIO
1112from pathlib import Path
1213from typing import Dict , Iterable , List , Sequence , Set
@@ -55,9 +56,12 @@ class CombinedBuildStats:
5556 submission_count : int
5657 subsets : List [str ]
5758
58- def as_config (self , dataset_dir_name : str , split : str ) -> ConfigEntry :
59+ def as_config (
60+ self , dataset_dir_name : str , split : str , config_name : str | None = None
61+ ) -> ConfigEntry :
62+ name = config_name if config_name is not None else "default"
5963 pattern = f"{ dataset_dir_name } /{ split } /shard_*.parquet"
60- return ConfigEntry (name = "default" , split = split , path_pattern = pattern )
64+ return ConfigEntry (name = name , split = split , path_pattern = pattern )
6165
6266
6367@dataclass (frozen = True )
@@ -67,6 +71,18 @@ class DocumentAsset:
6771 doc_hash : str | None
6872
6973
74+ @dataclass (frozen = True )
75+ class SplitRule :
76+ pattern : str
77+ split : str
78+
79+
80+ @dataclass (frozen = True )
81+ class BuildResult :
82+ config : ConfigEntry
83+ stats : CombinedBuildStats
84+
85+
7086def _find_overview_path (assets_root : Path ) -> Path | None :
7187 for candidate_name in ("cvat_annotation_overview.json" , "cvat_overview.json" ):
7288 candidate_path = assets_root / candidate_name
@@ -229,6 +245,57 @@ def discover_subset_sources(
229245 return subset_dirs
230246
231247
248+ def _parse_subset_split_rules (rules : List [str ]) -> List [SplitRule ]:
249+ parsed_rules : List [SplitRule ] = []
250+ for rule in rules :
251+ if "=" not in rule :
252+ raise typer .BadParameter (
253+ f"Invalid --subset-split value '{ rule } '. Expected format pattern=split."
254+ )
255+ pattern , split = rule .split ("=" , maxsplit = 1 )
256+ if not pattern or not split :
257+ raise typer .BadParameter (
258+ f"Invalid --subset-split value '{ rule } '. Both pattern and split are required."
259+ )
260+ parsed_rules .append (SplitRule (pattern = pattern , split = split ))
261+ return parsed_rules
262+
263+
264+ def _route_subsets_to_splits (
265+ subset_sources : Dict [str , List [Path ]],
266+ rules : List [SplitRule ],
267+ default_split : str ,
268+ * ,
269+ fail_on_unmatched : bool ,
270+ ) -> Dict [str , Dict [str , List [Path ]]]:
271+ split_map : Dict [str , Dict [str , List [Path ]]] = defaultdict (dict )
272+
273+ for subset_name , sources in subset_sources .items ():
274+ matched_split = None
275+ for rule in rules :
276+ if fnmatch (subset_name , rule .pattern ):
277+ matched_split = rule .split
278+ break
279+
280+ if matched_split is None :
281+ if fail_on_unmatched :
282+ raise RuntimeError (
283+ f"Subset '{ subset_name } ' did not match any --subset-split rule "
284+ "and --fail-on-unmatched was provided."
285+ )
286+ matched_split = default_split
287+
288+ split_map [matched_split ][subset_name ] = sources
289+
290+ return split_map
291+
292+
293+ def _dataset_dir_for_split (base_name : str , split : str , multi_split : bool ) -> str :
294+ if not multi_split :
295+ return base_name
296+ return f"{ base_name } -{ split } "
297+
298+
232299def ensure_clean_dir (path : Path ) -> None :
233300 if path .exists ():
234301 shutil .rmtree (path )
@@ -448,38 +515,40 @@ def build_combined_dataset(
448515 )
449516
450517
451- def render_configs_block (config : ConfigEntry ) -> List [str ]:
518+ def render_configs_block (configs : Sequence [ ConfigEntry ] ) -> List [str ]:
452519 lines : List [str ] = ["configs:" ]
453- lines .append (f"- config_name: { config .name } " )
454- lines .append (" data_files:" )
455- lines .append (f" - split: { config .split } " )
456- lines .append (f" path: { config .path_pattern } " )
520+ for config in configs :
521+ lines .append (f"- config_name: { config .name } " )
522+ lines .append (" data_files:" )
523+ lines .append (f" - split: { config .split } " )
524+ lines .append (f" path: { config .path_pattern } " )
457525 return lines
458526
459527
460- def render_dataset_info_block (config : ConfigEntry ) -> List [str ]:
528+ def render_dataset_info_block (configs : Sequence [ ConfigEntry ] ) -> List [str ]:
461529 lines : List [str ] = ["dataset_info:" ]
462- feature_rows = iter_dataset_features ()
463- lines .append (f"- config_name: { config .name } " )
464- lines .append (" features:" )
465- for feature_name , attr , value in feature_rows :
466- lines .append (f" - name: { feature_name } " )
467- lines .append (f" { attr } : { value } " )
468- lines .append ("" )
530+ for config in configs :
531+ feature_rows = iter_dataset_features ()
532+ lines .append (f"- config_name: { config .name } " )
533+ lines .append (" features:" )
534+ for feature_name , attr , value in feature_rows :
535+ lines .append (f" - name: { feature_name } " )
536+ lines .append (f" { attr } : { value } " )
537+ lines .append ("" )
469538 return lines
470539
471540
472541def write_readme (
473542 output_root : Path ,
474- config : ConfigEntry ,
475- stats : CombinedBuildStats ,
543+ builds : Sequence [BuildResult ],
476544 license_name : str ,
477545 export_kind : DeliveryExportKind ,
478546 custom_dirname : str | None = None ,
479547) -> None :
480548 lines : List [str ] = ["---" ]
481- lines .extend (render_configs_block (config ))
482- lines .extend (render_dataset_info_block (config ))
549+ configs = [build .config for build in builds ]
550+ lines .extend (render_configs_block (configs ))
551+ lines .extend (render_dataset_info_block (configs ))
483552 lines .append (f"license: { license_name } " )
484553 lines .append ("---" )
485554 lines .append ("" )
@@ -504,9 +573,12 @@ def write_readme(
504573 )
505574 lines .append ("" )
506575 lines .append ("## Dataset Statistics" )
507- lines .append (f"- Total records: { stats .record_count } " )
508- lines .append (f"- Total submissions: { stats .submission_count } " )
509- lines .append (f"- Subsets included: { ', ' .join (f'`{ s } `' for s in stats .subsets )} " )
576+ for build in builds :
577+ lines .append (
578+ f"- `{ build .config .name } ` ({ build .config .split } ): "
579+ f"{ build .stats .record_count } records from { build .stats .submission_count } submissions; "
580+ f"subsets: { ', ' .join (f'`{ s } `' for s in build .stats .subsets )} "
581+ )
510582 readme_path = output_root / "README.md"
511583 with readme_path .open ("w" , encoding = "utf-8" ) as handle :
512584 handle .write ("\n " .join (lines ).rstrip () + "\n " )
@@ -565,6 +637,22 @@ def main(
565637 "(default: predictions_json). Only used when --export-kind is predictions."
566638 ),
567639 ),
640+ subset_split : List [str ] | None = typer .Option (
641+ None ,
642+ "--subset-split" ,
643+ help = (
644+ "Route subsets to splits using pattern=split (fnmatch). "
645+ "Example: --subset-split pdf_val=validation --subset-split 'pdf_train_*'=train"
646+ ),
647+ ),
648+ fail_on_unmatched : bool = typer .Option (
649+ False ,
650+ "--fail-on-unmatched/--allow-unmatched" ,
651+ help = (
652+ "Error if a subset does not match any --subset-split rule instead of "
653+ "falling back to the default --split."
654+ ),
655+ ),
568656 datasets_root : Path | None = typer .Option (
569657 None ,
570658 "--datasets-root" ,
@@ -594,6 +682,7 @@ def main(
594682 output_dir = output_dir .expanduser ().resolve ()
595683 datasets_root = datasets_root .expanduser ().resolve () if datasets_root else None
596684 staging_root = output_dir / "_staging"
685+ subset_split_rules = _parse_subset_split_rules (subset_split ) if subset_split else []
597686
598687 if not deliveries_root .exists ():
599688 raise typer .BadParameter (f"{ deliveries_root } does not exist." )
@@ -645,30 +734,53 @@ def main(
645734 )
646735
647736 staging_dir = staging_root / "combined"
648- stats = build_combined_dataset (
649- subset_sources = sorted_subset_sources ,
650- staging_dir = staging_dir ,
651- output_root = output_dir ,
652- dataset_dir_name = dataset_dir_name ,
653- split = split ,
654- chunk_size = chunk_size ,
655- export_kind = export_kind ,
656- force = force ,
657- subset_assets = subset_assets ,
658- assets_required = datasets_root is not None ,
737+ split_map = (
738+ _route_subsets_to_splits (
739+ subset_sources = sorted_subset_sources ,
740+ rules = subset_split_rules ,
741+ default_split = split ,
742+ fail_on_unmatched = fail_on_unmatched ,
743+ )
744+ if subset_split_rules
745+ else {split : sorted_subset_sources }
659746 )
660747
748+ build_results : List [BuildResult ] = []
749+ multi_split = len (split_map ) > 1
750+
751+ for split_name in sorted (split_map .keys ()):
752+ staging_dir = staging_root / f"combined-{ split_name } "
753+ dataset_dir = _dataset_dir_for_split (dataset_dir_name , split_name , multi_split )
754+ stats = build_combined_dataset (
755+ subset_sources = split_map [split_name ],
756+ staging_dir = staging_dir ,
757+ output_root = output_dir ,
758+ dataset_dir_name = dataset_dir ,
759+ split = split_name ,
760+ chunk_size = chunk_size ,
761+ export_kind = export_kind ,
762+ force = force ,
763+ subset_assets = subset_assets ,
764+ assets_required = datasets_root is not None ,
765+ )
766+ if stats is not None :
767+ config_name = split_name if multi_split or subset_split_rules else None
768+ config = stats .as_config (
769+ dataset_dir_name = dataset_dir ,
770+ split = split_name ,
771+ config_name = config_name ,
772+ )
773+ build_results .append (BuildResult (config = config , stats = stats ))
774+
661775 shutil .rmtree (staging_root , ignore_errors = True )
662776
663- if not stats :
777+ if not build_results :
664778 typer .echo ("No datasets were produced." , err = True )
665779 raise typer .Exit (code = 1 )
666780
667- config = stats .as_config (dataset_dir_name , split )
668781 write_readme (
669782 output_root = output_dir ,
670- config = config ,
671- stats = stats ,
783+ builds = build_results ,
672784 license_name = license_name ,
673785 export_kind = export_kind ,
674786 custom_dirname = custom_dirname ,
0 commit comments