1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414import argparse
15+ import builtins
1516import json
1617import statistics
1718import sys
2324# Custom functions for working with dictionary values
2425def min (value ):
2526 """Return the minimum value in a dictionary."""
26- return __builtins__ .min (float (v ) for v in value .values ())
27+ return builtins .min (float (v ) for v in value .values ())
2728
2829
2930def max (value ):
3031 """Return the maximum value in a dictionary."""
31- return __builtins__ .max (float (v ) for v in value .values ())
32+ return builtins .max (float (v ) for v in value .values ())
3233
3334
34- def mean (value , range_start = 1 , range_end = 0 ):
35+ def ratio_above (value , threshold ):
36+ """Return the ratio of values that are >= threshold.
37+
38+ Args:
39+ value: Dictionary of step -> value
40+ threshold: Threshold value to compare against
41+
42+ Returns:
43+ Float between 0.0 and 1.0 representing the proportion of values >= threshold
44+ """
45+ vals = [float (v ) for v in value .values ()]
46+ if len (vals ) == 0 :
47+ return 0.0
48+ count_above = sum (1 for v in vals if v >= threshold )
49+ return count_above / len (vals )
50+
51+
52+ def mean (value , range_start = 1 , range_end = 0 , ignore_top_p = 0.0 ):
3553 """Return the mean of values (or a range of values) in a dictionary.
3654
3755 Note:
3856 step, and ranges, are 1 indexed. Range_end is exclusive.
3957 range_end=0 means to include until the last step in the run
58+
59+ Args:
60+ value: Dictionary of step -> value
61+ range_start: Starting step (1-indexed, default=1)
62+ range_end: Ending step (1-indexed, exclusive, 0 means last step)
63+ ignore_top_p: Proportion of top outliers to ignore (0.0-1.0, default=0.0)
64+ E.g., 0.05 ignores the top 5% of values
4065 """
4166
4267 ## find potential offset that might arise from resuming from a checkpoint
43- max_step_reached = __builtins__ .max ([int (s ) for s in value .keys ()])
68+ max_step_reached = builtins .max ([int (s ) for s in value .keys ()])
4469 ## this is the number of steps that occurred prior to resuming
4570 offset = max_step_reached - len (value )
4671
@@ -55,6 +80,20 @@ def mean(value, range_start=1, range_end=0):
5580 if range_start <= int (step ) and int (step ) < range_end :
5681 vals .append (float (v ))
5782
83+ # Validate ignore_top_p parameter
84+ if not 0.0 <= ignore_top_p <= 1.0 :
85+ raise ValueError (
86+ f"ignore_top_p must be between 0.0 and 1.0, got { ignore_top_p } "
87+ )
88+
89+ # Filter out top outliers if requested
90+ if ignore_top_p > 0.0 and len (vals ) > 0 :
91+ # Sort values and determine cutoff index
92+ sorted_vals = sorted (vals )
93+ cutoff_idx = int (len (sorted_vals ) * (1.0 - ignore_top_p ))
94+ # Take only values up to the cutoff (excluding top p%)
95+ vals = sorted_vals [:cutoff_idx ] if cutoff_idx > 0 else sorted_vals [:1 ]
96+
5897 return statistics .mean (vals )
5998
6099
@@ -65,17 +104,23 @@ def evaluate_check(data: dict, check: str) -> tuple[bool, str, object]:
65104 Tuple of (passed, message, value)
66105 """
67106 # Create a local context with our custom functions and the data
68- local_context = {"data" : data , "min" : min , "max" : max , "mean" : mean }
107+ local_context = {
108+ "data" : data ,
109+ "min" : min ,
110+ "max" : max ,
111+ "mean" : mean ,
112+ "ratio_above" : ratio_above ,
113+ }
69114
70115 # Extract the value expression from the check
71116 value_expr = check .split (">" )[0 ].split ("<" )[0 ].split ("==" )[0 ].strip ()
72117
73118 try :
74119 # Try to get the value first
75- value = eval (value_expr , {"__builtins__" : __builtins__ }, local_context )
120+ value = eval (value_expr , {"__builtins__" : builtins }, local_context )
76121
77122 # Then evaluate the check
78- result = eval (check , {"__builtins__" : __builtins__ }, local_context )
123+ result = eval (check , {"__builtins__" : builtins }, local_context )
79124 if result :
80125 return True , f"PASS: { check } " , value
81126 else :
@@ -107,6 +152,8 @@ def main():
107152 # Use helper functions
108153 python check_metrics.py results.json "min(data['class_f1']) > 0.6"
109154 python check_metrics.py results.json "mean(data['accuracies']) > 0.85"
155+ python check_metrics.py results.json "mean(data['loss'], ignore_top_p=0.05) < 1.5"
156+ python check_metrics.py results.json "ratio_above(data['error'], 1.05) < 0.02"
110157 """
111158 parser .formatter_class = argparse .RawDescriptionHelpFormatter
112159 args = parser .parse_args ()
0 commit comments