Skip to content

Commit 15ba3a1

Browse files
committed
track chinese characters
1 parent e5566f9 commit 15ba3a1

File tree

2 files changed

+45
-0
lines changed

2 files changed

+45
-0
lines changed

src/prime_rl/orchestrator/orchestrator.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,7 @@
5353
from prime_rl.utils.temp_scheduling import compute_temperature
5454
from prime_rl.utils.utils import (
5555
clean_exit,
56+
count_chinese_chars,
5657
get_env_ids_to_install,
5758
install_env,
5859
resolve_latest_ckpt_step,
@@ -489,6 +490,17 @@ def process_rollout(rollout: vf.State) -> list[TrainingSample] | None:
489490
# Gather individual reward function metrics
490491
metrics_df = pd.DataFrame([rollout["metrics"] for rollout in train_rollouts])
491492

493+
# Count Chinese characters in completions
494+
chinese_stats = []
495+
for rollout in train_rollouts:
496+
trajectory = rollout["trajectory"]
497+
last_step = trajectory[-1]
498+
tokens = last_step["tokens"]
499+
completion_text = tokenizer.decode(tokens["completion_ids"])
500+
chinese_count, total_count = count_chinese_chars(completion_text)
501+
chinese_stats.append({"chinese_chars": chinese_count, "total_chars": total_count, "has_chinese": chinese_count > 0})
502+
chinese_df = pd.DataFrame(chinese_stats)
503+
492504
val_results_df = (
493505
pd.DataFrame(
494506
{
@@ -568,6 +580,14 @@ def process_rollout(rollout: vf.State) -> list[TrainingSample] | None:
568580
},
569581
# Env metrics
570582
**{f"metrics/{metric}": metrics_df[metric].mean() for metric in metrics_df.columns},
583+
# Chinese character metrics
584+
"chinese/char_count": chinese_df.chinese_chars.sum(),
585+
"chinese/char_ratio": (
586+
chinese_df.chinese_chars.sum() / chinese_df.total_chars.sum()
587+
if chinese_df.total_chars.sum() > 0
588+
else 0.0
589+
),
590+
"chinese/rollout_ratio": chinese_df.has_chinese.mean(),
571591
# Time metrics
572592
"time/step": step_time,
573593
"time/generate_completions": generate_completions_time,

src/prime_rl/utils/utils.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -300,3 +300,28 @@ def get_env_ids_to_install(env_configs: list[EnvConfig] | list[EvalEnvConfig]) -
300300
if "/" in env_config.id:
301301
env_ids_to_install.add(env_config.id)
302302
return env_ids_to_install
303+
304+
305+
def is_chinese_char(char: str) -> bool:
306+
"""Check if a character is Chinese (CJK Unified Ideographs and extensions)."""
307+
code_point = ord(char)
308+
return (
309+
0x4E00 <= code_point <= 0x9FFF # CJK Unified Ideographs
310+
or 0x3400 <= code_point <= 0x4DBF # CJK Extension A
311+
or 0x20000 <= code_point <= 0x2A6DF # CJK Extension B
312+
or 0x2A700 <= code_point <= 0x2B73F # CJK Extension C
313+
or 0x2B740 <= code_point <= 0x2B81F # CJK Extension D
314+
or 0x2B820 <= code_point <= 0x2CEAF # CJK Extension E
315+
or 0xF900 <= code_point <= 0xFAFF # CJK Compatibility Ideographs
316+
or 0x2F800 <= code_point <= 0x2FA1F # CJK Compatibility Ideographs Supplement
317+
)
318+
319+
320+
def count_chinese_chars(text: str) -> tuple[int, int]:
321+
"""Count Chinese characters in text.
322+
323+
Returns:
324+
Tuple of (chinese_char_count, total_char_count)
325+
"""
326+
chinese_count = sum(1 for char in text if is_chinese_char(char))
327+
return chinese_count, len(text)

0 commit comments

Comments
 (0)