Skip to content

Commit e71efc4

Browse files
Nush395Torax team
authored andcommitted
Move static_runtime_params_slice from StepFunction __call__ to constructor.
PiperOrigin-RevId: 775283537
1 parent 2ca8e6b commit e71efc4

File tree

5 files changed

+23
-36
lines changed

5 files changed

+23
-36
lines changed

torax/_src/mhd/sawtooth/tests/sawtooth_model_test.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -111,13 +111,11 @@ def setUp(self):
111111
neoclassical_models=neoclassical_models,
112112
)
113113

114-
self.geometry_provider = torax_config.geometry.build_provider
115-
116-
self.static_runtime_params_slice = (
114+
geometry_provider = torax_config.geometry.build_provider
115+
static_runtime_params_slice = (
117116
build_runtime_params.build_static_params_from_config(torax_config)
118117
)
119-
120-
self.dynamic_runtime_params_slice_provider = (
118+
dynamic_runtime_params_slice_provider = (
121119
build_runtime_params.DynamicRuntimeParamsSliceProvider.from_config(
122120
torax_config
123121
)
@@ -127,24 +125,24 @@ def setUp(self):
127125
solver=solver,
128126
time_step_calculator=torax_config.time_step_calculator.time_step_calculator,
129127
mhd_models=mhd_models,
130-
geometry_provider=self.geometry_provider,
131-
dynamic_runtime_params_slice_provider=self.dynamic_runtime_params_slice_provider,
128+
static_runtime_params_slice=static_runtime_params_slice,
129+
geometry_provider=geometry_provider,
130+
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
132131
)
133132

134133
self.initial_state, self.initial_post_processed_outputs = (
135134
initial_state_lib.get_initial_state_and_post_processed_outputs(
136135
t=torax_config.numerics.t_initial,
137-
static_runtime_params_slice=self.static_runtime_params_slice,
138-
dynamic_runtime_params_slice_provider=self.dynamic_runtime_params_slice_provider,
139-
geometry_provider=self.geometry_provider,
136+
static_runtime_params_slice=static_runtime_params_slice,
137+
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
138+
geometry_provider=geometry_provider,
140139
step_fn=self.step_fn,
141140
)
142141
)
143142

144143
def test_sawtooth_crash(self):
145144
"""Tests that default values lead to crash and compares post-crash to ref."""
146145
output_state, _, sim_error = self.step_fn(
147-
static_runtime_params_slice=self.static_runtime_params_slice,
148146
input_state=self.initial_state,
149147
previous_post_processed_outputs=self.initial_post_processed_outputs,
150148
)
@@ -180,7 +178,6 @@ def test_no_sawtooth_crash(self):
180178
),
181179
)
182180
output_state, _, sim_error = self.step_fn(
183-
static_runtime_params_slice=self.static_runtime_params_slice,
184181
input_state=initial_state,
185182
previous_post_processed_outputs=self.initial_post_processed_outputs,
186183
)
@@ -197,7 +194,6 @@ def test_no_subsequent_sawtooth_crashes(self):
197194
"""Tests for no subsequent sawtooth crashes even if q in trigger condition."""
198195
# This crashes
199196
output_state0, post_processed_outputs0, _ = self.step_fn(
200-
static_runtime_params_slice=self.static_runtime_params_slice,
201197
input_state=self.initial_state,
202198
previous_post_processed_outputs=self.initial_post_processed_outputs,
203199
)
@@ -224,7 +220,6 @@ def test_no_subsequent_sawtooth_crashes(self):
224220

225221
with self.subTest('no_subsequent_sawtooth_crashes'):
226222
output_state_should_not_crash, _, sim_error = self.step_fn(
227-
static_runtime_params_slice=self.static_runtime_params_slice,
228223
input_state=new_input_state_should_not_crash,
229224
previous_post_processed_outputs=post_processed_outputs0,
230225
)
@@ -243,7 +238,6 @@ def test_no_subsequent_sawtooth_crashes(self):
243238

244239
with self.subTest('crashes_if_sawtooth_crash_is_false'):
245240
output_state_should_crash, _, sim_error = self.step_fn(
246-
static_runtime_params_slice=self.static_runtime_params_slice,
247241
input_state=new_input_state_should_crash,
248242
previous_post_processed_outputs=post_processed_outputs0,
249243
)

