@@ -23,82 +23,82 @@ def extract_formula_from_function(func: Callable):
2323 func_def = next ((n for n in tree .body if isinstance (n , ast .FunctionDef )), None )
2424 if func_def is None :
2525 raise ValueError ("No function definition found in source." )
26- # Find the return statement
27- return_node = next (( n for n in ast .walk (func_def ) if isinstance (n , ast .Return )), None )
28- if return_node is None :
26+ # Find all return statements
27+ return_nodes = [ n for n in ast .walk (func_def ) if isinstance (n , ast .Return )]
28+ if not return_nodes :
2929 raise ValueError ("No return statement found in function." )
30-
31- def _resolve_wrapped_expression (expr : ast .AST ) -> ast .AST :
32- """
33- Attempt to unwrap simple wrappers around the core mathematical expression so
34- the judge sees the actual formula instead of helper constructs.
35- - float(value) -> value
36- - return of a variable -> resolve latest assignment to that variable prior to return
37- """
38- # Unwrap float(<expr>) calls
39- if isinstance (expr , ast .Call ) and isinstance (expr .func , ast .Name ) and expr .func .id == 'float' and len (expr .args ) == 1 :
40- return _resolve_wrapped_expression (expr .args [0 ])
41-
42- # If returning a variable name, try to find its latest assignment before the return
43- if isinstance (expr , ast .Name ):
44- var_name = expr .id
45- return_lineno = getattr (return_node , 'lineno' , None )
46- # Search for last assignment to var_name before the return across all nested blocks
47- last_assigned_expr = None
48- last_assigned_lineno = - 1
49- if return_lineno is not None :
50- for node in ast .walk (func_def ):
51- node_lineno = getattr (node , 'lineno' , 0 )
52- if node_lineno and node_lineno < return_lineno :
53- if isinstance (node , ast .Assign ):
54- # Only consider single-target assigns of a Name
55- if (
56- len (node .targets ) == 1
57- and isinstance (node .targets [0 ], ast .Name )
58- and node .targets [0 ].id == var_name
59- and node_lineno > last_assigned_lineno
60- ):
61- last_assigned_expr = node .value
62- last_assigned_lineno = node_lineno
63- elif isinstance (node , ast .AnnAssign ):
64- if (
65- isinstance (node .target , ast .Name )
66- and node .target .id == var_name
67- and node .value is not None
68- and node_lineno > last_assigned_lineno
69- ):
70- last_assigned_expr = node .value
71- last_assigned_lineno = node_lineno
72- if last_assigned_expr is not None :
73- return _resolve_wrapped_expression (last_assigned_expr )
74- # Fallback to original name if no assignment found
75- return expr
76-
77- # If expression is a conditional (a if cond else b), prefer the non-NaN branch heuristically
78- if isinstance (expr , ast .IfExp ):
79- # Try both branches; prefer the one that isn't a NaN literal
80- def is_nan_literal (e : ast .AST ) -> bool :
81- # Matches float('nan')
82- return (
83- isinstance (e , ast .Call )
84- and isinstance (e .func , ast .Name )
85- and e .func .id == 'float'
86- and len (e .args ) == 1
87- and isinstance (e .args [0 ], ast .Constant )
88- and isinstance (e .args [0 ].value , str )
89- and e .args [0 ].value .lower () == 'nan'
90- )
91- # Prefer body if it isn't NaN, else orelse
92- if not is_nan_literal (expr .body ):
93- return _resolve_wrapped_expression (expr .body )
94- return _resolve_wrapped_expression (expr .orelse )
95-
96- return expr
97-
98- core_expr = _resolve_wrapped_expression (return_node .value )
30+ # Resolve each and collect non-constant candidates
31+ candidates = []
32+ for rn in return_nodes :
33+ resolved = _resolve_wrapped_expression (rn .value , func_def , getattr (rn , 'lineno' , 0 ))
34+ if not isinstance (resolved , ast .Constant ):
35+ candidates .append (resolved )
36+ if not candidates :
37+ raise ValueError ("No non-constant return found." )
38+ # Pick the first non-constant (assuming it's the main formula)
39+ core_expr = candidates [0 ]
9940 formula_str = ast .unparse (core_expr )
10041 return formula_str
10142
43+ def _resolve_wrapped_expression (expr : ast .AST , func_def : ast .FunctionDef , return_lineno : int ) -> ast .AST :
44+ """
45+ Attempt to unwrap simple wrappers around the core mathematical expression so
46+ the judge sees the actual formula instead of helper constructs.
47+ - float(value) -> value
48+ - return of a variable -> resolve latest assignment to that variable prior to return
49+ """
50+ # Unwrap float(<expr>) calls
51+ if isinstance (expr , ast .Call ) and isinstance (expr .func , ast .Name ) and expr .func .id == 'float' and len (expr .args ) == 1 :
52+ return _resolve_wrapped_expression (expr .args [0 ], func_def , return_lineno )
53+ # If returning a variable name, try to find its latest assignment before the return
54+ if isinstance (expr , ast .Name ):
55+ var_name = expr .id
56+ last_assigned_expr = None
57+ last_assigned_lineno = - 1
58+ for node in ast .walk (func_def ):
59+ node_lineno = getattr (node , 'lineno' , 0 )
60+ if node_lineno and node_lineno < return_lineno :
61+ if isinstance (node , ast .Assign ):
62+ if (
63+ len (node .targets ) == 1
64+ and isinstance (node .targets [0 ], ast .Name )
65+ and node .targets [0 ].id == var_name
66+ and node_lineno > last_assigned_lineno
67+ ):
68+ last_assigned_expr = node .value
69+ last_assigned_lineno = node_lineno
70+ elif isinstance (node , ast .AnnAssign ):
71+ if (
72+ isinstance (node .target , ast .Name )
73+ and node .target .id == var_name
74+ and node .value is not None
75+ and node_lineno > last_assigned_lineno
76+ ):
77+ last_assigned_expr = node .value
78+ last_assigned_lineno = node_lineno
79+ if last_assigned_expr is not None :
80+ return _resolve_wrapped_expression (last_assigned_expr , func_def , return_lineno )
81+ # Fallback to original name if no assignment found
82+ return expr
83+ # If expression is a conditional (a if cond else b), prefer the non-NaN branch heuristically
84+ if isinstance (expr , ast .IfExp ):
85+ def is_nan_literal (e : ast .AST ) -> bool :
86+ # Matches float('nan')
87+ return (
88+ isinstance (e , ast .Call )
89+ and isinstance (e .func , ast .Name )
90+ and e .func .id == 'float'
91+ and len (e .args ) == 1
92+ and isinstance (e .args [0 ], ast .Constant )
93+ and isinstance (e .args [0 ].value , str )
94+ and e .args [0 ].value .lower () == 'nan'
95+ )
96+ # Prefer body if it isn't NaN, else orelse
97+ if not is_nan_literal (expr .body ):
98+ return _resolve_wrapped_expression (expr .body , func_def , return_lineno )
99+ return _resolve_wrapped_expression (expr .orelse , func_def , return_lineno )
100+ return expr
101+
102102def calculate_rmsle (y_true , y_pred ):
103103 """
104104 Calculate Root Mean Squared Logarithmic Error (RMSLE) between true and predicted values.
0 commit comments