Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions trinity/common/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,7 @@ class TrainerConfig:
use_dynamic_bsz: bool = True
ppo_max_token_len_per_gpu: int = 16384
ulysses_sequence_parallel_size: int = 1 # sp size
fix_actor_microbatch_loss_scale: bool = False # EXPERIMENTAL
# TODO: extract more train-related params from underlying trainer engine

save_strategy: SaveStrategy = SaveStrategy.UNRESTRICTED
Expand Down
5 changes: 5 additions & 0 deletions trinity/common/verl_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -136,6 +136,7 @@ class Actor:
ppo_micro_batch_size_per_gpu: int = 1
use_dynamic_bsz: Optional[bool] = None
ppo_max_token_len_per_gpu: Optional[int] = None
fix_actor_microbatch_loss_scale: Optional[bool] = None # EXPERIMENTAL
grad_clip: Optional[float] = None
ppo_epochs: int = 1
shuffle: bool = False
Expand Down Expand Up @@ -425,6 +426,10 @@ def synchronize_config(self, config: Config) -> None: # noqa: C901
self.actor_rollout_ref.actor.ppo_max_token_len_per_gpu = (
config.trainer.ppo_max_token_len_per_gpu
)
if self.actor_rollout_ref.actor.fix_actor_microbatch_loss_scale is None:
self.actor_rollout_ref.actor.fix_actor_microbatch_loss_scale = (
config.trainer.fix_actor_microbatch_loss_scale
)
if self.actor_rollout_ref.actor.ulysses_sequence_parallel_size is None:
self.actor_rollout_ref.actor.ulysses_sequence_parallel_size = (
config.trainer.ulysses_sequence_parallel_size
Expand Down
34 changes: 28 additions & 6 deletions trinity/trainer/verl/dp_actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,16 @@ def update_policy(self, data: DataProto): # noqa: C901

mini_batches = data.split(self.config.ppo_mini_batch_size)

# EXPERIMENTAL: apply loss scale fix
loss_agg_mode = (
self.policy_loss_fn.loss_agg_mode
if hasattr(self.policy_loss_fn, "loss_agg_mode")
else "token-mean"
)
do_fix_actor_microbatch_loss_scale = self.config.fix_actor_microbatch_loss_scale and (
loss_agg_mode == "token-mean"
)

metrics = {}
for _ in range(self.config.ppo_epochs):
for batch_idx, mini_batch in enumerate(mini_batches):
Expand All @@ -104,6 +114,12 @@ def update_policy(self, data: DataProto): # noqa: C901
)
micro_batches = mini_batch.split(self.config.ppo_micro_batch_size_per_gpu)

if do_fix_actor_microbatch_loss_scale:
# calculate the total number of response tokens in the minibatch
mini_batch_token_num = torch.sum(
mini_batch.batch["response_mask"].to(get_device_id())
).item() # TODO: double check this calculation

self.actor_optimizer.zero_grad()

for micro_batch in micro_batches:
Expand Down Expand Up @@ -156,13 +172,19 @@ def update_policy(self, data: DataProto): # noqa: C901
)
policy_loss = policy_loss + kl_loss

if self.config.use_dynamic_bsz:
# relative to the dynamic bsz
loss = policy_loss * (
response_mask.shape[0] / self.config.ppo_mini_batch_size
)
# set loss scale for the microbatch
if not do_fix_actor_microbatch_loss_scale:
# original implementation of microbatch loss scale
if self.config.use_dynamic_bsz:
loss_scale = response_mask.shape[0] / self.config.ppo_mini_batch_size
else:
loss_scale = 1.0 / self.gradient_accumulation
else:
loss = policy_loss / self.gradient_accumulation
# EXPERIMENTAL: fix for token-mean loss aggregation
# scale microbatch loss according to the number of tokens (rather than sequences)
loss_scale = torch.sum(response_mask).item() / mini_batch_token_num

loss = policy_loss * loss_scale
loss.backward()

append_to_dict(metrics, micro_batch_metrics)
Expand Down