Skip to content

Commit a7b4905

Browse files
Nush395Torax team
authored andcommitted
Move dx from being a grid input to being a property.
Currently `dx` is always `1/nx`. Drive-by: fix formatting check. PiperOrigin-RevId: 778535641
1 parent 59614f1 commit a7b4905

File tree

11 files changed

+25
-56
lines changed

11 files changed

+25
-56
lines changed

torax/_src/array_typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
# ============================================================================
1515
"""Common types for using jaxtyping in TORAX."""
1616
from typing import TypeAlias
17+
1718
import jax
1819
import jaxtyping as jt
1920
import numpy as np

torax/_src/config/tests/build_runtime_params_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,7 @@ class RuntimeParamsSliceTest(parameterized.TestCase):
3333

3434
def setUp(self):
3535
super().setUp()
36-
self._torax_mesh = torax_pydantic.Grid1D(nx=4, dx=0.25)
36+
self._torax_mesh = torax_pydantic.Grid1D(nx=4,)
3737

3838
def test_time_dependent_provider_is_time_dependent(self):
3939
"""Tests that the runtime_params slice provider is time dependent."""

torax/_src/geometry/circular_geometry.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -45,9 +45,8 @@ def build_circular_geometry(
4545
"""
4646
# circular geometry assumption of r/a_minor = rho_norm, the normalized
4747
# toroidal flux coordinate.
48-
drho_norm = 1.0 / n_rho
4948
# Define mesh (Slab Uniform 1D with Jacobian = 1)
50-
mesh = torax_pydantic.Grid1D(nx=n_rho, dx=drho_norm)
49+
mesh = torax_pydantic.Grid1D(nx=n_rho,)
5150
# toroidal flux coordinate (rho) at boundary (last closed flux surface)
5251
rho_b = np.asarray(a_minor)
5352

torax/_src/geometry/standard_geometry.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1057,9 +1057,8 @@ def build_standard_geometry(
10571057
j_total = np.concatenate([np.array([j_total_face_axis]), j_total_face_bulk])
10581058

10591059
# fill geometry structure
1060-
drho_norm = float(rho_norm_intermediate[-1]) / intermediate.n_rho
10611060
# normalized grid
1062-
mesh = torax_pydantic.Grid1D(nx=intermediate.n_rho, dx=drho_norm)
1061+
mesh = torax_pydantic.Grid1D(nx=intermediate.n_rho,)
10631062
rho_b = rho_intermediate[-1] # radius denormalization constant
10641063
# helper variables for mesh cells and faces
10651064
rho_face_norm = mesh.face_centers

torax/_src/sources/tests/ion_cyclotron_source_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ def test_build_dynamic_params(self):
112112
self.assertIsInstance(source, self._source_config_class)
113113
torax_pydantic.set_grid(
114114
source,
115-
torax_pydantic.Grid1D(nx=4, dx=0.25),
115+
torax_pydantic.Grid1D(nx=4,),
116116
)
117117
dynamic_params = source.build_dynamic_params(t=0.0)
118118
self.assertIsInstance(

torax/_src/sources/tests/mavrin_impurity_radiation_heat_sink_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,7 @@ def test_correct_dynamic_params_built(self):
4343
})
4444
# Set the grid to allows the dynamic params to be built without making the
4545
# full config.
46-
torax_pydantic.set_grid(sources, torax_pydantic.Grid1D(nx=4, dx=0.25))
46+
torax_pydantic.set_grid(sources, torax_pydantic.Grid1D(nx=4,))
4747
runtime_params = getattr(sources, self._source_name).build_dynamic_params(
4848
t=0.0
4949
)

torax/_src/sources/tests/pydantic_model_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -162,7 +162,7 @@ def test_adding_a_source_with_prescribed_values(self):
162162
),
163163
},
164164
})
165-
mesh = torax_pydantic.Grid1D(nx=4, dx=0.25)
165+
mesh = torax_pydantic.Grid1D(nx=4,)
166166
torax_pydantic.set_grid(sources, mesh)
167167
source = sources.generic_current
168168
self.assertLen(source.prescribed_values, 1)

torax/_src/sources/tests/source_profiles_test.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -65,7 +65,7 @@ def test_summed_T_i_profiles_dont_change_when_jitting(self):
6565

6666
def test_merging_source_profiles(self):
6767
"""Tests that the implicit and explicit source profiles merge correctly."""
68-
torax_mesh = torax_pydantic.Grid1D(nx=10, dx=0.1)
68+
torax_mesh = torax_pydantic.Grid1D(nx=10,)
6969
sources = sources_pydantic_model.Sources.from_dict(
7070
default_sources.get_default_source_config()
7171
)

torax/_src/sources/tests/test_lib.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,7 @@ def test_build_dynamic_params(self):
5959
self.assertIsInstance(source, self._source_config_class)
6060
torax_pydantic.set_grid(
6161
source,
62-
torax_pydantic.Grid1D(nx=4, dx=0.25),
62+
torax_pydantic.Grid1D(nx=4,),
6363
)
6464
dynamic_params = source.build_dynamic_params(t=0.0)
6565
self.assertIsInstance(

torax/_src/torax_pydantic/interpolated_param_2d.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,11 +39,13 @@ class Grid1D(model_base.BaseModelFrozen):
3939
4040
Attributes:
4141
nx: Number of cells.
42-
dx: Distance between cell centers.
4342
"""
4443

4544
nx: typing_extensions.Annotated[pydantic.conint(ge=4), model_base.JAX_STATIC]
46-
dx: typing_extensions.Annotated[pydantic.PositiveFloat, model_base.JAX_STATIC]
45+
46+
@functools.cached_property
47+
def dx(self) -> float:
48+
return 1 / self.nx
4749

4850
@property
4951
def face_centers(self) -> np.ndarray:
@@ -441,7 +443,6 @@ def _update_rule(submodel):
441443
# the same NumPy arrays.
442444
new_grid = Grid1D.model_construct(
443445
nx=grid.nx,
444-
dx=grid.dx,
445446
face_centers=grid.face_centers,
446447
cell_centers=grid.cell_centers,
447448
)

0 commit comments

Comments
 (0)