Skip to content

Commit 4e923e3

Browse files
Nush395Torax team
authored andcommitted
Create Static container for numerics and package components of slice into it.
This is to package all the current components of `Numerics` that go into the static slice into a single object (that future attributes can be easily added to as well). This also keeps consistency with the other object patterns (e.g. `profile_conditions.StaticRuntimeParams`) PiperOrigin-RevId: 775289301
1 parent e71efc4 commit 4e923e3

File tree

8 files changed

+49
-43
lines changed

8 files changed

+49
-43
lines changed

torax/_src/config/build_runtime_params.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -42,13 +42,9 @@ def build_static_params_from_config(
4242
},
4343
torax_mesh=config.geometry.build_provider.torax_mesh,
4444
solver=config.solver.build_static_params(),
45-
evolve_ion_heat=config.numerics.evolve_ion_heat,
46-
evolve_electron_heat=config.numerics.evolve_electron_heat,
47-
evolve_current=config.numerics.evolve_current,
48-
evolve_density=config.numerics.evolve_density,
4945
main_ion_names=config.plasma_composition.get_main_ion_names(),
5046
impurity_names=config.plasma_composition.get_impurity_names(),
51-
adaptive_dt=config.numerics.adaptive_dt,
47+
numerics=config.numerics.build_static_params(),
5248
profile_conditions=config.profile_conditions.build_static_params(),
5349
)
5450

torax/_src/config/numerics.py

Lines changed: 28 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -24,7 +24,10 @@
2424
# TODO(b/326578331): remove density reference from DynamicNumerics entirely.
2525
@chex.dataclass
2626
class DynamicNumerics:
27-
"""Generic numeric parameters for the simulation."""
27+
"""Generic numeric parameters for the simulation.
28+
29+
For definitions see `Numerics`.
30+
"""
2831

2932
t_initial: float
3033
t_final: float
@@ -40,6 +43,19 @@ class DynamicNumerics:
4043
calcphibdot: bool
4144

4245

