Skip to content

Commit 4729c6f

Browse files
committed
fix: function and argument types
Signed-off-by: Mehant Kammakomati <[email protected]>
1 parent 2f619f8 commit 4729c6f

File tree

1 file changed

+25
-5
lines changed

1 file changed

+25
-5
lines changed

tuning/data/data_processors.py

Lines changed: 25 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,7 @@
1313
# limitations under the License.
1414

1515
# Standard
16-
from typing import Dict, List, Union
16+
from typing import Dict, List, Tuple, Union
1717
import logging
1818
import os
1919

@@ -452,8 +452,16 @@ def split_dataset(
452452
)
453453
return split_datasets
454454

455-
def _process_datasets_for_odm(self, processed_datasets):
456-
train_split = "train" # default
455+
def _process_datasets_for_odm(
456+
self,
457+
processed_datasets: List[
458+
Tuple[DataSetConfig, Union[DatasetDict, IterableDatasetDict]]
459+
],
460+
) -> Tuple[
461+
Dict[str, Union[Dataset, IterableDataset]],
462+
Dict[str, Union[Dataset, IterableDataset]],
463+
]:
464+
train_split = "train"
457465
eval_split = "test"
458466
train_datasets_dict = {}
459467
eval_datasets_dict = {}
@@ -466,7 +474,13 @@ def _process_datasets_for_odm(self, processed_datasets):
466474

467475
def _process_dataset_configs(
468476
self, dataset_configs: List[DataSetConfig], odm_config=None
469-
) -> Union[Dataset, IterableDataset]:
477+
) -> Union[
478+
Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]],
479+
Tuple[
480+
Dict[str, Union[Dataset, IterableDataset]],
481+
Dict[str, Union[Dataset, IterableDataset]],
482+
],
483+
]:
470484

471485
if not dataset_configs:
472486
raise ValueError(
@@ -605,7 +619,13 @@ def _process_dataset_configs(
605619

606620
def process_dataset_configs(
607621
self, dataset_configs: List[DataSetConfig], odm_config=None
608-
) -> Union[Dataset, IterableDataset]:
622+
) -> Union[
623+
Tuple[Union[Dataset, IterableDataset], Union[Dataset, IterableDataset]],
624+
Tuple[
625+
Dict[str, Union[Dataset, IterableDataset]],
626+
Dict[str, Union[Dataset, IterableDataset]],
627+
],
628+
]:
609629
train_dataset = eval_dataset = None
610630

611631
# Use partial state as recommended by HF documentation for process control

0 commit comments

Comments
 (0)