Skip to content

Commit d47b50d

Browse files
committed
Fix mean mismatch metrics
1 parent dceb2d6 commit d47b50d

File tree

2 files changed

+25
-9
lines changed

2 files changed

+25
-9
lines changed

xtuner/v1/rl/base/rollout_is.py

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -467,7 +467,7 @@ def compute_mismatch_metrics(
467467
- When Speed Kills Stability: https://yingru.notion.site/When-Speed-Kills-Stability-271211a558b7808d8b12d403fd15edda
468468
"""
469469
metrics = {}
470-
470+
metrics["valid"] = response_mask.any().item()
471471
# 1. Training policy perplexity (always available)
472472
# Formula: exp(-1/|T| * Σ log π_training(y_t|y_<t))
473473
# where |T| is the number of tokens generated by the model
@@ -529,18 +529,32 @@ def compute_mismatch_metrics(
529529

530530
def merge_rollout_is_metrics(rollout_is_metrics: list[dict[str, float]], device="cuda") -> dict[str, float]:
531531
metrics = {}
532-
for key in rollout_is_metrics[0].keys():
533-
all_values = [m[key] for m in rollout_is_metrics]
532+
keys = [k for k in rollout_is_metrics[0].keys() if k != "mismatch/valid"]
533+
534+
for key in keys:
535+
values = []
536+
valids = []
537+
for m in rollout_is_metrics:
538+
is_valid = m.get("mismatch/valid", True)
539+
valids.append(float(is_valid))
540+
values.append(m[key] if is_valid else 0.0) # set to 0.0 if invalid
541+
value_tensor = torch.tensor(values, dtype=torch.float32, device=device)
542+
valid_tensor = torch.tensor(valids, dtype=torch.float32, device=device)
543+
544+
# Aggregate across all processes
534545
if "max" in key:
535-
max_value = torch.tensor(all_values).max().to(torch.float32).to(device)
546+
max_value = value_tensor.max()
536547
dist.all_reduce(max_value, op=dist.ReduceOp.MAX)
537548
metrics[key] = max_value.item()
538549
elif "min" in key:
539-
min_value = torch.tensor(all_values).min().to(torch.float32).to(device)
550+
min_value = value_tensor.min()
540551
dist.all_reduce(min_value, op=dist.ReduceOp.MIN)
541552
metrics[key] = min_value.item()
542553
else:
543-
mean_value = torch.tensor(all_values).mean().to(torch.float32).to(device)
544-
dist.all_reduce(mean_value, op=dist.ReduceOp.AVG)
545-
metrics[key] = mean_value.item()
554+
sum_value = value_tensor.sum()
555+
count_value = valid_tensor.sum()
556+
dist.all_reduce(sum_value, op=dist.ReduceOp.SUM)
557+
dist.all_reduce(count_value, op=dist.ReduceOp.SUM)
558+
metrics[key] = sum_value.item() / count_value.item() if count_value.item() > 0 else 0.0
559+
546560
return metrics

xtuner/v1/rl/base/worker.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -477,7 +477,9 @@ def fit(self, data_batches: list[WorkerInputItem], rollout_idx: int):
477477
rollout_is_metrics = merge_rollout_is_metrics(all_rollout_is_metrics, DEVICE)
478478
if len(rollout_is_metrics) > 0:
479479
logger_msg += f"\n rollout importance sampling metrics:\n{json.dumps(rollout_is_metrics, indent=4)}"
480-
self.logger.info(logger_msg)
480+
481+
if self.rank == 0:
482+
self.logger.info(logger_msg)
481483

482484
if self._has_ref:
483485
# ref logprobs are inplaced updated in compute_actor_logprobs

0 commit comments

Comments
 (0)