Skip to content

Commit 40ec3c8

Browse files
tamaranormanTorax team
authored andcommitted
Pipe through prev_core_profiles to updaters where needed
PiperOrigin-RevId: 876278049
1 parent 8b65e14 commit 40ec3c8

File tree

8 files changed

+60
-10
lines changed

8 files changed

+60
-10
lines changed

torax/_src/core_profiles/updaters.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,8 @@ def update_core_profiles_during_step(
5353
runtime_params: runtime_params_lib.RuntimeParams,
5454
geo: geometry.Geometry,
5555
core_profiles: state.CoreProfiles,
56+
prev_core_profiles: state.CoreProfiles | None,
57+
dt: array_typing.FloatScalar | None,
5658
evolving_names: tuple[str, ...],
5759
) -> state.CoreProfiles:
5860
"""Returns the new core profiles after updating the evolving variables.
@@ -67,8 +69,13 @@ def update_core_profiles_during_step(
6769
runtime_params: The runtime params slice.
6870
geo: Magnetic geometry.
6971
core_profiles: The old set of core plasma profiles.
72+
prev_core_profiles: Core plasma profiles from the previous timestep if
73+
available, used to update the energy state.
74+
dt: The size of the last timestep, used to update the energy state.
7075
evolving_names: The names of the evolving variables.
7176
"""
77+
# Currently unused but will be used to update the energy state soon
78+
del prev_core_profiles, dt
7279