46+
@chex.dataclass(frozen=True)
47+
class StaticNumerics:
48+
"""Static numerics parameters for the simulation.
49+
50+
For definitions see `Numerics`.
51+
"""
52+
evolve_ion_heat: bool
53+
evolve_electron_heat: bool
54+
evolve_current: bool
55+
evolve_density: bool
56+
adaptive_dt: bool
57+
58+
4359
class Numerics(torax_pydantic.BaseModelFrozen):
4460
"""Generic numeric parameters for the simulation.
4561
@@ -119,7 +135,7 @@ def build_dynamic_params(
119135
self,
120136
t: chex.Numeric,
121137
) -> DynamicNumerics:
122-
"""Builds a DynamicNumerics."""
138+
"""Builds a DynamicNumerics object for time t."""
123139
return DynamicNumerics(
124140
t_initial=self.t_initial,
125141
t_final=self.t_final,
@@ -134,3 +150,13 @@ def build_dynamic_params(
134150
adaptive_T_source_prefactor=self.adaptive_T_source_prefactor,
135151
adaptive_n_source_prefactor=self.adaptive_n_source_prefactor,
136152
)
153+
154+
def build_static_params(self) -> StaticNumerics:
155+
"""Builds a StaticNumerics object."""
156+
return StaticNumerics(
157+
evolve_ion_heat=self.evolve_ion_heat,
158+
evolve_electron_heat=self.evolve_electron_heat,
159+
evolve_current=self.evolve_current,
160+
evolve_density=self.evolve_density,
161+
adaptive_dt=self.adaptive_dt,
162+
)

torax/_src/config/runtime_params_slice.py

Lines changed: 2 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -113,15 +113,6 @@ class StaticRuntimeParamsSlice:
113113
sources: Mapping[str, sources_params.StaticRuntimeParams]
114114
# Torax mesh used to construct the geometry.
115115
torax_mesh: torax_pydantic.Grid1D
116-
# Solve the ion heat equation (ion temperature evolves over time)
117-
evolve_ion_heat: bool
118-
# Solve the electron heat equation (electron temperature evolves over time)
119-
evolve_electron_heat: bool
120-
# Solve the current equation (psi evolves over time driven by the solver;
121-
# q and s evolve over time as a function of psi)
122-
evolve_current: bool
123-
# Solve the density equation (n evolves over time)
124-
evolve_density: bool
125116
# Ion symbols for main ion and impurity (which each could be mixtures of ions)
126117
# These are static to simplify source functions for fusion power and radiation
127118
# which are species-dependent.
@@ -130,24 +121,17 @@ class StaticRuntimeParamsSlice:
130121
main_ion_names: tuple[str, ...]
131122
impurity_names: tuple[str, ...]
132123
profile_conditions: profile_conditions.StaticRuntimeParams
133-
# Iterative reduction of dt if nonlinear step does not converge,
134-
# If nonlinear step does not converge, then the step is redone
135-
# iteratively at successively lower dt until convergence is reached
136-
adaptive_dt: bool
124+
numerics: numerics.StaticNumerics
137125

138126
def __hash__(self):
139127
return hash((
140128
self.solver,
141129
tuple(sorted(self.sources.items())), # Hashable version of sources
142130
hash(self.torax_mesh), # Grid1D has a hash method defined.
143-
self.evolve_ion_heat,
144-
self.evolve_electron_heat,
145-
self.evolve_current,
146-
self.evolve_density,
147131
self.main_ion_names,
148132
self.impurity_names,
149-
self.adaptive_dt,
150133
self.profile_conditions,
134+
self.numerics,
151135
))
152136

153137
def validate_new(self, new_params: typing_extensions.Self):

torax/_src/core_profiles/updaters.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -91,21 +91,21 @@ def get_prescribed_core_profile_values(
9191
"""
9292
# If profiles are not evolved, they can still potential be time-evolving,
9393
# depending on the runtime params. If so, they are updated below.
94-
if not static_runtime_params_slice.evolve_ion_heat:
94+
if not static_runtime_params_slice.numerics.evolve_ion_heat:
9595
T_i = getters.get_updated_ion_temperature(
9696
dynamic_runtime_params_slice.profile_conditions, geo
9797
).value
9898
else:
9999
T_i = core_profiles.T_i.value
100-
if not static_runtime_params_slice.evolve_electron_heat:
100+
if not static_runtime_params_slice.numerics.evolve_electron_heat:
101101
T_e_cell_variable = getters.get_updated_electron_temperature(
102102
dynamic_runtime_params_slice.profile_conditions, geo
103103
)
104104
T_e = T_e_cell_variable.value
105105
else:
106106
T_e_cell_variable = core_profiles.T_e
107107
T_e = T_e_cell_variable.value
108-
if not static_runtime_params_slice.evolve_density:
108+
if not static_runtime_params_slice.numerics.evolve_density:
109109
n_e_cell_variable = getters.get_updated_electron_density(
110110
static_runtime_params_slice,
111111
dynamic_runtime_params_slice.profile_conditions,

torax/_src/mhd/sawtooth/simple_redistribution.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def __call__(
6565
"""
6666

6767
# No sawtooth redistribution if current is not being evolved.
68-
if not static_runtime_params_slice.evolve_current:
68+
if not static_runtime_params_slice.numerics.evolve_current:
6969
return core_profiles_t
7070

7171
assert dynamic_runtime_params_slice.mhd.sawtooth is not None
@@ -88,7 +88,7 @@ def __call__(
8888
indices = jnp.arange(geo.rho_norm.shape[0])
8989
redistribution_mask = indices < idx_mixing
9090

91-
if static_runtime_params_slice.evolve_density:
91+
if static_runtime_params_slice.numerics.evolve_density:
9292
n_e_redistributed = flatten_profile.flatten_density_profile(
9393
rho_norm_q1,
9494
mixing_radius,
@@ -99,7 +99,7 @@ def __call__(
9999
)
100100
else:
101101
n_e_redistributed = core_profiles_t.n_e
102-
if static_runtime_params_slice.evolve_electron_heat:
102+
if static_runtime_params_slice.numerics.evolve_electron_heat:
103103
te_redistributed = flatten_profile.flatten_temperature_profile(
104104
rho_norm_q1,
105105
mixing_radius,
@@ -113,8 +113,8 @@ def __call__(
113113
else:
114114
te_redistributed = core_profiles_t.T_e
115115
if (
116-
static_runtime_params_slice.evolve_density
117-
or static_runtime_params_slice.evolve_electron_heat
116+
static_runtime_params_slice.numerics.evolve_density
117+
or static_runtime_params_slice.numerics.evolve_electron_heat
118118
):
119119
ions_redistributed = getters.get_updated_ions(
120120
static_runtime_params_slice,
@@ -136,7 +136,7 @@ def __call__(
136136
Z_eff=core_profiles_t.Z_eff,
137137
Z_eff_face=core_profiles_t.Z_eff_face,
138138
)
139-
if static_runtime_params_slice.evolve_ion_heat:
139+
if static_runtime_params_slice.numerics.evolve_ion_heat:
140140
ti_redistributed = flatten_profile.flatten_temperature_profile(
141141
rho_norm_q1,
142142
mixing_radius,

torax/_src/orchestration/step_function.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,7 +255,7 @@ def __call__(
255255
explicit_source_profiles,
256256
)
257257

258-
if self._static_runtime_params_slice.adaptive_dt:
258+
if self._static_runtime_params_slice.numerics.adaptive_dt:
259259
# This is a no-op if
260260
# output_state.solver_numeric_outputs.solver_error_state == 0.
261261
(

torax/_src/solver/solver.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -88,13 +88,13 @@ def __eq__(self, other: typing_extensions.Self) -> bool:
8888
def evolving_names(self) -> tuple[str, ...]:
8989
"""The names of core_profiles variables that are evolved by the solver."""
9090
evolving_names = []
91-
if self.static_runtime_params_slice.evolve_ion_heat:
91+
if self.static_runtime_params_slice.numerics.evolve_ion_heat:
9292
evolving_names.append('T_i')
93-
if self.static_runtime_params_slice.evolve_electron_heat:
93+
if self.static_runtime_params_slice.numerics.evolve_electron_heat:
9494
evolving_names.append('T_e')
95-
if self.static_runtime_params_slice.evolve_current:
95+
if self.static_runtime_params_slice.numerics.evolve_current:
9696
evolving_names.append('psi')
97-
if self.static_runtime_params_slice.evolve_density:
97+
if self.static_runtime_params_slice.numerics.evolve_density:
9898
evolving_names.append('n_e')
9999
return tuple(evolving_names)
100100

torax/_src/sources/qei_source.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,12 @@ def _model_based_qei(
119119
if (
120120
# if only a single heat equation is being evolved
121121
(
122-
static_runtime_params_slice.evolve_ion_heat
123-
and not static_runtime_params_slice.evolve_electron_heat
122+
static_runtime_params_slice.numerics.evolve_ion_heat
123+
and not static_runtime_params_slice.numerics.evolve_electron_heat
124124
)
125125
or (
126-
static_runtime_params_slice.evolve_electron_heat
127-
and not static_runtime_params_slice.evolve_ion_heat
126+
static_runtime_params_slice.numerics.evolve_electron_heat
127+
and not static_runtime_params_slice.numerics.evolve_ion_heat
128128
)
129129
):
130130
explicit_i = qei_coef * core_profiles.T_e.value

0 commit comments

Comments
 (0)