Skip to content

Commit d9799d9

Browse files
authored
Fix microbatch loss scale when loss_agg_mode is "token-mean" (#336)
1 parent 1120aed commit d9799d9

File tree

3 files changed

+34
-6
lines changed

3 files changed

+34
-6
lines changed

trinity/common/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -646,6 +646,7 @@ class TrainerConfig:
646646
# if None, automatically set to ceil(2 * model.max_model_len / ulysses_sequence_parallel_size)
647647
max_token_len_per_gpu: Optional[int] = None
648648
ulysses_sequence_parallel_size: int = 1 # sp size
649+
fix_actor_microbatch_loss_scale: bool = False # EXPERIMENTAL
649650
# TODO: extract more train-related params from underlying trainer engine
650651

651652
save_strategy: SaveStrategy = SaveStrategy.UNRESTRICTED

trinity/common/verl_config.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,6 +136,7 @@ class Actor:
136136
ppo_micro_batch_size_per_gpu: int = 1
137137
use_dynamic_bsz: Optional[bool] = None
138138
ppo_max_token_len_per_gpu: Optional[int] = None
139+
fix_actor_microbatch_loss_scale: Optional[bool] = None # EXPERIMENTAL
139140
grad_clip: Optional[float] = None
140141
ppo_epochs: int = 1
141142
shuffle: bool = False
@@ -427,6 +428,10 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
427428
self.actor_rollout_ref.actor.ppo_max_token_len_per_gpu = (
428429
config.trainer.max_token_len_per_gpu
429430
)
431+
if self.actor_rollout_ref.actor.fix_actor_microbatch_loss_scale is None:
432+
self.actor_rollout_ref.actor.fix_actor_microbatch_loss_scale = (
433+
config.trainer.fix_actor_microbatch_loss_scale
434+
)
430435
if self.actor_rollout_ref.actor.ulysses_sequence_parallel_size is None:
431436
self.actor_rollout_ref.actor.ulysses_sequence_parallel_size = (
432437
config.trainer.ulysses_sequence_parallel_size

trinity/trainer/verl/dp_actor.py

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -88,6 +88,16 @@ def update_policy(self, data: DataProto): # noqa: C901
8888

8989
mini_batches = data.split(self.config.ppo_mini_batch_size)
9090

91+
# EXPERIMENTAL: apply loss scale fix
92+
loss_agg_mode = (
93+
self.policy_loss_fn.loss_agg_mode
94+
if hasattr(self.policy_loss_fn, "loss_agg_mode")
95+
else "token-mean"
96+
)
97+
do_fix_actor_microbatch_loss_scale = self.config.fix_actor_microbatch_loss_scale and (
98+
loss_agg_mode == "token-mean"
99+
)
100+
91101
metrics = {}
92102
for _ in range(self.config.ppo_epochs):
93103
for batch_idx, mini_batch in enumerate(mini_batches):
@@ -104,6 +114,12 @@ def update_policy(self, data: DataProto): # noqa: C901
104114
)
105115
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)
106116

117+
if do_fix_actor_microbatch_loss_scale:
118+
# calculate the total number of response tokens in the minibatch
119+
mini_batch_token_num = torch.sum(
120+
mini_batch.batch["response_mask"].to(get_device_id())
121+
).item()
122+
107123
self.actor_optimizer.zero_grad()
108124

109125
for micro_batch in micro_batches:
@@ -156,13 +172,19 @@ def update_policy(self, data: DataProto): # noqa: C901
156172
)
157173
policy_loss = policy_loss + kl_loss
158174

159-
if self.config.use_dynamic_bsz:
160-
# relative to the dynamic bsz
161-
loss = policy_loss * (
162-
response_mask.shape[0] / self.config.ppo_mini_batch_size
163-
)
175+
# set loss scale for the microbatch
176+
if not do_fix_actor_microbatch_loss_scale:
177+
# original implementation of microbatch loss scale
178+
if self.config.use_dynamic_bsz:
179+
loss_scale = response_mask.shape[0] / self.config.ppo_mini_batch_size
180+
else:
181+
loss_scale = 1.0 / self.gradient_accumulation
164182
else:
165-
loss = policy_loss / self.gradient_accumulation
183+
# EXPERIMENTAL: fix for token-mean loss aggregation
184+
# scale microbatch loss according to the number of tokens (rather than sequences)
185+
loss_scale = torch.sum(response_mask).item() / (mini_batch_token_num + 1e-6)
186+
187+
loss = policy_loss * loss_scale
166188
loss.backward()
167189

168190
append_to_dict(metrics, micro_batch_metrics)

0 commit comments

Comments
 (0)