torax/_src/orchestration/run_loop.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -152,7 +152,6 @@ def run_loop(
152152
_log_timestep(current_state)
153153

154154
current_state, post_processed_outputs, sim_error = step_fn(
155-
static_runtime_params_slice,
156155
current_state,
157156
post_processing_history[-1],
158157
)

torax/_src/orchestration/run_simulation.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ def prepare_simulation(
102102
solver=solver,
103103
time_step_calculator=torax_config.time_step_calculator.time_step_calculator,
104104
mhd_models=mhd_models,
105+
static_runtime_params_slice=static_runtime_params_slice,
105106
geometry_provider=geometry_provider,
106107
dynamic_runtime_params_slice_provider=dynamic_runtime_params_slice_provider,
107108
)

torax/_src/orchestration/step_function.py

Lines changed: 12 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ def __init__(
7676
solver: solver_lib.Solver,
7777
time_step_calculator: ts.TimeStepCalculator,
7878
mhd_models: mhd_base.MHDModels,
79+
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
7980
dynamic_runtime_params_slice_provider: build_runtime_params.DynamicRuntimeParamsSliceProvider,
8081
geometry_provider: geometry_provider_lib.GeometryProvider,
8182
):
@@ -84,7 +85,9 @@ def __init__(
8485
Args:
8586
solver: Evolves the core profiles.
8687
time_step_calculator: Calculates the dt for each time step.
87-
mhd_models: Collection of MHD models applied, e.g. sawtooth
88+
mhd_models: Collection of MHD models applied, e.g. sawtooth.
89+
static_runtime_params_slice: Static parameters that, if they change,
90+
should trigger a recompilation of the SimulationStepFn.
8891
dynamic_runtime_params_slice_provider: Object that returns a set of
8992
runtime parameters which may change from time step to time step or
9093
simulation run to run. If these runtime parameters change, it does NOT
@@ -105,6 +108,7 @@ def __init__(
105108
self._dynamic_runtime_params_slice_provider = (
106109
dynamic_runtime_params_slice_provider
107110
)
111+
self._static_runtime_params_slice = static_runtime_params_slice
108112

109113
@property
110114
def solver(self) -> solver_lib.Solver:
@@ -120,7 +124,6 @@ def time_step_calculator(self) -> ts.TimeStepCalculator:
120124

121125
def __call__(
122126
self,
123-
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
124127
input_state: sim_state.ToraxSimState,
125128
previous_post_processed_outputs: post_processing.PostProcessedOutputs,
126129
) -> tuple[
@@ -136,8 +139,6 @@ def __call__(
136139
sawtooth redistribution, at a t+dt set by the sawtooth model.
137140
138141
Args:
139-
static_runtime_params_slice: Static parameters that, if they change,
140-
should trigger a recompilation of the SimulationStepFn.
141142
input_state: State at the start of the time step, including the core
142143
profiles which are being evolved.
143144
previous_post_processed_outputs: Post-processed outputs from the previous
@@ -172,7 +173,7 @@ def __call__(
172173
# set to 0.
173174
explicit_source_profiles = source_profile_builders.build_source_profiles(
174175
dynamic_runtime_params_slice=dynamic_runtime_params_slice_t,
175-
static_runtime_params_slice=static_runtime_params_slice,
176+
static_runtime_params_slice=self._static_runtime_params_slice,
176177
geo=geo_t,
177178
core_profiles=input_state.core_profiles,
178179
source_models=self.solver.source_models,
@@ -209,7 +210,7 @@ def __call__(
209210
# previous_post_processed_outputs.
210211
output_state, post_processed_outputs = _sawtooth_step(
211212
sawtooth_solver=self.mhd_models.sawtooth,
212-
static_runtime_params_slice=static_runtime_params_slice,
213+
static_runtime_params_slice=self._static_runtime_params_slice,
213214
dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t,
214215
dynamic_runtime_params_slice_t_plus_crash_dt=dynamic_runtime_params_slice_t_plus_crash_dt,
215216
geo_t=geo_t,
@@ -246,7 +247,6 @@ def __call__(
246247

247248
x_new, solver_numeric_outputs = self.step(
248249
dt,
249-
static_runtime_params_slice,
250250
dynamic_runtime_params_slice_t,
251251
dynamic_runtime_params_slice_t_plus_dt,
252252
geo_t,
@@ -255,7 +255,7 @@ def __call__(
255255
explicit_source_profiles,
256256
)
257257

258-
if static_runtime_params_slice.adaptive_dt:
258+
if self._static_runtime_params_slice.adaptive_dt:
259259
# This is a no-op if
260260
# output_state.solver_numeric_outputs.solver_error_state == 0.
261261
(
@@ -268,7 +268,6 @@ def __call__(
268268
x_new,
269269
dt,
270270
solver_numeric_outputs,
271-
static_runtime_params_slice,
272271
dynamic_runtime_params_slice_t,
273272
dynamic_runtime_params_slice_t_plus_dt,
274273
geo_t,
@@ -282,7 +281,7 @@ def __call__(
282281
# the dt may be adaptive, we need to recompute it here.
283282
core_profiles_t_plus_dt = updaters.provide_core_profiles_t_plus_dt(
284283
dt=dt,
285-
static_runtime_params_slice=static_runtime_params_slice,
284+
static_runtime_params_slice=self._static_runtime_params_slice,
286285
dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t,
287286
dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt,
288287
geo_t_plus_dt=geo_t_plus_dt,
@@ -294,7 +293,7 @@ def __call__(
294293
dt=dt,
295294
x_new=x_new,
296295
solver_numeric_outputs=solver_numeric_outputs,
297-
static_runtime_params_slice=self.solver.static_runtime_params_slice,
296+
static_runtime_params_slice=self._static_runtime_params_slice,
298297
dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt,
299298
geometry_t_plus_dt=geo_t_plus_dt,
300299
explicit_source_profiles=explicit_source_profiles,
@@ -377,7 +376,6 @@ def init_time_step_calculator(
377376
def step(
378377
self,
379378
dt: jax.Array,
380-
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
381379
dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice,
382380
dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice,
383381
geo_t: geometry.Geometry,
@@ -395,8 +393,6 @@ def step(
395393
396394
Args:
397395
dt: Time step duration.
398-
static_runtime_params_slice: Static parameters that, if they change,
399-
should trigger a recompilation of the SimulationStepFn.
400396
dynamic_runtime_params_slice_t: Runtime parameters at time t.
401397
dynamic_runtime_params_slice_t_plus_dt: Runtime parameters at time t + dt.
402398
geo_t: The geometry of the torus during this time step of the simulation.
@@ -421,7 +417,7 @@ def step(
421417
# PDE system.
422418
core_profiles_t_plus_dt = updaters.provide_core_profiles_t_plus_dt(
423419
dt=dt,
424-
static_runtime_params_slice=static_runtime_params_slice,
420+
static_runtime_params_slice=self._static_runtime_params_slice,
425421
dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t,
426422
dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt,
427423
geo_t_plus_dt=geo_t_plus_dt,
@@ -447,7 +443,6 @@ def _adaptive_step(
447443
x_old: tuple[cell_variable.CellVariable, ...],
448444
dt: jax.Array,
449445
solver_numeric_outputs: state.SolverNumericOutputs,
450-
static_runtime_params_slice: runtime_params_slice.StaticRuntimeParamsSlice,
451446
dynamic_runtime_params_slice_t: runtime_params_slice.DynamicRuntimeParamsSlice,
452447
dynamic_runtime_params_slice_t_plus_dt: runtime_params_slice.DynamicRuntimeParamsSlice,
453448
geo_t: geometry.Geometry,
@@ -473,8 +468,6 @@ def _adaptive_step(
473468
dt: Time step duration for the initial step.
474469
solver_numeric_outputs: Solver-specific numeric outputs from the initial
475470
step.
476-
static_runtime_params_slice: Static parameters that, if they change,
477-
should trigger a recompilation of the SimulationStepFn.
478471
dynamic_runtime_params_slice_t: Runtime parameters at time t.
479472
dynamic_runtime_params_slice_t_plus_dt: Runtime parameters at time t +
480473
dt.
@@ -558,7 +551,7 @@ def body_fun(
558551

559552
core_profiles_t_plus_dt = updaters.provide_core_profiles_t_plus_dt(
560553
dt=dt,
561-
static_runtime_params_slice=static_runtime_params_slice,
554+
static_runtime_params_slice=self._static_runtime_params_slice,
562555
dynamic_runtime_params_slice_t=dynamic_runtime_params_slice_t,
563556
dynamic_runtime_params_slice_t_plus_dt=dynamic_runtime_params_slice_t_plus_dt,
564557
geo_t_plus_dt=geo_t_plus_dt,

torax/tests/sim_time_dependence_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -124,9 +124,9 @@ def _fake_run_loop(
124124
restart_case,
125125
geometry_provider,
126126
dynamic_runtime_params_slice_provider,
127+
static_runtime_params_slice,
127128
)
128129
output_state, post_processed_outputs, error = step_fn(
129-
static_runtime_params_slice,
130130
initial_state,
131131
initial_post_processed_outputs,
132132
)

0 commit comments

Comments
 (0)