Skip to content

Commit 85eeb8d

Browse files
authored
fix: more robust fp8 rollout metric check (#1307)
Signed-off-by: Terry Kong <terryk@nvidia.com>
1 parent 905a224 commit 85eeb8d

File tree

5 files changed

+469
-14
lines changed

5 files changed

+469
-14
lines changed

examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v2.yaml renamed to examples/configs/recipes/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.yaml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
defaults: ../../grpo_math_1B.yaml
22
grpo:
3-
num_prompts_per_step: 64
43
num_generations_per_prompt: 32
54
max_num_steps: 500
65
loss_fn:

tests/check_metrics.py

Lines changed: 54 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
import argparse
15+
import builtins
1516
import json
1617
import statistics
1718
import sys
@@ -23,24 +24,48 @@
2324
# Custom functions for working with dictionary values
2425
def min(value):
2526
"""Return the minimum value in a dictionary."""
26-
return __builtins__.min(float(v) for v in value.values())
27+
return builtins.min(float(v) for v in value.values())
2728

2829

2930
def max(value):
3031
"""Return the maximum value in a dictionary."""
31-
return __builtins__.max(float(v) for v in value.values())
32+
return builtins.max(float(v) for v in value.values())
3233

3334

34-
def mean(value, range_start=1, range_end=0):
35+
def ratio_above(value, threshold):
36+
"""Return the ratio of values that are >= threshold.
37+
38+
Args:
39+
value: Dictionary of step -> value
40+
threshold: Threshold value to compare against
41+
42+
Returns:
43+
Float between 0.0 and 1.0 representing the proportion of values >= threshold
44+
"""
45+
vals = [float(v) for v in value.values()]
46+
if len(vals) == 0:
47+
return 0.0
48+
count_above = sum(1 for v in vals if v >= threshold)
49+
return count_above / len(vals)
50+
51+
52+
def mean(value, range_start=1, range_end=0, ignore_top_p=0.0):
3553
"""Return the mean of values (or a range of values) in a dictionary.
3654
3755
Note:
3856
step, and ranges, are 1 indexed. Range_end is exclusive.
3957
range_end=0 means to include until the last step in the run
58+
59+
Args:
60+
value: Dictionary of step -> value
61+
range_start: Starting step (1-indexed, default=1)
62+
range_end: Ending step (1-indexed, exclusive, 0 means last step)
63+
ignore_top_p: Proportion of top outliers to ignore (0.0-1.0, default=0.0)
64+
E.g., 0.05 ignores the top 5% of values
4065
"""
4166

4267
## find potential offset that might arise from resuming from a checkpoint
43-
max_step_reached = __builtins__.max([int(s) for s in value.keys()])
68+
max_step_reached = builtins.max([int(s) for s in value.keys()])
4469
## this is the number of steps that occurred prior to resuming
4570
offset = max_step_reached - len(value)
4671

@@ -55,6 +80,20 @@ def mean(value, range_start=1, range_end=0):
5580
if range_start <= int(step) and int(step) < range_end:
5681
vals.append(float(v))
5782

83+
# Validate ignore_top_p parameter
84+
if not 0.0 <= ignore_top_p <= 1.0:
85+
raise ValueError(
86+
f"ignore_top_p must be between 0.0 and 1.0, got {ignore_top_p}"
87+
)
88+
89+
# Filter out top outliers if requested
90+
if ignore_top_p > 0.0 and len(vals) > 0:
91+
# Sort values and determine cutoff index
92+
sorted_vals = sorted(vals)
93+
cutoff_idx = int(len(sorted_vals) * (1.0 - ignore_top_p))
94+
# Take only values up to the cutoff (excluding top p%)
95+
vals = sorted_vals[:cutoff_idx] if cutoff_idx > 0 else sorted_vals[:1]
96+
5897
return statistics.mean(vals)
5998

6099

@@ -65,17 +104,23 @@ def evaluate_check(data: dict, check: str) -> tuple[bool, str, object]:
65104
Tuple of (passed, message, value)
66105
"""
67106
# Create a local context with our custom functions and the data
68-
local_context = {"data": data, "min": min, "max": max, "mean": mean}
107+
local_context = {
108+
"data": data,
109+
"min": min,
110+
"max": max,
111+
"mean": mean,
112+
"ratio_above": ratio_above,
113+
}
69114

70115
# Extract the value expression from the check
71116
value_expr = check.split(">")[0].split("<")[0].split("==")[0].strip()
72117

73118
try:
74119
# Try to get the value first
75-
value = eval(value_expr, {"__builtins__": __builtins__}, local_context)
120+
value = eval(value_expr, {"__builtins__": builtins}, local_context)
76121

77122
# Then evaluate the check
78-
result = eval(check, {"__builtins__": __builtins__}, local_context)
123+
result = eval(check, {"__builtins__": builtins}, local_context)
79124
if result:
80125
return True, f"PASS: {check}", value
81126
else:
@@ -107,6 +152,8 @@ def main():
107152
# Use helper functions
108153
python check_metrics.py results.json "min(data['class_f1']) > 0.6"
109154
python check_metrics.py results.json "mean(data['accuracies']) > 0.85"
155+
python check_metrics.py results.json "mean(data['loss'], ignore_top_p=0.05) < 1.5"
156+
python check_metrics.py results.json "ratio_above(data['error'], 1.05) < 0.02"
110157
"""
111158
parser.formatter_class = argparse.RawDescriptionHelpFormatter
112159
args = parser.parse_args()

tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v2.sh renamed to tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,10 +4,10 @@ source $SCRIPT_DIR/common.env
44

55
# ===== BEGIN CONFIG =====
66
NUM_NODES=1
7-
STEPS_PER_RUN=40
8-
MAX_STEPS=40
7+
STEPS_PER_RUN=100
8+
MAX_STEPS=100
99
NUM_RUNS=$(( (MAX_STEPS + STEPS_PER_RUN - 1) / STEPS_PER_RUN )) # Round up
10-
NUM_MINUTES=120
10+
NUM_MINUTES=180
1111
# ===== END CONFIG =====
1212

1313
exit_if_max_steps_reached
@@ -33,7 +33,9 @@ uv run tests/json_dump_tb_logs.py $LOG_DIR --output_path $JSON_METRICS
3333

3434
# Only run metrics if the target step is reached
3535
if [[ $(jq 'to_entries | .[] | select(.key == "train/loss") | .value | keys | map(tonumber) | max' $JSON_METRICS) -ge $MAX_STEPS ]]; then
36+
# With a few number of steps the logprob can have spikes that can move the average up.
3637
uv run tests/check_metrics.py $JSON_METRICS \
37-
'mean(data["train/token_mult_prob_error"]) < 1.1' \
38-
'data["train/token_mult_prob_error"]["40"] < 1.1'
38+
'mean(data["train/token_mult_prob_error"], ignore_top_p=0.05) < 1.1' \
39+
'ratio_above(data["train/token_mult_prob_error"], 1.1) < 0.1'
40+
# ratio_above @ 1.1 was 0.03,0.06,0.05: 3sigma ~=0.1
3941
fi

tests/test_suites/nightly.txt

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@ tests/test_suites/llm/grpo-gspo-deepscaler-1.5b-8K.sh
3838
tests/test_suites/llm/grpo-math-qwen3-30ba3b-megatron-tp4-32k.sh
3939

4040
# FP8
41-
tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v2.sh
41+
tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-rollouts.v3.sh
4242
tests/test_suites/llm/grpo-llama3.1-8b-instruct-1n8g-megatron-fp8-e2e.sh
4343

4444
# Non-colocated

0 commit comments

Comments
 (0)