7380
updated_core_profiles = convertors.solver_x_tuple_to_core_profiles(
7481
x_new, evolving_names, core_profiles

torax/_src/fvm/calc_coeffs.py

Lines changed: 10 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@
1717
import functools
1818
import jax
1919
import jax.numpy as jnp
20+
from torax._src import array_typing
2021
from torax._src import constants
2122
from torax._src import physics_models as physics_models_lib
2223
from torax._src import state
@@ -62,6 +63,8 @@ def __call__(
6263
runtime_params: runtime_params_lib.RuntimeParams,
6364
geo: geometry.Geometry,
6465
core_profiles: state.CoreProfiles,
66+
prev_core_profiles: state.CoreProfiles | None,
67+
dt: array_typing.FloatScalar | None,
6568
x: tuple[cell_variable.CellVariable, ...],
6669
explicit_source_profiles: source_profiles_lib.SourceProfiles,
6770
allow_pereverzev: bool = False,
@@ -80,6 +83,9 @@ def __call__(
8083
state x.
8184
geo: The geometry of the system at this time step.
8285
core_profiles: The core profiles of the system at this time step.
86+
prev_core_profiles: The core profiles of the system at the previous
87+
time step.
88+
dt: The time step size.
8389
x: The state with cell-grid values of the evolving variables.
8490
explicit_source_profiles: Precomputed explicit source profiles. These
8591
profiles were configured to always depend on state and parameters at
@@ -107,8 +113,10 @@ def __call__(
107113
x,
108114
runtime_params,
109115
geo,
110-
core_profiles,
111-
self.evolving_names,
116+
core_profiles=core_profiles,
117+
prev_core_profiles=prev_core_profiles,
118+
dt=dt,
119+
evolving_names=self.evolving_names,
112120
)
113121
if allow_pereverzev:
114122
use_pereverzev = runtime_params.solver.use_pereverzev

torax/_src/fvm/newton_raphson_solve_block.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -158,7 +158,9 @@ def newton_raphson_solve_block(
158158
runtime_params_t,
159159
geo_t,
160160
core_profiles_t,
161-
x_old,
161+
prev_core_profiles=None,
162+
dt=None,
163+
x=x_old,
162164
explicit_source_profiles=explicit_source_profiles,
163165
explicit_call=True,
164166
)
@@ -173,8 +175,10 @@ def newton_raphson_solve_block(
173175
coeffs_exp_linear = coeffs_callback(
174176
runtime_params_t,
175177
geo_t,
176-
core_profiles_t,
177-
x_old,
178+
core_profiles=core_profiles_t,
179+
prev_core_profiles=None,
180+
dt=None,
181+
x=x_old,
178182
explicit_source_profiles=explicit_source_profiles,
179183
allow_pereverzev=True,
180184
explicit_call=True,
@@ -190,6 +194,7 @@ def newton_raphson_solve_block(
190194
geo_t_plus_dt=geo_t_plus_dt,
191195
x_old=x_old,
192196
x_new_guess=x_new_guess,
197+
core_profiles_t=core_profiles_t,
193198
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
194199
coeffs_exp=coeffs_exp_linear,
195200
coeffs_callback=coeffs_callback,
@@ -212,6 +217,7 @@ def newton_raphson_solve_block(
212217
runtime_params_t_plus_dt=runtime_params_t_plus_dt,
213218
geo_t_plus_dt=geo_t_plus_dt,
214219
x_old=x_old,
220+
core_profiles_t=core_profiles_t,
215221
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
216222
physics_models=physics_models,
217223
explicit_source_profiles=explicit_source_profiles,

torax/_src/fvm/optimizer_solve_block.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -122,7 +122,9 @@ def optimizer_solve_block(
122122
runtime_params_t,
123123
geo_t,
124124
core_profiles_t,
125-
x_old,
125+
prev_core_profiles=None,
126+
dt=None,
127+
x=x_old,
126128
explicit_source_profiles=explicit_source_profiles,
127129
explicit_call=True,
128130
)
@@ -139,7 +141,9 @@ def optimizer_solve_block(
139141
runtime_params_t,
140142
geo_t,
141143
core_profiles_t,
142-
x_old,
144+
prev_core_profiles=None,
145+
dt=None,
146+
x=x_old,
143147
explicit_source_profiles=explicit_source_profiles,
144148
allow_pereverzev=True,
145149
explicit_call=True,
@@ -154,6 +158,7 @@ def optimizer_solve_block(
154158
geo_t_plus_dt=geo_t_plus_dt,
155159
x_old=x_old,
156160
x_new_guess=x_new_guess,
161+
core_profiles_t=core_profiles_t,
157162
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
158163
coeffs_exp=coeffs_exp_linear,
159164
coeffs_callback=coeffs_callback,
@@ -178,6 +183,7 @@ def optimizer_solve_block(
178183
geo_t_plus_dt=geo_t_plus_dt,
179184
x_old=x_old,
180185
init_x_new_vec=init_x_new_vec,
186+
core_profiles_t=core_profiles_t,
181187
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
182188
explicit_source_profiles=explicit_source_profiles,
183189
physics_models=physics_models,

torax/_src/fvm/residual_and_loss.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -199,6 +199,7 @@ def theta_method_block_residual(
199199
runtime_params_t_plus_dt: runtime_params_lib.RuntimeParams,
200200
geo_t_plus_dt: geometry.Geometry,
201201
x_old: tuple[cell_variable.CellVariable, ...],
202+
core_profiles_t: state.CoreProfiles,
202203
core_profiles_t_plus_dt: state.CoreProfiles,
203204
explicit_source_profiles: source_profiles.SourceProfiles,
204205
physics_models: physics_models_lib.PhysicsModels,
@@ -214,6 +215,8 @@ def theta_method_block_residual(
214215
runtime_params_t_plus_dt: Runtime parameters for time t + dt.
215216
geo_t_plus_dt: The geometry at time t + dt.
216217
x_old: The starting x defined as a tuple of CellVariables.
218+
core_profiles_t: Core plasma profiles which contain all available
219+
prescribed quantities at the start of the time step.
217220
core_profiles_t_plus_dt: Core plasma profiles which contain all available
218221
prescribed quantities at the end of the time step. This includes evolving
219222
boundary conditions and prescribed time-dependent profiles that are not
@@ -244,7 +247,9 @@ def theta_method_block_residual(
244247
runtime_params_t_plus_dt,
245248
geo_t_plus_dt,
246249
core_profiles_t_plus_dt,
247-
evolving_names,
250+
prev_core_profiles=core_profiles_t,
251+
dt=dt,
252+
evolving_names=evolving_names,
248253
)
249254
coeffs_new = calc_coeffs.calc_coeffs(
250255
runtime_params=runtime_params_t_plus_dt,
@@ -288,6 +293,7 @@ def theta_method_block_loss(
288293
runtime_params_t_plus_dt: runtime_params_lib.RuntimeParams,
289294
geo_t_plus_dt: geometry.Geometry,
290295
x_old: tuple[cell_variable.CellVariable, ...],
296+
core_profiles_t: state.CoreProfiles,
291297
core_profiles_t_plus_dt: state.CoreProfiles,
292298
explicit_source_profiles: source_profiles.SourceProfiles,
293299
physics_models: physics_models_lib.PhysicsModels,
@@ -303,6 +309,7 @@ def theta_method_block_loss(
303309
runtime_params_t_plus_dt: Runtime parameters for time t + dt.
304310
geo_t_plus_dt: geometry object at time t + dt.
305311
x_old: The starting x defined as a tuple of CellVariables.
312+
core_profiles_t: Core profiles from the previous time step.
306313
core_profiles_t_plus_dt: Core plasma profiles which contain all available
307314
prescribed quantities at the end of the time step. This includes evolving
308315
boundary conditions and prescribed time-dependent profiles that are not
@@ -324,6 +331,7 @@ def theta_method_block_loss(
324331
geo_t_plus_dt=geo_t_plus_dt,
325332
x_old=x_old,
326333
x_new_guess_vec=x_new_guess_vec,
334+
core_profiles_t=core_profiles_t,
327335
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
328336
explicit_source_profiles=explicit_source_profiles,
329337
physics_models=physics_models,
@@ -347,6 +355,7 @@ def jaxopt_solver(
347355
geo_t_plus_dt: geometry.Geometry,
348356
x_old: tuple[cell_variable.CellVariable, ...],
349357
init_x_new_vec: jax.Array,
358+
core_profiles_t: state.CoreProfiles,
350359
core_profiles_t_plus_dt: state.CoreProfiles,
351360
explicit_source_profiles: source_profiles.SourceProfiles,
352361
physics_models: physics_models_lib.PhysicsModels,
@@ -364,6 +373,7 @@ def jaxopt_solver(
364373
x_old: The starting x defined as a tuple of CellVariables.
365374
init_x_new_vec: Flattened array of initial guess of x_new for all evolving
366375
core profiles.
376+
core_profiles_t: Core profiles from the previous time step.
367377
core_profiles_t_plus_dt: Core plasma profiles which contain all available
368378
prescribed quantities at the end of the time step. This includes evolving
369379
boundary conditions and prescribed time-dependent profiles that are not
@@ -389,6 +399,7 @@ def jaxopt_solver(
389399
runtime_params_t_plus_dt=runtime_params_t_plus_dt,
390400
geo_t_plus_dt=geo_t_plus_dt,
391401
x_old=x_old,
402+
core_profiles_t=core_profiles_t,
392403
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
393404
explicit_source_profiles=explicit_source_profiles,
394405
physics_models=physics_models,

torax/_src/fvm/tests/fvm_test.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -286,6 +286,7 @@ def test_nonlinear_solve_block_loss_minimum(
286286
x_old=x_old,
287287
x_new_guess_vec=jnp.concatenate([var.value for var in x_new]),
288288
core_profiles_t_plus_dt=core_profiles,
289+
core_profiles_t=core_profiles,
289290
physics_models=physics_models,
290291
explicit_source_profiles=explicit_source_profiles,
291292
coeffs_old=coeffs,
@@ -299,6 +300,7 @@ def test_nonlinear_solve_block_loss_minimum(
299300
x_new_guess_vec=jnp.concatenate([var.value for var in x_new]),
300301
x_old=x_old,
301302
core_profiles_t_plus_dt=core_profiles,
303+
core_profiles_t=core_profiles,
302304
physics_models=physics_models,
303305
explicit_source_profiles=explicit_source_profiles,
304306
coeffs_old=coeffs,
@@ -512,6 +514,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
512514
x_old=(x_0,),
513515
x_new_guess_vec=x_0.value,
514516
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
517+
core_profiles_t=initial_core_profiles,
515518
physics_models=physics_models,
516519
explicit_source_profiles=explicit_source_profiles,
517520
coeffs_old=coeffs_old,
@@ -535,6 +538,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
535538
x_0, right_face_constraint=final_right_boundary
536539
),
537540
),
541+
core_profiles_t=initial_core_profiles,
538542
evolving_names=evolving_names,
539543
physics_models=physics_models,
540544
explicit_source_profiles=explicit_source_profiles,
@@ -553,6 +557,7 @@ def test_theta_residual_uses_updated_boundary_conditions(self):
553557
x_0, right_face_constraint=final_right_boundary
554558
),
555559
),
560+
core_profiles_t=initial_core_profiles,
556561
x_new_guess_vec=x_0.value,
557562
physics_models=physics_models,
558563
explicit_source_profiles=explicit_source_profiles,

torax/_src/solver/linear_theta_method.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -74,7 +74,9 @@ def _x_new(
7474
runtime_params_t,
7575
geo_t,
7676
core_profiles_t,
77-
x_old,
77+
prev_core_profiles=None,
78+
dt=None,
79+
x=x_old,
7880
explicit_source_profiles=explicit_source_profiles,
7981
allow_pereverzev=True,
8082
explicit_call=True,
@@ -90,6 +92,7 @@ def _x_new(
9092
geo_t_plus_dt=geo_t_plus_dt,
9193
x_old=x_old,
9294
x_new_guess=x_new_guess,
95+
core_profiles_t=core_profiles_t,
9396
core_profiles_t_plus_dt=core_profiles_t_plus_dt,
9497
coeffs_exp=coeffs_exp,
9598
coeffs_callback=coeffs_callback,

torax/_src/solver/predictor_corrector_method.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ def predictor_corrector_method(
4444
geo_t_plus_dt: geometry.Geometry,
4545
x_old: tuple[cell_variable.CellVariable, ...],
4646
x_new_guess: tuple[cell_variable.CellVariable, ...],
47+
core_profiles_t: state.CoreProfiles,
4748
core_profiles_t_plus_dt: state.CoreProfiles,
4849
coeffs_exp: block_1d_coeffs.Block1DCoeffs,
4950
explicit_source_profiles: source_profiles.SourceProfiles,
@@ -60,6 +61,7 @@ def predictor_corrector_method(
6061
time t.
6162
x_new_guess: Tuple of CellVariables corresponding to the initial guess for
6263
the next time step.
64+
core_profiles_t: Core profiles at the current time step.
6365
core_profiles_t_plus_dt: Core profiles at the next time step.
6466
coeffs_exp: Block1DCoeffs PDE coefficients at beginning of timestep.
6567
explicit_source_profiles: Precomputed explicit source profiles. These
@@ -82,7 +84,9 @@ def loop_body(x_new_guess):
8284
runtime_params_t_plus_dt,
8385
geo_t_plus_dt,
8486
core_profiles_t_plus_dt,
85-
x_new_guess,
87+
prev_core_profiles=core_profiles_t,
88+
dt=dt,
89+
x=x_new_guess,
8690
explicit_source_profiles=explicit_source_profiles,
8791
allow_pereverzev=True,
8892
)

0 commit comments

Comments
 (0)