Skip to content

Commit eb18c64

Browse files
authored
[grpo] check eval_dataset length (#4781)
* check evalds length * check valds in trainargs * set default split_dataset_ratio 0 for grpo * fix generation_batch_size check
1 parent 4ddb7fa commit eb18c64

File tree

4 files changed

+11
-1
lines changed

4 files changed

+11
-1
lines changed

swift/llm/argument/train_args.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -169,6 +169,8 @@ def __post_init__(self) -> None:
169169

170170
if getattr(self, 'accelerator_config', None) is None:
171171
self.accelerator_config = {'dispatch_batches': False}
172+
if self.split_dataset_ratio == 0 and not self.val_dataset:
173+
self.eval_strategy = 'no'
172174
self.training_args = TrainerFactory.get_training_args(self)
173175
self.training_args.remove_unused_columns = False
174176
self._add_version()

swift/trainers/arguments.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -225,6 +225,7 @@ class GRPOArgumentsMixin:
225225

226226
# dataset
227227
dataset_shuffle: Optional[bool] = True
228+
split_dataset_ratio: float = 0.0
228229

229230

230231
@dataclass

swift/trainers/rlhf_arguments.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -80,7 +80,7 @@ def check_num_generations(self):
8080
# check num_generations for trl < 0.18
8181
num_processes = self.world_size
8282

83-
if self.generation_batch_size % self.per_device_train_batch_size * num_processes != 0:
83+
if self.generation_batch_size % (self.per_device_train_batch_size * num_processes) != 0:
8484
raise ValueError(
8585
f'generation_batch_size ({self.generation_batch_size}) must be divisible by the global batch size '
8686
f'({self.per_device_train_batch_size * num_processes}).')

swift/trainers/rlhf_trainer/grpo_trainer.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -188,6 +188,13 @@ def __init__(self,
188188
vllm_client = kwargs.pop('vllm_client') # for external vllm
189189

190190
super().__init__(model, ref_model, *_args, **kwargs)
191+
if self.args.eval_strategy != 'no':
192+
total_eval_batch_size = self.args.per_device_eval_batch_size * \
193+
self.accelerator.num_processes // self.args.num_generations
194+
assert len(self.eval_dataset) >= total_eval_batch_size, (
195+
f'eval_dataset size {len(self.eval_dataset)} is smaller than '
196+
f'total_eval_batch_size {total_eval_batch_size}. '
197+
f'Please increase the size of eval_dataset or set a larger value for split_dataset_ratio.')
191198
# Multi-step
192199
self.num_iterations = args.num_iterations # = 𝜇 in the GRPO paper
193200

0 commit comments

Comments
 (0)