|
53 | 53 | from prime_rl.utils.temp_scheduling import compute_temperature |
54 | 54 | from prime_rl.utils.utils import ( |
55 | 55 | clean_exit, |
| 56 | + count_chinese_chars, |
56 | 57 | get_env_ids_to_install, |
57 | 58 | install_env, |
58 | 59 | resolve_latest_ckpt_step, |
@@ -489,6 +490,17 @@ def process_rollout(rollout: vf.State) -> list[TrainingSample] | None: |
489 | 490 | # Gather individual reward function metrics |
490 | 491 | metrics_df = pd.DataFrame([rollout["metrics"] for rollout in train_rollouts]) |
491 | 492 |
|
| 493 | + # Count Chinese characters in completions |
| 494 | + chinese_stats = [] |
| 495 | + for rollout in train_rollouts: |
| 496 | + trajectory = rollout["trajectory"] |
| 497 | + last_step = trajectory[-1] |
| 498 | + tokens = last_step["tokens"] |
| 499 | + completion_text = tokenizer.decode(tokens["completion_ids"]) |
| 500 | + chinese_count, total_count = count_chinese_chars(completion_text) |
| 501 | + chinese_stats.append({"chinese_chars": chinese_count, "total_chars": total_count, "has_chinese": chinese_count > 0}) |
| 502 | + chinese_df = pd.DataFrame(chinese_stats) |
| 503 | + |
492 | 504 | val_results_df = ( |
493 | 505 | pd.DataFrame( |
494 | 506 | { |
@@ -568,6 +580,14 @@ def process_rollout(rollout: vf.State) -> list[TrainingSample] | None: |
568 | 580 | }, |
569 | 581 | # Env metrics |
570 | 582 | **{f"metrics/{metric}": metrics_df[metric].mean() for metric in metrics_df.columns}, |
| 583 | + # Chinese character metrics |
| 584 | + "chinese/char_count": chinese_df.chinese_chars.sum(), |
| 585 | + "chinese/char_ratio": ( |
| 586 | + chinese_df.chinese_chars.sum() / chinese_df.total_chars.sum() |
| 587 | + if chinese_df.total_chars.sum() > 0 |
| 588 | + else 0.0 |
| 589 | + ), |
| 590 | + "chinese/rollout_ratio": chinese_df.has_chinese.mean(), |
571 | 591 | # Time metrics |
572 | 592 | "time/step": step_time, |
573 | 593 | "time/generate_completions": generate_completions_time, |
|
0 commit comments