Skip to content

Commit 8ad4404

Browse files
committed
Update the modules
1 parent a11f134 commit 8ad4404

File tree

40 files changed

+675
-1454
lines changed

40 files changed

+675
-1454
lines changed

modules/common/evaluation.py

Lines changed: 72 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -23,82 +23,82 @@ def extract_formula_from_function(func: Callable):
2323
func_def = next((n for n in tree.body if isinstance(n, ast.FunctionDef)), None)
2424
if func_def is None:
2525
raise ValueError("No function definition found in source.")
26-
# Find the return statement
27-
return_node = next((n for n in ast.walk(func_def) if isinstance(n, ast.Return)), None)
28-
if return_node is None:
26+
# Find all return statements
27+
return_nodes = [n for n in ast.walk(func_def) if isinstance(n, ast.Return)]
28+
if not return_nodes:
2929
raise ValueError("No return statement found in function.")
30-
31-
def _resolve_wrapped_expression(expr: ast.AST) -> ast.AST:
32-
"""
33-
Attempt to unwrap simple wrappers around the core mathematical expression so
34-
the judge sees the actual formula instead of helper constructs.
35-
- float(value) -> value
36-
- return of a variable -> resolve latest assignment to that variable prior to return
37-
"""
38-
# Unwrap float(<expr>) calls
39-
if isinstance(expr, ast.Call) and isinstance(expr.func, ast.Name) and expr.func.id == 'float' and len(expr.args) == 1:
40-
return _resolve_wrapped_expression(expr.args[0])
41-
42-
# If returning a variable name, try to find its latest assignment before the return
43-
if isinstance(expr, ast.Name):
44-
var_name = expr.id
45-
return_lineno = getattr(return_node, 'lineno', None)
46-
# Search for last assignment to var_name before the return across all nested blocks
47-
last_assigned_expr = None
48-
last_assigned_lineno = -1
49-
if return_lineno is not None:
50-
for node in ast.walk(func_def):
51-
node_lineno = getattr(node, 'lineno', 0)
52-
if node_lineno and node_lineno < return_lineno:
53-
if isinstance(node, ast.Assign):
54-
# Only consider single-target assigns of a Name
55-
if (
56-
len(node.targets) == 1
57-
and isinstance(node.targets[0], ast.Name)
58-
and node.targets[0].id == var_name
59-
and node_lineno > last_assigned_lineno
60-
):
61-
last_assigned_expr = node.value
62-
last_assigned_lineno = node_lineno
63-
elif isinstance(node, ast.AnnAssign):
64-
if (
65-
isinstance(node.target, ast.Name)
66-
and node.target.id == var_name
67-
and node.value is not None
68-
and node_lineno > last_assigned_lineno
69-
):
70-
last_assigned_expr = node.value
71-
last_assigned_lineno = node_lineno
72-
if last_assigned_expr is not None:
73-
return _resolve_wrapped_expression(last_assigned_expr)
74-
# Fallback to original name if no assignment found
75-
return expr
76-
77-
# If expression is a conditional (a if cond else b), prefer the non-NaN branch heuristically
78-
if isinstance(expr, ast.IfExp):
79-
# Try both branches; prefer the one that isn't a NaN literal
80-
def is_nan_literal(e: ast.AST) -> bool:
81-
# Matches float('nan')
82-
return (
83-
isinstance(e, ast.Call)
84-
and isinstance(e.func, ast.Name)
85-
and e.func.id == 'float'
86-
and len(e.args) == 1
87-
and isinstance(e.args[0], ast.Constant)
88-
and isinstance(e.args[0].value, str)
89-
and e.args[0].value.lower() == 'nan'
90-
)
91-
# Prefer body if it isn't NaN, else orelse
92-
if not is_nan_literal(expr.body):
93-
return _resolve_wrapped_expression(expr.body)
94-
return _resolve_wrapped_expression(expr.orelse)
95-
96-
return expr
97-
98-
core_expr = _resolve_wrapped_expression(return_node.value)
30+
# Resolve each and collect non-constant candidates
31+
candidates = []
32+
for rn in return_nodes:
33+
resolved = _resolve_wrapped_expression(rn.value, func_def, getattr(rn, 'lineno', 0))
34+
if not isinstance(resolved, ast.Constant):
35+
candidates.append(resolved)
36+
if not candidates:
37+
raise ValueError("No non-constant return found.")
38+
# Pick the first non-constant (assuming it's the main formula)
39+
core_expr = candidates[0]
9940
formula_str = ast.unparse(core_expr)
10041
return formula_str
10142

