Skip to content

Commit 1194bd2

Browse files
Nush395Torax team
authored andcommitted
Remove unnecessary face_centers and cell_centers force update on set_grid.
PiperOrigin-RevId: 852228987
1 parent 2923d37 commit 1194bd2

File tree

82 files changed

+165
-148
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

82 files changed

+165
-148
lines changed

torax/_src/core_profiles/convertors.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -133,9 +133,9 @@ def scale_cell_variable(
133133

134134
return cell_variable.CellVariable(
135135
value=scaled_value,
136+
face_centers=cv.face_centers,
136137
left_face_constraint=scaled_left_face_constraint,
137138
left_face_grad_constraint=scaled_left_face_grad_constraint,
138139
right_face_constraint=scaled_right_face_constraint,
139140
right_face_grad_constraint=scaled_right_face_grad_constraint,
140-
dr=cv.dr,
141141
)

torax/_src/core_profiles/getters.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -72,10 +72,10 @@ def get_updated_ion_temperature(
7272
value = profile_conditions_params.T_i
7373
T_i = cell_variable.CellVariable(
7474
value=value,
75+
face_centers=geo.rho_face_norm,
7576
left_face_grad_constraint=jnp.zeros(()),
7677
right_face_grad_constraint=None,
7778
right_face_constraint=profile_conditions_params.T_i_right_bc,
78-
dr=geo.drho_norm,
7979
)
8080
return T_i
8181

@@ -99,10 +99,10 @@ def get_updated_electron_temperature(
9999

100100
T_e = cell_variable.CellVariable(
101101
value=value,
102+
face_centers=geo.rho_face_norm,
102103
left_face_grad_constraint=jnp.zeros(()),
103104
right_face_grad_constraint=None,
104105
right_face_constraint=profile_conditions_params.T_e_right_bc,
105-
dr=geo.drho_norm,
106106
)
107107
return T_e
108108

@@ -191,7 +191,7 @@ def get_updated_electron_density(
191191

192192
n_e = cell_variable.CellVariable(
193193
value=value,
194-
dr=geo.drho_norm,
194+
face_centers=geo.rho_face_norm,
195195
right_face_grad_constraint=None,
196196
right_face_constraint=n_e_right_bc,
197197
)
@@ -233,7 +233,7 @@ def get_updated_psi(
233233
)
234234
return cell_variable.CellVariable(
235235
value=value,
236-
dr=geo.drho_norm,
236+
face_centers=geo.rho_face_norm,
237237
right_face_grad_constraint=right_face_grad_constraint,
238238
right_face_constraint=right_face_constraint,
239239
)
@@ -250,9 +250,9 @@ def get_updated_toroidal_velocity(
250250
value = profile_conditions_params.toroidal_velocity
251251
toroidal_velocity = cell_variable.CellVariable(
252252
value=value,
253+
face_centers=geo.rho_face_norm,
253254
right_face_grad_constraint=None,
254255
right_face_constraint=profile_conditions_params.toroidal_velocity_right_bc,
255-
dr=geo.drho_norm,
256256
)
257257
return toroidal_velocity
258258

@@ -655,7 +655,7 @@ def get_updated_ions(
655655

656656
n_i = cell_variable.CellVariable(
657657
value=n_e.value * ion_properties.dilution_factor,
658-
dr=geo.drho_norm,
658+
face_centers=geo.rho_face_norm,
659659
right_face_grad_constraint=None,
660660
right_face_constraint=n_e.right_face_constraint
661661
* ion_properties.dilution_factor_edge,
@@ -676,7 +676,7 @@ def get_updated_ions(
676676

677677
n_impurity = cell_variable.CellVariable(
678678
value=n_impurity_value,
679-
dr=geo.drho_norm,
679+
face_centers=geo.rho_face_norm,
680680
right_face_grad_constraint=None,
681681
right_face_constraint=n_impurity_right_face_constraint,
682682
)

torax/_src/core_profiles/initialization.py

Lines changed: 7 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -90,11 +90,11 @@ def initial_core_profiles(
9090
# Initialise psi and derived quantities to zero before they are calculated.
9191
psidot = cell_variable.CellVariable(
9292
value=jnp.zeros_like(geo.rho, dtype=jax_utils.get_dtype()),
93-
dr=geo.drho_norm,
93+
face_centers=geo.rho_face_norm,
9494
)
9595
psi = cell_variable.CellVariable(
9696
value=jnp.zeros_like(geo.rho, dtype=jax_utils.get_dtype()),
97-
dr=geo.drho_norm,
97+
face_centers=geo.rho_face_norm,
9898
)
9999

100100
core_profiles = state.CoreProfiles(
@@ -190,7 +190,7 @@ def update_psi_from_j(
190190
# we set it to match the desired plasma current.
191191
right_face_grad_constraint = None
192192
right_face_constraint = (
193-
psi_value[-1] + dpsi_drhonorm_edge * geo.drho_norm / 2
193+
psi_value[-1] + dpsi_drhonorm_edge * geo.drho_norm[-1] / 2
194194
)
195195
else:
196196
# Use the dpsi/drho calculated above as the right face gradient constraint
@@ -199,7 +199,7 @@ def update_psi_from_j(
199199

200200
psi = cell_variable.CellVariable(
201201
value=psi_value,
202-
dr=geo.drho_norm,
202+
face_centers=geo.rho_face_norm,
203203
right_face_grad_constraint=right_face_grad_constraint,
204204
right_face_constraint=right_face_constraint,
205205
)
@@ -303,7 +303,7 @@ def _init_psi_and_psi_derived(
303303
right_face_grad_constraint = None
304304
right_face_constraint = (
305305
runtime_params.profile_conditions.psi[-1]
306-
+ dpsi_drhonorm_edge * geo.drho_norm / 2
306+
+ dpsi_drhonorm_edge * geo.drho_norm[-1] / 2
307307
)
308308
else:
309309
# Use the dpsi/drho calculated above as the right face gradient
@@ -313,9 +313,9 @@ def _init_psi_and_psi_derived(
313313

314314
psi = cell_variable.CellVariable(
315315
value=runtime_params.profile_conditions.psi,
316+
face_centers=geo.rho_face_norm,
316317
right_face_grad_constraint=right_face_grad_constraint,
317318
right_face_constraint=right_face_constraint,
318-
dr=geo.drho_norm,
319319
)
320320

321321
# Case 2: retrieving psi from the standard geometry input.
@@ -338,13 +338,13 @@ def _init_psi_and_psi_derived(
338338
# by make_ip_consistent
339339
psi = cell_variable.CellVariable(
340340
value=geo.psi_from_Ip, # Use psi from equilibrium
341+
face_centers=geo.rho_face_norm,
341342
right_face_grad_constraint=None
342343
if runtime_params.profile_conditions.use_v_loop_lcfs_boundary_condition
343344
else dpsi_drhonorm_edge,
344345
right_face_constraint=geo.psi_from_Ip_face[-1]
345346
if runtime_params.profile_conditions.use_v_loop_lcfs_boundary_condition
346347
else None,
347-
dr=geo.drho_norm,
348348
)
349349

350350
# Case 3: calculating j according to nu formula and psi from j.

torax/_src/core_profiles/tests/convertors_test.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -34,25 +34,25 @@ def setUp(self):
3434

3535
T_i = cell_variable.CellVariable(
3636
value=jnp.ones(self.geo.rho_norm.shape) * 1.0,
37-
dr=self.geo.drho_norm,
37+
face_centers=self.geo.rho_face_norm,
3838
right_face_constraint=jnp.array(0.5),
3939
right_face_grad_constraint=None,
4040
)
4141
T_e = cell_variable.CellVariable(
4242
value=jnp.ones(self.geo.rho_norm.shape) * 2.0,
43-
dr=self.geo.drho_norm,
43+
face_centers=self.geo.rho_face_norm,
4444
right_face_constraint=jnp.array(0.6),
4545
right_face_grad_constraint=None,
4646
)
4747
psi = cell_variable.CellVariable(
4848
value=jnp.ones(self.geo.rho_norm.shape) * 3.0,
49-
dr=self.geo.drho_norm,
49+
face_centers=self.geo.rho_face_norm,
5050
right_face_grad_constraint=jnp.array(0.7),
5151
right_face_constraint=None,
5252
)
5353
n_e = cell_variable.CellVariable(
5454
value=jnp.ones(self.geo.rho_norm.shape) * 4.0,
55-
dr=self.geo.drho_norm,
55+
face_centers=self.geo.rho_face_norm,
5656
right_face_constraint=jnp.array(0.8),
5757
right_face_grad_constraint=None,
5858
)

torax/_src/core_profiles/tests/getters_test.py

Lines changed: 16 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,7 @@ def test_only_updating_boundary_condition(self):
7373
left_face_grad_constraint=jnp.zeros(()),
7474
right_face_grad_constraint=None,
7575
right_face_constraint=jnp.array(1.0, dtype=jax_utils.get_dtype()),
76-
dr=self.geo.drho_norm,
76+
face_centers=self.geo.rho_face_norm,
7777
)
7878
result = getters.get_updated_ion_temperature(
7979
profile_conditions,
@@ -113,7 +113,7 @@ def test_only_updating_boundary_condition_electron_temperature(self):
113113
left_face_grad_constraint=jnp.zeros(()),
114114
right_face_grad_constraint=None,
115115
right_face_constraint=jnp.array(1.0, dtype=jax_utils.get_dtype()),
116-
dr=self.geo.drho_norm,
116+
face_centers=self.geo.rho_face_norm,
117117
)
118118
result = getters.get_updated_electron_temperature(
119119
profile_conditions,
@@ -168,7 +168,7 @@ def test_only_updating_boundary_condition_n_e(self):
168168
left_face_grad_constraint=jnp.zeros(()),
169169
right_face_grad_constraint=None,
170170
right_face_constraint=jnp.array(1.0, dtype=jax_utils.get_dtype()),
171-
dr=self.geo.drho_norm,
171+
face_centers=self.geo.rho_face_norm,
172172
)
173173
n_e = getters.get_updated_electron_density(
174174
profile_conditions,
@@ -402,7 +402,7 @@ def test_get_updated_ion_data(self):
402402
left_face_grad_constraint=jnp.zeros(()),
403403
right_face_grad_constraint=None,
404404
right_face_constraint=jnp.array(100.0, dtype=jax_utils.get_dtype()),
405-
dr=geo.drho_norm,
405+
face_centers=geo.rho_face_norm,
406406
)
407407
n_e = getters.get_updated_electron_density(
408408
runtime_params.profile_conditions,
@@ -561,13 +561,13 @@ def test_get_updated_ions_impurity_mixture(self):
561561
geo = torax_config.geometry.build_provider(t=0.0)
562562
T_e_cell_variable = cell_variable.CellVariable(
563563
value=jnp.full_like(geo.rho_norm, T_e),
564-
dr=geo.drho_norm,
564+
face_centers=geo.rho_face_norm,
565565
right_face_constraint=T_e,
566566
right_face_grad_constraint=None,
567567
)
568568
n_e_cell_variable = cell_variable.CellVariable(
569569
value=jnp.full_like(geo.rho_norm, n_e),
570-
dr=geo.drho_norm,
570+
face_centers=geo.rho_face_norm,
571571
right_face_constraint=n_e,
572572
right_face_grad_constraint=None,
573573
)
@@ -666,13 +666,13 @@ def test_get_updated_ions_impurity_mixture_radially_dependent(self):
666666
geo = torax_config.geometry.build_provider(t=0.0)
667667
T_e_cell_variable = cell_variable.CellVariable(
668668
value=jnp.full_like(geo.rho_norm, T_e),
669-
dr=geo.drho_norm,
669+
face_centers=geo.rho_face_norm,
670670
right_face_constraint=T_e,
671671
right_face_grad_constraint=None,
672672
)
673673
n_e_cell_variable = cell_variable.CellVariable(
674674
value=jnp.full_like(geo.rho_norm, n_e),
675-
dr=geo.drho_norm,
675+
face_centers=geo.rho_face_norm,
676676
right_face_constraint=n_e,
677677
right_face_grad_constraint=None,
678678
)
@@ -790,13 +790,13 @@ def _run_get_updated_ions(torax_config):
790790

791791
t_e_cell_variable = cell_variable.CellVariable(
792792
value=jnp.full_like(geo.rho_norm, t_e_keV),
793-
dr=geo.drho_norm,
793+
face_centers=geo.rho_face_norm,
794794
right_face_constraint=t_e_keV,
795795
right_face_grad_constraint=None,
796796
)
797797
n_e_cell_variable = cell_variable.CellVariable(
798798
value=jnp.full_like(geo.rho_norm, n_e_val),
799-
dr=geo.drho_norm,
799+
face_centers=geo.rho_face_norm,
800800
right_face_constraint=n_e_val,
801801
right_face_grad_constraint=None,
802802
)
@@ -908,13 +908,13 @@ def _run_get_updated_ions(torax_config):
908908

909909
t_e_cell_variable = cell_variable.CellVariable(
910910
value=jnp.full_like(geo.rho_norm, t_e_keV),
911-
dr=geo.drho_norm,
911+
face_centers=geo.rho_face_norm,
912912
right_face_constraint=t_e_keV,
913913
right_face_grad_constraint=None,
914914
)
915915
n_e_cell_variable = cell_variable.CellVariable(
916916
value=jnp.full_like(geo.rho_norm, n_e_val),
917-
dr=geo.drho_norm,
917+
face_centers=geo.rho_face_norm,
918918
right_face_constraint=n_e_val,
919919
right_face_grad_constraint=None,
920920
)
@@ -978,13 +978,13 @@ def _run_get_updated_ions(torax_config):
978978

979979
t_e_cell_variable = cell_variable.CellVariable(
980980
value=jnp.full_like(geo.rho_norm, t_e_keV),
981-
dr=geo.drho_norm,
981+
face_centers=geo.rho_face_norm,
982982
right_face_constraint=t_e_keV,
983983
right_face_grad_constraint=None,
984984
)
985985
n_e_cell_variable = cell_variable.CellVariable(
986986
value=jnp.full_like(geo.rho_norm, n_e_val),
987-
dr=geo.drho_norm,
987+
face_centers=geo.rho_face_norm,
988988
right_face_constraint=n_e_val,
989989
right_face_grad_constraint=None,
990990
)
@@ -1079,13 +1079,13 @@ def _run_get_updated_ions(torax_config):
10791079

10801080
t_e_cell_variable = cell_variable.CellVariable(
10811081
value=jnp.full_like(geo.rho_norm, t_e_keV),
1082-
dr=geo.drho_norm,
1082+
face_centers=geo.rho_face_norm,
10831083
right_face_constraint=t_e_keV,
10841084
right_face_grad_constraint=None,
10851085
)
10861086
n_e_cell_variable = cell_variable.CellVariable(
10871087
value=jnp.full_like(geo.rho_norm, n_e_val),
1088-
dr=geo.drho_norm,
1088+
face_centers=geo.rho_face_norm,
10891089
right_face_constraint=n_e_val,
10901090
right_face_grad_constraint=None,
10911091
)

torax/_src/core_profiles/tests/updaters_test.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,13 +40,13 @@ def setUp(self):
4040

4141
T_e = cell_variable.CellVariable(
4242
value=jnp.ones_like(self.geo.rho_norm),
43-
dr=self.geo.drho_norm,
43+
face_centers=self.geo.rho_face_norm,
4444
right_face_constraint=1.0,
4545
right_face_grad_constraint=None,
4646
)
4747
n_e = cell_variable.CellVariable(
4848
value=jnp.ones_like(self.geo.rho_norm),
49-
dr=self.geo.drho_norm,
49+
face_centers=self.geo.rho_face_norm,
5050
right_face_constraint=1.0,
5151
right_face_grad_constraint=None,
5252
)

torax/_src/fvm/cell_variable.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,8 @@ class CellVariable:
4343
4444
Attributes:
4545
value: Value of this variable at each cell grid point.
46-
dr: Distance between cell centers.
46+
face_centers: Locations of the face centers. This array should have length
47+
len(value) + 1. Supports both uniform and non-uniform grids.
4748
left_face_constraint: An optional scalar specifying the value of the
4849
leftmost face. Defaults to None, signifying no constraint. The user can
4950
modify this field at any time, but when face_grad is called exactly one of
@@ -56,8 +57,9 @@ class CellVariable:
5657
right_face_grad_constraint: Analogous to left_face_grad_constraint but for
5758
the right face, see left_face_grad_constraint.
5859
"""
60+
5961
value: jt.Float[chex.Array, 'cell']
60-
dr: jt.Float[chex.Array, '']
62+
face_centers: jt.Float[chex.Array, 'face']
6163
left_face_constraint: jt.Float[chex.Array, ''] | None = None
6264
right_face_constraint: jt.Float[chex.Array, ''] | None = None
6365
left_face_grad_constraint: jt.Float[chex.Array, ''] | None = (
@@ -69,11 +71,6 @@ class CellVariable:
6971
# Can't make the above default values be jax zeros because that would be a
7072
# call to jax before absl.app.run
7173

72-
@property
73-
def face_centers(self) -> jt.Float[chex.Array, 'face']:
74-
"""Locations of the face centers."""
75-
return jnp.linspace(0.0, len(self.value) * self.dr, num=len(self.value) + 1)
76-
7774
@property
7875
def cell_centers(self) -> jt.Float[chex.Array, 'cell']:
7976
"""Locations of the cell centers."""

torax/_src/fvm/convection_terms.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -118,10 +118,10 @@ def peclet_to_alpha(p):
118118
left_v = v_face[:-1]
119119
right_v = v_face[1:]
120120

121-
diag = (left_alpha * left_v - right_alpha * right_v) / var.dr
122-
above = -(1.0 - right_alpha) * right_v / var.dr
121+
diag = (left_alpha * left_v - right_alpha * right_v) / var.cell_widths
122+
above = -(1.0 - right_alpha) * right_v / var.cell_widths
123123
above = above[:-1]
124-
below = (1.0 - left_alpha) * left_v / var.dr
124+
below = (1.0 - left_alpha) * left_v / var.cell_widths
125125
below = below[1:]
126126
mat = math_utils.tridiag(diag, above, below)
127127

0 commit comments

Comments
 (0)