@@ -452,36 +452,9 @@ def split_dataset(
452452 )
453453 return split_datasets
454454
455- def _process_datasets_for_odm (
456- self ,
457- processed_datasets : List [
458- Tuple [DataSetConfig , Union [DatasetDict , IterableDatasetDict ]]
459- ],
460- ) -> Tuple [
461- Dict [str , Union [Dataset , IterableDataset ]],
462- Dict [str , Union [Dataset , IterableDataset ]],
463- ]:
464- train_split = "train"
465- eval_split = "test"
466- train_datasets_dict = {}
467- eval_datasets_dict = {}
468- for d , raw in processed_datasets :
469- if train_split in raw :
470- train_datasets_dict [d .name ] = raw [train_split ]
471- if eval_split in raw :
472- eval_datasets_dict [d .name ] = raw [eval_split ]
473- return train_datasets_dict , eval_datasets_dict
474-
475- def _process_dataset_configs (
476- self , dataset_configs : List [DataSetConfig ], odm_config = None
477- ) -> Union [
478- Tuple [Union [Dataset , IterableDataset ], Union [Dataset , IterableDataset ]],
479- Tuple [
480- Dict [str , Union [Dataset , IterableDataset ]],
481- Dict [str , Union [Dataset , IterableDataset ]],
482- ],
483- ]:
484-
455+ def _prepare_processed_datasets (
456+ self , dataset_configs : List [DataSetConfig ]
457+ ) -> List [Tuple [DataSetConfig , Union [IterableDataset , Dataset ]]]:
485458 if not dataset_configs :
486459 raise ValueError (
487460 "No dataset configs provided. Provided Dataset configs is None."
@@ -530,13 +503,35 @@ def _process_dataset_configs(
530503
531504 # Append the processed datasets to the final dict
532505 processed_datasets .append ((d , raw_datasets ))
533- if odm_config :
534- logger .info (
535- "Sampling probabilities are ignored if provided"
536- "and are not used for concatenation. Instead"
537- "online data mixing plugin handles it."
538- )
539- return self ._process_datasets_for_odm (processed_datasets )
506+ return processed_datasets
507+
508+ def _validate_sampling_ratios (self , sampling_ratios : List [float ], train_datasets ):
509+ if len (sampling_ratios ) > 0 :
510+ if len (sampling_ratios ) < len (train_datasets ):
511+ raise ValueError (
512+ "Sampling probability should be specified for all datasets with train split"
513+ )
514+ if len (sampling_ratios ) > len (train_datasets ):
515+ raise ValueError (
516+ "Sampling probability should only be specified for datasets with train split"
517+ )
518+ if sum (p for p in sampling_ratios ) != 1 :
519+ raise ValueError (
520+ "Sampling probabilities for train datasets don't sum to 1"
521+ )
522+ return True
523+
524+ def _process_dataset_configs (
525+ self , dataset_configs : List [DataSetConfig ]
526+ ) -> Tuple [
527+ Union [Dataset , IterableDataset ],
528+ Union [Dataset , IterableDataset ],
529+ Dict [str , float ],
530+ ]:
531+ train_split = "train" # default
532+ eval_split = "test"
533+ processed_datasets = self ._prepare_processed_datasets (dataset_configs )
534+
540535 train_datasets = []
541536 train_sampling_probabilities = []
542537 validation_datasets = []
@@ -557,25 +552,9 @@ def _process_dataset_configs(
557552 )
558553
559554 # quick check to see if we are sampling and if we need to throw error.
560- if len (train_sampling_probabilities ) > 0 :
561- if len (train_sampling_probabilities ) < len (train_datasets ):
562- raise ValueError (
563- "Sampling probability should be specified for all datasets with train split"
564- )
565- if len (train_sampling_probabilities ) > len (train_datasets ):
566- raise ValueError (
567- "Sampling probability should only be specified for datasets with train split"
568- )
569- if sum (p for p in train_sampling_probabilities ) != 1 :
570- raise ValueError (
571- "Sampling probabilities for train datasets don't sum to 1"
572- )
573- sample_datasets = True
574- logger .info (
575- "Sampling ratios are specified; only train datasets will be interleaved."
576- )
577- else :
578- sample_datasets = False
555+ sample_datasets = self ._validate_sampling_ratios (
556+ train_sampling_probabilities , train_datasets
557+ )
579558
580559 # Ensure again datasets are aligned before interleaving or concatenating
581560 maybe_align_datasets (train_datasets )
@@ -620,16 +599,14 @@ def _process_dataset_configs(
620599 if eval_dataset and isinstance (eval_dataset , IterableDataset ):
621600 eval_dataset = resolve_iterable_dataset_features (eval_dataset )
622601
623- return train_dataset , eval_dataset
602+ return train_dataset , eval_dataset , None
624603
625604 def process_dataset_configs (
626- self , dataset_configs : List [DataSetConfig ], odm_config = None
627- ) -> Union [
628- Tuple [Union [Dataset , IterableDataset ], Union [Dataset , IterableDataset ]],
629- Tuple [
630- Dict [str , Union [Dataset , IterableDataset ]],
631- Dict [str , Union [Dataset , IterableDataset ]],
632- ],
605+ self , dataset_configs : List [DataSetConfig ]
606+ ) -> Tuple [
607+ Union [Dataset , IterableDataset ],
608+ Union [Dataset , IterableDataset ],
609+ Dict [str , float ],
633610 ]:
634611 train_dataset = eval_dataset = None
635612
@@ -643,14 +620,43 @@ def process_dataset_configs(
643620 # as we want to reuse HF cache and not redo computation on all nodes
644621 # For rationale see https://github.com/huggingface/trl/pull/3106
645622 with state .main_process_first ():
646- train_dataset , eval_dataset = self ._process_dataset_configs (
647- dataset_configs , odm_config
648- )
623+ (
624+ train_dataset ,
625+ eval_dataset ,
626+ sampling_weights ,
627+ ) = self ._process_dataset_configs (dataset_configs )
649628
650629 logger .info ("Processed train dataset {}" .format (train_dataset ))
651630 logger .info ("Processed eval dataset {}" .format (eval_dataset ))
652631
653- return train_dataset , eval_dataset
632+ return train_dataset , eval_dataset , sampling_weights
633+
634+
635+ class ODMDataPreProcessor (DataPreProcessor ):
636+ def _process_dataset_configs (
637+ self , dataset_configs : List [DataSetConfig ]
638+ ) -> Tuple [
639+ Dict [str , Union [Dataset , IterableDataset ]],
640+ Dict [str , Union [Dataset , IterableDataset ]],
641+ Dict [str , float ],
642+ ]:
643+ processed_datasets = self ._prepare_processed_datasets (dataset_configs )
644+ train_split = "train"
645+ eval_split = "test"
646+ train_datasets_dict = {}
647+ eval_datasets_dict = {}
648+ sampling_weights_dict = {}
649+ for d , raw in processed_datasets :
650+ if d .sampling is not None and d .sampling > 0.0 :
651+ sampling_weights_dict [d .name ] = d .sampling
652+ if train_split in raw :
653+ train_datasets_dict [d .name ] = raw [train_split ]
654+ if eval_split in raw :
655+ eval_datasets_dict [d .name ] = raw [eval_split ]
656+ self ._validate_sampling_ratios (
657+ sampling_weights_dict .values (), train_datasets_dict .values ()
658+ )
659+ return train_datasets_dict , eval_datasets_dict , sampling_weights_dict
654660
655661
656662def get_datapreprocessor (
@@ -659,7 +665,10 @@ def get_datapreprocessor(
659665 processor : AutoProcessor = None ,
660666 additional_data_handlers : Dict [str , DataHandler ] = None ,
661667) -> DataPreProcessor :
662- data_processor = DataPreProcessor (
668+ data_processor_cls = DataPreProcessor
669+ if processor_config .type == "odm" :
670+ data_processor_cls = ODMDataPreProcessor
671+ data_processor = data_processor_cls (
663672 processor_config = processor_config ,
664673 tokenizer = tokenizer ,
665674 processor = processor ,
0 commit comments