@@ -76,6 +76,7 @@ def __init__(
7676 solver : solver_lib .Solver ,
7777 time_step_calculator : ts .TimeStepCalculator ,
7878 mhd_models : mhd_base .MHDModels ,
79+ static_runtime_params_slice : runtime_params_slice .StaticRuntimeParamsSlice ,
7980 dynamic_runtime_params_slice_provider : build_runtime_params .DynamicRuntimeParamsSliceProvider ,
8081 geometry_provider : geometry_provider_lib .GeometryProvider ,
8182 ):
@@ -84,7 +85,9 @@ def __init__(
8485 Args:
8586 solver: Evolves the core profiles.
8687 time_step_calculator: Calculates the dt for each time step.
87- mhd_models: Collection of MHD models applied, e.g. sawtooth
88+ mhd_models: Collection of MHD models applied, e.g. sawtooth.
89+ static_runtime_params_slice: Static parameters that, if they change,
90+ should trigger a recompilation of the SimulationStepFn.
8891 dynamic_runtime_params_slice_provider: Object that returns a set of
8992 runtime parameters which may change from time step to time step or
9093 simulation run to run. If these runtime parameters change, it does NOT
@@ -105,6 +108,7 @@ def __init__(
105108 self ._dynamic_runtime_params_slice_provider = (
106109 dynamic_runtime_params_slice_provider
107110 )
111+ self ._static_runtime_params_slice = static_runtime_params_slice
108112
109113 @property
110114 def solver (self ) -> solver_lib .Solver :
@@ -120,7 +124,6 @@ def time_step_calculator(self) -> ts.TimeStepCalculator:
120124
121125 def __call__ (
122126 self ,
123- static_runtime_params_slice : runtime_params_slice .StaticRuntimeParamsSlice ,
124127 input_state : sim_state .ToraxSimState ,
125128 previous_post_processed_outputs : post_processing .PostProcessedOutputs ,
126129 ) -> tuple [
@@ -136,8 +139,6 @@ def __call__(
136139 sawtooth redistribution, at a t+dt set by the sawtooth model.
137140
138141 Args:
139- static_runtime_params_slice: Static parameters that, if they change,
140- should trigger a recompilation of the SimulationStepFn.
141142 input_state: State at the start of the time step, including the core
142143 profiles which are being evolved.
143144 previous_post_processed_outputs: Post-processed outputs from the previous
@@ -172,7 +173,7 @@ def __call__(
172173 # set to 0.
173174 explicit_source_profiles = source_profile_builders .build_source_profiles (
174175 dynamic_runtime_params_slice = dynamic_runtime_params_slice_t ,
175- static_runtime_params_slice = static_runtime_params_slice ,
176+ static_runtime_params_slice = self . _static_runtime_params_slice ,
176177 geo = geo_t ,
177178 core_profiles = input_state .core_profiles ,
178179 source_models = self .solver .source_models ,
@@ -209,7 +210,7 @@ def __call__(
209210 # previous_post_processed_outputs.
210211 output_state , post_processed_outputs = _sawtooth_step (
211212 sawtooth_solver = self .mhd_models .sawtooth ,
212- static_runtime_params_slice = static_runtime_params_slice ,
213+ static_runtime_params_slice = self . _static_runtime_params_slice ,
213214 dynamic_runtime_params_slice_t = dynamic_runtime_params_slice_t ,
214215 dynamic_runtime_params_slice_t_plus_crash_dt = dynamic_runtime_params_slice_t_plus_crash_dt ,
215216 geo_t = geo_t ,
@@ -246,7 +247,6 @@ def __call__(
246247
247248 x_new , solver_numeric_outputs = self .step (
248249 dt ,
249- static_runtime_params_slice ,
250250 dynamic_runtime_params_slice_t ,
251251 dynamic_runtime_params_slice_t_plus_dt ,
252252 geo_t ,
@@ -255,7 +255,7 @@ def __call__(
255255 explicit_source_profiles ,
256256 )
257257
258- if static_runtime_params_slice .adaptive_dt :
258+ if self . _static_runtime_params_slice .adaptive_dt :
259259 # This is a no-op if
260260 # output_state.solver_numeric_outputs.solver_error_state == 0.
261261 (
@@ -268,7 +268,6 @@ def __call__(
268268 x_new ,
269269 dt ,
270270 solver_numeric_outputs ,
271- static_runtime_params_slice ,
272271 dynamic_runtime_params_slice_t ,
273272 dynamic_runtime_params_slice_t_plus_dt ,
274273 geo_t ,
@@ -282,7 +281,7 @@ def __call__(
282281 # the dt may be adaptive, we need to recompute it here.
283282 core_profiles_t_plus_dt = updaters .provide_core_profiles_t_plus_dt (
284283 dt = dt ,
285- static_runtime_params_slice = static_runtime_params_slice ,
284+ static_runtime_params_slice = self . _static_runtime_params_slice ,
286285 dynamic_runtime_params_slice_t = dynamic_runtime_params_slice_t ,
287286 dynamic_runtime_params_slice_t_plus_dt = dynamic_runtime_params_slice_t_plus_dt ,
288287 geo_t_plus_dt = geo_t_plus_dt ,
@@ -294,7 +293,7 @@ def __call__(
294293 dt = dt ,
295294 x_new = x_new ,
296295 solver_numeric_outputs = solver_numeric_outputs ,
297- static_runtime_params_slice = self .solver . static_runtime_params_slice ,
296+ static_runtime_params_slice = self ._static_runtime_params_slice ,
298297 dynamic_runtime_params_slice_t_plus_dt = dynamic_runtime_params_slice_t_plus_dt ,
299298 geometry_t_plus_dt = geo_t_plus_dt ,
300299 explicit_source_profiles = explicit_source_profiles ,
@@ -377,7 +376,6 @@ def init_time_step_calculator(
377376 def step (
378377 self ,
379378 dt : jax .Array ,
380- static_runtime_params_slice : runtime_params_slice .StaticRuntimeParamsSlice ,
381379 dynamic_runtime_params_slice_t : runtime_params_slice .DynamicRuntimeParamsSlice ,
382380 dynamic_runtime_params_slice_t_plus_dt : runtime_params_slice .DynamicRuntimeParamsSlice ,
383381 geo_t : geometry .Geometry ,
@@ -395,8 +393,6 @@ def step(
395393
396394 Args:
397395 dt: Time step duration.
398- static_runtime_params_slice: Static parameters that, if they change,
399- should trigger a recompilation of the SimulationStepFn.
400396 dynamic_runtime_params_slice_t: Runtime parameters at time t.
401397 dynamic_runtime_params_slice_t_plus_dt: Runtime parameters at time t + dt.
402398 geo_t: The geometry of the torus during this time step of the simulation.
@@ -421,7 +417,7 @@ def step(
421417 # PDE system.
422418 core_profiles_t_plus_dt = updaters .provide_core_profiles_t_plus_dt (
423419 dt = dt ,
424- static_runtime_params_slice = static_runtime_params_slice ,
420+ static_runtime_params_slice = self . _static_runtime_params_slice ,
425421 dynamic_runtime_params_slice_t = dynamic_runtime_params_slice_t ,
426422 dynamic_runtime_params_slice_t_plus_dt = dynamic_runtime_params_slice_t_plus_dt ,
427423 geo_t_plus_dt = geo_t_plus_dt ,
@@ -447,7 +443,6 @@ def _adaptive_step(
447443 x_old : tuple [cell_variable .CellVariable , ...],
448444 dt : jax .Array ,
449445 solver_numeric_outputs : state .SolverNumericOutputs ,
450- static_runtime_params_slice : runtime_params_slice .StaticRuntimeParamsSlice ,
451446 dynamic_runtime_params_slice_t : runtime_params_slice .DynamicRuntimeParamsSlice ,
452447 dynamic_runtime_params_slice_t_plus_dt : runtime_params_slice .DynamicRuntimeParamsSlice ,
453448 geo_t : geometry .Geometry ,
@@ -473,8 +468,6 @@ def _adaptive_step(
473468 dt: Time step duration for the initial step.
474469 solver_numeric_outputs: Solver-specific numeric outputs from the initial
475470 step.
476- static_runtime_params_slice: Static parameters that, if they change,
477- should trigger a recompilation of the SimulationStepFn.
478471 dynamic_runtime_params_slice_t: Runtime parameters at time t.
479472 dynamic_runtime_params_slice_t_plus_dt: Runtime parameters at time t +
480473 dt.
@@ -558,7 +551,7 @@ def body_fun(
558551
559552 core_profiles_t_plus_dt = updaters .provide_core_profiles_t_plus_dt (
560553 dt = dt ,
561- static_runtime_params_slice = static_runtime_params_slice ,
554+ static_runtime_params_slice = self . _static_runtime_params_slice ,
562555 dynamic_runtime_params_slice_t = dynamic_runtime_params_slice_t ,
563556 dynamic_runtime_params_slice_t_plus_dt = dynamic_runtime_params_slice_t_plus_dt ,
564557 geo_t_plus_dt = geo_t_plus_dt ,
0 commit comments