@@ -89,7 +89,7 @@ def get_datasets_list(
8989
9090def get_finetuning_dataloader (
9191 args : TrainingArgs | InferenceArgs , split : DatasetSplit , mode : Mode , tokenizer : TOKENIZER_TYPE
92- ) -> tuple [ ResumableDataLoader ] :
92+ ) -> ResumableDataLoader :
9393 """prepares datasets and sampler
9494
9595 Args:
@@ -99,7 +99,7 @@ def get_finetuning_dataloader(
9999 tokenizer (TOKENIZER_TYPE): tokenizer
100100
101101 Returns:
102- tuple[ ResumableDataLoader] : dataloader for a blended dataset
102+ ResumableDataLoader: dataloader for a blended dataset
103103 """
104104
105105 assert mode == Mode .training , "blended dataset is only supported in training mode"
@@ -121,7 +121,7 @@ def get_finetuning_dataloader(
121121
122122def get_pretraining_dataloaders (
123123 args : TrainingArgs , tokenizer : TOKENIZER_TYPE , consumed_samples : int
124- ) -> tuple [ResumableDataLoader ]:
124+ ) -> tuple [ResumableDataLoader , list [ ResumableDataLoader ], list [ ResumableDataLoader ] ]:
125125 if args .datasets [0 ].class_name == "MegatronDataset" :
126126 dataloaders = get_megatron_gpt_dataloaders (args , tokenizer , consumed_samples = consumed_samples )
127127 elif args .datasets [0 ].class_name == "IBMDataset" :
@@ -132,7 +132,7 @@ def get_pretraining_dataloaders(
132132
133133def _get_dispatching_dataloader (
134134 args : TrainingArgs | InferenceArgs , split : DatasetSplit , mode : Mode , tokenizer : TOKENIZER_TYPE
135- ) -> tuple [ ResumableDataLoader ] :
135+ ) -> ResumableDataLoader :
136136 micro_batch_size = args .training_parameters .micro_batch_size
137137
138138 num_ranks_per_node = torch .cuda .device_count ()
@@ -211,7 +211,7 @@ def _get_source_broadcast_mapping() -> dict:
211211
212212def _get_non_dispatching_dataloader (
213213 args : TrainingArgs | InferenceArgs , split : DatasetSplit , mode : Mode , tokenizer : TOKENIZER_TYPE
214- ) -> tuple [ ResumableDataLoader ] :
214+ ) -> ResumableDataLoader :
215215 micro_batch_size = args .training_parameters .micro_batch_size
216216
217217 datasets_list , data_sampling_ratios = get_datasets_list (
0 commit comments