@@ -452,35 +452,9 @@ def split_dataset(
452452 )
453453 return split_datasets
454454
455- def _process_datasets_for_odm (
456- self ,
457- processed_datasets : List [
458- Tuple [DataSetConfig , Union [DatasetDict , IterableDatasetDict ]]
459- ],
460- ) -> Tuple [
461- Dict [str , Union [Dataset , IterableDataset ]],
462- Dict [str , Union [Dataset , IterableDataset ]],
463- ]:
464- train_split = "train"
465- eval_split = "test"
466- train_datasets_dict = {}
467- eval_datasets_dict = {}
468- for d , raw in processed_datasets :
469- if train_split in raw :
470- train_datasets_dict [d .name ] = raw [train_split ]
471- if eval_split in raw :
472- eval_datasets_dict [d .name ] = raw [eval_split ]
473- return train_datasets_dict , eval_datasets_dict
474-
475455 def _process_dataset_configs (
476- self , dataset_configs : List [DataSetConfig ], odm_config = None
477- ) -> Union [
478- Tuple [Union [Dataset , IterableDataset ], Union [Dataset , IterableDataset ]],
479- Tuple [
480- Dict [str , Union [Dataset , IterableDataset ]],
481- Dict [str , Union [Dataset , IterableDataset ]],
482- ],
483- ]:
456+ self , dataset_configs : List [DataSetConfig ]
457+ ) -> Tuple [Union [Dataset , IterableDataset ], Union [Dataset , IterableDataset ]]:
484458
485459 if not dataset_configs :
486460 raise ValueError (
@@ -530,13 +504,7 @@ def _process_dataset_configs(
530504
531505 # Append the processed datasets to the final dict
532506 processed_datasets .append ((d , raw_datasets ))
533- if odm_config :
534- logger .info (
535- "Sampling probabilities are ignored if provided"
536- "and are not used for concatenation. Instead"
537- "online data mixing plugin handles it."
538- )
539- return self ._process_datasets_for_odm (processed_datasets )
507+
540508 train_datasets = []
541509 train_sampling_probabilities = []
542510 validation_datasets = []
@@ -623,14 +591,8 @@ def _process_dataset_configs(
623591 return train_dataset , eval_dataset
624592
625593 def process_dataset_configs (
626- self , dataset_configs : List [DataSetConfig ], odm_config = None
627- ) -> Union [
628- Tuple [Union [Dataset , IterableDataset ], Union [Dataset , IterableDataset ]],
629- Tuple [
630- Dict [str , Union [Dataset , IterableDataset ]],
631- Dict [str , Union [Dataset , IterableDataset ]],
632- ],
633- ]:
594+ self , dataset_configs : List [DataSetConfig ]
595+ ) -> Tuple [Union [Dataset , IterableDataset ], Union [Dataset , IterableDataset ]]:
634596 train_dataset = eval_dataset = None
635597
636598 # Use partial state as recommended by HF documentation for process control
@@ -643,23 +605,108 @@ def process_dataset_configs(
643605 # as we want to reuse HF cache and not redo computation on all nodes
644606 # For rationale see https://github.com/huggingface/trl/pull/3106
645607 with state .main_process_first ():
646- train_dataset , eval_dataset = self ._process_dataset_configs (
647- dataset_configs , odm_config
648- )
608+ train_dataset , eval_dataset = self ._process_dataset_configs (dataset_configs )
649609
650610 logger .info ("Processed train dataset {}" .format (train_dataset ))
651611 logger .info ("Processed eval dataset {}" .format (eval_dataset ))
652612
653613 return train_dataset , eval_dataset
654614
655615
616+ class ODMDataPreProcessor (DataPreProcessor ):
617+ def _process_datasets_for_odm (
618+ self ,
619+ processed_datasets : List [
620+ Tuple [DataSetConfig , Union [DatasetDict , IterableDatasetDict ]]
621+ ],
622+ ) -> Tuple [
623+ Dict [str , Union [Dataset , IterableDataset ]],
624+ Dict [str , Union [Dataset , IterableDataset ]],
625+ ]:
626+ train_split = "train"
627+ eval_split = "test"
628+ train_datasets_dict = {}
629+ eval_datasets_dict = {}
630+ for d , raw in processed_datasets :
631+ if train_split in raw :
632+ train_datasets_dict [d .name ] = raw [train_split ]
633+ if eval_split in raw :
634+ eval_datasets_dict [d .name ] = raw [eval_split ]
635+ return train_datasets_dict , eval_datasets_dict
636+
637+ def _process_dataset_configs (
638+ self , dataset_configs : List [DataSetConfig ]
639+ ) -> Tuple [
640+ Dict [str , Union [Dataset , IterableDataset ]],
641+ Dict [str , Union [Dataset , IterableDataset ]],
642+ ]:
643+
644+ if not dataset_configs :
645+ raise ValueError (
646+ "No dataset configs provided. Provided Dataset configs is None."
647+ )
648+
649+ train_split = "train" # default
650+ eval_split = "test"
651+
652+ processed_datasets = []
653+
654+ logger .info ("Starting DataPreProcessor..." )
655+ # Now Iterate over the multiple datasets provided to us to process
656+ for d in dataset_configs :
657+ logger .info ("Loading the dataset - %s" , d .name )
658+
659+ # In future the streaming etc go as kwargs of this function
660+ loaded_dataset = self .load_dataset (d , self .processor_config .streaming )
661+ logger .info ("Loaded raw dataset : %s" , str (loaded_dataset ))
662+
663+ if d .split is not None :
664+ loaded_dataset = self .split_dataset (d , loaded_dataset )
665+
666+ # Create a raw dataset which is a Dict container to house all Datasets
667+ raw_datasets = (
668+ IterableDatasetDict ()
669+ if isinstance (loaded_dataset , (IterableDataset , IterableDatasetDict ))
670+ else DatasetDict ()
671+ )
672+
673+ splits_to_keep = [train_split , eval_split ]
674+ if isinstance (loaded_dataset , (Dataset , IterableDataset )):
675+ # Assume all is train split
676+ raw_datasets [train_split ] = loaded_dataset
677+ else :
678+ for k , v in loaded_dataset .items ():
679+ if k in splits_to_keep :
680+ raw_datasets [k ] = v
681+
682+ if d .data_handlers : # Execute the datahandlers
683+ for data_handler_config in d .data_handlers :
684+ raw_datasets = self ._execute_data_handlers (
685+ raw_datasets = raw_datasets ,
686+ data_handler_config = data_handler_config ,
687+ datasetName = d .name ,
688+ )
689+
690+ # Append the processed datasets to the final dict
691+ processed_datasets .append ((d , raw_datasets ))
692+ logger .info (
693+ "Sampling probabilities are ignored if provided"
694+ "and are not used for concatenation. Instead"
695+ "online data mixing plugin handles it."
696+ )
697+ return self ._process_datasets_for_odm (processed_datasets )
698+
699+
656700def get_datapreprocessor (
657701 processor_config : DataPreProcessorConfig ,
658702 tokenizer : AutoTokenizer ,
659703 processor : AutoProcessor = None ,
660704 additional_data_handlers : Dict [str , DataHandler ] = None ,
661705) -> DataPreProcessor :
662- data_processor = DataPreProcessor (
706+ data_processor_cls = DataPreProcessor
707+ if processor_config .type == "odm" :
708+ data_processor_cls = ODMDataPreProcessor
709+ data_processor = data_processor_cls (
663710 processor_config = processor_config ,
664711 tokenizer = tokenizer ,
665712 processor = processor ,
0 commit comments