43+
def _resolve_wrapped_expression(expr: ast.AST, func_def: ast.FunctionDef, return_lineno: int) -> ast.AST:
44+
"""
45+
Attempt to unwrap simple wrappers around the core mathematical expression so
46+
the judge sees the actual formula instead of helper constructs.
47+
- float(value) -> value
48+
- return of a variable -> resolve latest assignment to that variable prior to return
49+
"""
50+
# Unwrap float(<expr>) calls
51+
if isinstance(expr, ast.Call) and isinstance(expr.func, ast.Name) and expr.func.id == 'float' and len(expr.args) == 1:
52+
return _resolve_wrapped_expression(expr.args[0], func_def, return_lineno)
53+
# If returning a variable name, try to find its latest assignment before the return
54+
if isinstance(expr, ast.Name):
55+
var_name = expr.id
56+
last_assigned_expr = None
57+
last_assigned_lineno = -1
58+
for node in ast.walk(func_def):
59+
node_lineno = getattr(node, 'lineno', 0)
60+
if node_lineno and node_lineno < return_lineno:
61+
if isinstance(node, ast.Assign):
62+
if (
63+
len(node.targets) == 1
64+
and isinstance(node.targets[0], ast.Name)
65+
and node.targets[0].id == var_name
66+
and node_lineno > last_assigned_lineno
67+
):
68+
last_assigned_expr = node.value
69+
last_assigned_lineno = node_lineno
70+
elif isinstance(node, ast.AnnAssign):
71+
if (
72+
isinstance(node.target, ast.Name)
73+
and node.target.id == var_name
74+
and node.value is not None
75+
and node_lineno > last_assigned_lineno
76+
):
77+
last_assigned_expr = node.value
78+
last_assigned_lineno = node_lineno
79+
if last_assigned_expr is not None:
80+
return _resolve_wrapped_expression(last_assigned_expr, func_def, return_lineno)
81+
# Fallback to original name if no assignment found
82+
return expr
83+
# If expression is a conditional (a if cond else b), prefer the non-NaN branch heuristically
84+
if isinstance(expr, ast.IfExp):
85+
def is_nan_literal(e: ast.AST) -> bool:
86+
# Matches float('nan')
87+
return (
88+
isinstance(e, ast.Call)
89+
and isinstance(e.func, ast.Name)
90+
and e.func.id == 'float'
91+
and len(e.args) == 1
92+
and isinstance(e.args[0], ast.Constant)
93+
and isinstance(e.args[0].value, str)
94+
and e.args[0].value.lower() == 'nan'
95+
)
96+
# Prefer body if it isn't NaN, else orelse
97+
if not is_nan_literal(expr.body):
98+
return _resolve_wrapped_expression(expr.body, func_def, return_lineno)
99+
return _resolve_wrapped_expression(expr.orelse, func_def, return_lineno)
100+
return expr
101+
102102
def calculate_rmsle(y_true, y_pred):
103103
"""
104104
Calculate Root Mean Squared Logarithmic Error (RMSLE) between true and predicted values.

modules/m0_gravity/core.py

Lines changed: 3 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -59,10 +59,8 @@ def _run_orbital_experiment(
5959
Returns:
6060
dict: Time series data with keys 'time', 'position', 'velocity', all as JSON-serializable lists (no NumPy arrays).
6161
"""
62-
# Initialize arrays for time series data
6362
num_steps = int(duration / time_step)
6463

