77from dataclasses import dataclass , field
88from datetime import datetime
99from enum import Enum
10- from typing import Any , Dict , List , Optional
10+ from typing import Any , Dict , List , Optional , Union
1111
1212from omegaconf import OmegaConf
1313
@@ -108,6 +108,10 @@ class StorageConfig:
108108 path : Optional [str ] = None
109109 repeat_times : Optional [int ] = None
110110
111+ # For shuffle
112+ shuffle : bool = False
113+ seed : int = 42
114+
111115 # For continuing training
112116 index : int = 0
113117
@@ -369,7 +373,8 @@ class ClusterConfig:
369373class ExplorerInput :
370374 """Config for explorer input."""
371375
372- taskset : StorageConfig = field (default_factory = StorageConfig )
376+ taskset : Optional [StorageConfig ] = None
377+ tasksets : List [StorageConfig ] = field (default_factory = list )
373378 eval_tasksets : List [StorageConfig ] = field (default_factory = list )
374379 # The following args provide default values for the corresponding args in `taskset` and `eval_tasksets`
375380 default_workflow_type : Optional [str ] = None
@@ -630,40 +635,44 @@ def _check_buffer(self) -> None: # noqa: C901
630635 trainer_input = self .buffer .trainer_input
631636 experience_buffer = trainer_input .experience_buffer
632637 explorer_input = self .buffer .explorer_input
633- taskset = explorer_input .taskset
634638
635- if self .mode != "train" and not taskset .path :
636- raise ValueError (
637- "`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset."
638- )
639- if not taskset .name :
640- taskset .name = "taskset"
641- if taskset .repeat_times is None or taskset .repeat_times != self .algorithm .repeat_times :
642- taskset .repeat_times = self .algorithm .repeat_times
643- logger .info (
644- "`buffer.explorer_input.taskset.repeat_times` is set to `algorithm.repeat_times`"
645- f" (={ self .algorithm .repeat_times } )."
639+ if len (explorer_input .tasksets ) == 0 and explorer_input .taskset :
640+ explorer_input .tasksets .append (explorer_input .taskset )
641+ tasksets = explorer_input .tasksets
642+
643+ for taskset in tasksets :
644+ if self .mode != "train" and not taskset .path :
645+ raise ValueError (
646+ "`buffer.explorer_input.taskset.path` is required, please set it to the path of the taskset."
647+ )
648+ if not taskset .name :
649+ taskset .name = "taskset"
650+ if taskset .repeat_times is None or taskset .repeat_times != self .algorithm .repeat_times :
651+ taskset .repeat_times = self .algorithm .repeat_times
652+ logger .info (
653+ "`buffer.explorer_input.taskset.repeat_times` is set to `algorithm.repeat_times`"
654+ f" (={ self .algorithm .repeat_times } )."
655+ )
656+ if self .mode == "train" :
657+ assert (
658+ experience_buffer is not None
659+ ), "`buffer.trainer_input.experience_buffer` is required when `mode` is `train`."
660+ experience_buffer .total_epochs = self .buffer .total_epochs
661+ experience_buffer .total_steps = self .buffer .total_steps
662+ else :
663+ taskset .is_eval = False
664+ taskset .total_epochs = self .buffer .total_epochs
665+ taskset .total_steps = self .buffer .total_steps
666+
667+ set_if_none (taskset , "default_workflow_type" , explorer_input .default_workflow_type )
668+ set_if_none (
669+ taskset , "default_eval_workflow_type" , explorer_input .default_eval_workflow_type
646670 )
647- if self .mode == "train" :
648- assert (
649- experience_buffer is not None
650- ), "`buffer.trainer_input.experience_buffer` is required when `mode` is `train`."
651- experience_buffer .total_epochs = self .buffer .total_epochs
652- experience_buffer .total_steps = self .buffer .total_steps
653- else :
654- taskset .is_eval = False
655- taskset .total_epochs = self .buffer .total_epochs
656- taskset .total_steps = self .buffer .total_steps
657-
658- set_if_none (taskset , "default_workflow_type" , explorer_input .default_workflow_type )
659- set_if_none (
660- taskset , "default_eval_workflow_type" , explorer_input .default_eval_workflow_type
661- )
662- set_if_none (taskset , "default_reward_fn_type" , explorer_input .default_reward_fn_type )
663- set_if_none (taskset .format , "system_prompt" , explorer_input .system_prompt )
664- set_if_none (taskset .format , "reply_prefix" , explorer_input .reply_prefix )
665- set_if_none (taskset , "ray_namespace" , self .ray_namespace )
666- set_if_none (taskset .rollout_args , "max_tokens" , self .model .max_response_tokens )
671+ set_if_none (taskset , "default_reward_fn_type" , explorer_input .default_reward_fn_type )
672+ set_if_none (taskset .format , "system_prompt" , explorer_input .system_prompt )
673+ set_if_none (taskset .format , "reply_prefix" , explorer_input .reply_prefix )
674+ set_if_none (taskset , "ray_namespace" , self .ray_namespace )
675+ set_if_none (taskset .rollout_args , "max_tokens" , self .model .max_response_tokens )
667676
668677 remained_tasksets = []
669678 for idx , dataset in enumerate (explorer_input .eval_tasksets ):
@@ -730,8 +739,8 @@ def _check_buffer(self) -> None: # noqa: C901
730739 task_pipeline = self .data_processor .task_pipeline
731740 if task_pipeline is not None :
732741 if task_pipeline .output is None :
733- if taskset .path is not None :
734- task_pipeline .output = taskset
742+ if tasksets [ 0 ] .path is not None :
743+ task_pipeline .output = tasksets [ 0 ]
735744 elif (
736745 experience_buffer .schema_type in {"dpo" , "sft" }
737746 and experience_buffer .path is not None
@@ -740,7 +749,7 @@ def _check_buffer(self) -> None: # noqa: C901
740749 else :
741750 raise ValueError (
742751 "`data_processor.task_pipeline.output` is required when both "
743- "`buffer.explorer_input.taskset .path` and `buffer.trainer_input.experience_buffer.path` are "
752+ "`buffer.explorer_input.tasksets[0] .path` and `buffer.trainer_input.experience_buffer.path` are "
744753 "None"
745754 )
746755 if task_pipeline .output .path and os .path .exists (task_pipeline .output .path ):
0 commit comments