@@ -56,6 +56,8 @@ class GRPOArguments(GRPOArgumentsMixin):
5656 # multi step
5757 num_iterations : int = 1
5858
59+ truncation_strategy : Optional [Literal ['delete' , 'left' , 'right' ]] = None
60+
5961
6062@dataclass
6163class RLHFArguments (GRPOArguments , PPOArguments , RewardModelArguments , TrainArguments ):
@@ -108,7 +110,6 @@ def __post_init__(self):
108110 self ._init_ppo ()
109111 self ._set_default ()
110112 super ().__post_init__ ()
111- self ._init_grpo_ds3 ()
112113 self ._check_rlhf ()
113114 self ._check_grpo ()
114115
@@ -139,7 +140,11 @@ def _init_grpo(self):
139140 self .gradient_accumulation_steps = 1
140141 self .remove_unused_columns = False
141142 logger .info (f'Setting args.remove_unused_columns: { self .remove_unused_columns } ' )
142- self .truncation_strategy = 'left' # Used for trimming the excessively long parts of a prompt.
143+ if self .truncation_strategy is None :
144+ self .truncation_strategy = 'left'
145+ assert self .truncation_strategy == 'left' , \
146+ "GRPO requires `truncation_strategy='left'`," \
147+ f"Current value: `truncation_strategy='{ self .truncation_strategy } '`."
143148 if self .beta is None :
144149 self .beta = 0.04 # https://arxiv.org/abs/2402.03300
145150 if self .async_generate :
@@ -189,11 +194,6 @@ def _set_default(self):
189194 elif self .rlhf_type in ['kto' ]:
190195 self .loss_type = 'kto'
191196
192- def _init_grpo_ds3 (self ):
193- if self .rlhf_type == 'grpo' and self .deepspeed :
194- if 'zero_optimization' in self .deepspeed and self .deepspeed ['zero_optimization' ]['stage' ] == 3 :
195- self .deepspeed ['zero_optimization' ]['stage3_prefetch_bucket_size' ] = 0
196-
197197 def _check_rlhf (self ):
198198 if self .sequence_parallel_size > 1 :
199199 raise ValueError ('RLHF do not support sequence parallel' )
0 commit comments