65-
# Validate parameters to prevent empty arrays
6664
if num_steps <= 0:
6765
return {
6866
'time': [],
@@ -74,50 +72,39 @@ def _run_orbital_experiment(
7472
positions = np.zeros((num_steps, 2))
7573
velocities = np.zeros((num_steps, 2))
7674

77-
# Set initial conditions
78-
# mass1 at origin, mass2 at (distance, 0)
7975
positions[0] = np.array([distance, 0.0])
80-
# Initial velocity perpendicular to radius (for circular orbit)
8176
velocities[0] = np.array([0.0, initial_velocity])
8277

83-
# Time evolution using Verlet integration
8478
for i in range(1, num_steps):
85-
# Calculate acceleration
8679
acc = calculate_acceleration_2d(
8780
mass1, mass2,
88-
np.array([0.0, 0.0]), # mass1 at origin
81+
np.array([0.0, 0.0]),
8982
positions[i-1],
9083
force_law
91-
)[1] # We only need acc of mass2
84+
)[1]
9285

93-
# Update position and velocity
9486
pos_new, vel_half = verlet_integration_2d(
9587
positions[i-1],
9688
velocities[i-1],
9789
acc,
9890
time_step
9991
)
10092

101-
# Calculate new acceleration for velocity update
10293
acc_new = calculate_acceleration_2d(
10394
mass1, mass2,
10495
np.array([0.0, 0.0]),
10596
pos_new,
10697
force_law
10798
)[1]
10899

109-
# Final velocity update
110100
vel_new = vel_half + 0.5 * acc_new * time_step
111101

112-
# Store results
113102
positions[i] = pos_new
114103
velocities[i] = vel_new
115104

116-
# Add noise to measurements
117105
noisy_positions = inject_noise(positions, noise_level, ABSOLUTE_POSITION_PRECISION)
118106
noisy_velocities = inject_noise(velocities, noise_level, ABSOLUTE_VELOCITY_PRECISION)
119107

120-
# Downsample to at most 20 data points
121108
max_points = 20
122109
if len(times) > max_points:
123110
times = times[:max_points]
@@ -157,10 +144,8 @@ def _run_linear_experiment(
157144
Returns:
158145
dict: Time series data with keys 'time', 'position', 'velocity', all as JSON-serializable lists (no NumPy arrays).
159146
"""
160-
# Initialize arrays for time series data
161147
num_steps = int(duration / time_step)
162148

163-
# Validate parameters to prevent empty arrays
164149
if num_steps <= 0:
165150
return {
166151
'time': [],
@@ -173,48 +158,39 @@ def _run_linear_experiment(
173158
velocities = np.zeros(num_steps)
174159
accelerations = np.zeros(num_steps)
175160

176-
# Set initial conditions
177161
positions[0] = distance
178162
velocities[0] = initial_velocity
179163

180-
# Initial acceleration
181164
acc0 = calculate_acceleration_1d(
182165
mass1, mass2,
183166
positions[0],
184167
force_law
185-
)[1] # We only need acc of mass2
168+
)[1]
186169
accelerations[0] = acc0
187170

188-
# Time evolution using Verlet integration
189171
for i in range(1, num_steps):
190-
# Update position and velocity
191172
pos_new, vel_half = verlet_integration_1d(
192173
positions[i-1],
193174
velocities[i-1],
194175
accelerations[i-1],
195176
time_step
196177
)
197178

198-
# Calculate new acceleration for velocity update
199179
acc_new = calculate_acceleration_1d(
200180
mass1, mass2,
201181
pos_new,
202182
force_law
203183
)[1]
204184

205-
# Final velocity update
206185
vel_new = vel_half + 0.5 * acc_new * time_step
207186

208-
# Store results
209187
positions[i] = pos_new
210188
velocities[i] = vel_new
211189
accelerations[i] = acc_new
212190

213-
# Add noise to measurements
214191
noisy_positions = inject_noise(positions, noise_level, ABSOLUTE_POSITION_PRECISION)
215192
noisy_velocities = inject_noise(velocities, noise_level, ABSOLUTE_VELOCITY_PRECISION)
216193

217-
# Downsample to at most 20 data points
218194
max_points = 20
219195
if len(times) > max_points:
220196
times = times[:max_points]

modules/m0_gravity/laws.py

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -3,64 +3,64 @@
33
from typing import Dict, List, Tuple, Callable, Optional
44

55
# --- Environment Constants ---
6-
HIDDEN_CONSTANT_C = 6.674e-5
6+
HIDDEN_CONSTANT = 6.674e-5
77

88
# --- v0 laws ---
99
def _ground_truth_law_easy_v0(mass1: float, mass2: float, distance: float) -> float:
1010
"""Easy law: F = C * m1 * m2 / r^1.5"""
1111
if distance <= 0 or mass1 <= 0 or mass2 <= 0:
1212
return 0.0
13-
return (HIDDEN_CONSTANT_C * mass1 * mass2) / (distance ** 1.5)
13+
return (HIDDEN_CONSTANT * mass1 * mass2) / (distance ** 1.5)
1414

1515
def _ground_truth_law_medium_v0(mass1: float, mass2: float, distance: float) -> float:
1616
"""Medium law: F = C * (m1 * m2)^2 / r^1.5"""
1717
if distance <= 0 or mass1 <= 0 or mass2 <= 0:
1818
return 0.0
19-
return (HIDDEN_CONSTANT_C * (mass1 * mass2) ** 2) / (distance ** 1.5)
19+
return (HIDDEN_CONSTANT * (mass1 * mass2) ** 2) / (distance ** 1.5)
2020

2121
def _ground_truth_law_hard_v0(mass1: float, mass2: float, distance: float) -> float:
2222
"""Hard law: F = C * (m1 + m2)^2 / r^1.5"""
2323
if distance <= 0 or mass1 <= 0 or mass2 <= 0:
2424
return 0.0
25-
return (HIDDEN_CONSTANT_C * (mass1 + mass2) ** 2) / (distance ** 1.5)
25+
return (HIDDEN_CONSTANT * (mass1 + mass2) ** 2) / (distance ** 1.5)
2626

2727
# --- v1 laws ---
2828
def _ground_truth_law_easy_v1(mass1: float, mass2: float, distance: float) -> float:
2929
"""Easy law: F = C * m1 / r^2"""
3030
if distance <= 0 or mass1 <= 0 or mass2 <= 0:
3131
return 0.0
32-
return (HIDDEN_CONSTANT_C * mass1) / (distance ** 2)
32+
return (HIDDEN_CONSTANT * mass1) / (distance ** 2)
3333

3434
def _ground_truth_law_medium_v1(mass1: float, mass2: float, distance: float) -> float:
3535
"""Medium law: F = C * m1 / r^2.6"""
3636
if distance <= 0 or mass1 <= 0 or mass2 <= 0:
3737
return 0.0
38-
return (HIDDEN_CONSTANT_C * mass1) / (distance ** 2.6)
38+
return (HIDDEN_CONSTANT * mass1) / (distance ** 2.6)
3939

4040
def _ground_truth_law_hard_v1(mass1: float, mass2: float, distance: float) -> float:
4141
"""Hard law: F = C * m1^1.3 / r^2.6"""
4242
if distance <= 0 or mass1 <= 0 or mass2 <= 0:
4343
return 0.0
44-
return (HIDDEN_CONSTANT_C * mass1 ** 1.3) / (distance ** 2.6)
44+
return (HIDDEN_CONSTANT * mass1 ** 1.3) / (distance ** 2.6)
4545

4646
# --- v2 laws ---
4747
def _ground_truth_law_easy_v2(mass1: float, mass2: float, distance: float) -> float:
4848
"""Easy law: F = C * (m1^2 * m2^2) / r^2"""
4949
if distance <= 0 or mass1 <= 0 or mass2 <= 0:
5050
return 0.0
51-
return (HIDDEN_CONSTANT_C * (mass1 ** 2 * mass2 ** 2)) / (distance ** 2)
51+
return (HIDDEN_CONSTANT * (mass1 ** 2 * mass2 ** 2)) / (distance ** 2)
5252

5353
def _ground_truth_law_medium_v2(mass1: float, mass2: float, distance: float) -> float:
5454
"""Medium law: F = C * (m1^2 * m2^2) * r^2"""
5555
if distance <= 0 or mass1 <= 0 or mass2 <= 0:
5656
return 0.0
57-
return (HIDDEN_CONSTANT_C * (mass1 ** 2 * mass2 ** 2)) * (distance ** 2)
57+
return (HIDDEN_CONSTANT * (mass1 ** 2 * mass2 ** 2)) * (distance ** 2)
5858

5959
def _ground_truth_law_hard_v2(mass1: float, mass2: float, distance: float) -> float:
6060
"""Hard law: F = C * (m1^2 + m2^2) * r^2"""
6161
if distance <= 0 or mass1 <= 0 or mass2 <= 0:
6262
return 0.0
63-
return (HIDDEN_CONSTANT_C * (mass1 ** 2 + mass2 ** 2)) * (distance ** 2)
63+
return (HIDDEN_CONSTANT * (mass1 ** 2 + mass2 ** 2)) * (distance ** 2)
6464

6565
# --- Law Registry ---
6666

modules/m0_gravity/m0_types.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -2,12 +2,12 @@
22

33
# Default parameters for orbital motion
44
TWO_DIM_DEFAULTS = {
5-
'time_step': 0.1, # seconds
6-
'duration': 10.0, # seconds
5+
'time_step': 0.1,
6+
'duration': 10.0,
77
}
88

99
# Default parameters for linear motion
1010
LINEAR_DEFAULTS = {
11-
'time_step': 0.01, # seconds
12-
'duration': 5.0, # seconds
11+
'time_step': 0.01,
12+
'duration': 5.0,
1313
}

0 commit comments

Comments
 (0)