Skip to content

Commit ee7889a

Browse files
committed
overhaul param handling
1 parent 51ca27e commit ee7889a

File tree

9 files changed

+125
-100
lines changed

9 files changed

+125
-100
lines changed

adirondax/params_default.json

Lines changed: 24 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
{
22
"physics": {
33
"hydro": {
4-
"default": true,
4+
"default": false,
55
"description": "switch on for hydrodynamics."
66
},
77
"magnetic": {
@@ -32,21 +32,21 @@
3232
},
3333
"box_size": {
3434
"default": [1.0, 1.0],
35-
"description": "domain box size."
35+
"description": "domain box size: (x,y), (x,y,z), (r,z)."
3636
},
3737
"resolution": {
3838
"default": [32, 32],
3939
"description": "resolution for each dimension."
4040
}
4141
},
4242
"time": {
43-
"start": {
43+
"span": {
4444
"default": 0.0,
45-
"description": "simulation start time."
45+
"description": "simulation span time."
4646
},
47-
"end": {
48-
"default": 1.0,
49-
"description": "simulation end time."
47+
"num_timesteps": {
48+
"default": -1,
49+
"description": "set to a positive value to fix the number of (equi-spaced) timesteps."
5050
}
5151
},
5252
"output": {
@@ -69,8 +69,22 @@
6969
},
7070
"hydro": {
7171
"eos": {
72-
"default": {"type": "ideal", "gamma": 1.66666667},
73-
"description": "equations of state: {'type': 'ideal', 'gamma':}, {'type': 'isothermal', 'sound_speed':}."
72+
"type": {
73+
"default": "ideal",
74+
"description": "options: ideal, isothermal, tabular."
75+
},
76+
"gamma": {
77+
"default": 1.66666667,
78+
"description": "adiabatic index for 'ideal' gas."
79+
},
80+
"sound_speed": {
81+
"default": 1.0,
82+
"description": "isothermal sound speed for 'isothermal' gas."
83+
},
84+
"table_path": {
85+
"default": "./eos_table.h5",
86+
"description": "path to tabular EOS data file for 'tabular' gas."
87+
}
7488
}
7589
},
7690
"quantum": {
@@ -97,4 +111,4 @@
97111
"default": "unknown",
98112
"description": "adirondax version used (auto detected)."
99113
}
100-
}
114+
}

adirondax/simulation.py

Lines changed: 77 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -1,12 +1,12 @@
11
import jax
22
import jax.numpy as jnp
3-
import copy
43

54
from .constants import constants
65
from .hydro.euler2d import hydro_euler2d_fluxes
76
from .hydro.mhd2d import hydro_mhd2d_fluxes
87
from .quantum import quantum_kick, quantum_drift
98
from .gravity import calculate_gravitational_potential
9+
from .utils import set_up_parameters, print_parameters
1010

1111

1212
class Simulation:
@@ -20,67 +20,89 @@ class Simulation:
2020
"""
2121

2222
def __init__(self, params):
23-
# simulation parameters
24-
self._params = copy.deepcopy(params)
25-
self._nt = params["simulation"]["n_timestep"]
26-
self._dt = params["simulation"]["timestep"]
27-
self._dim = len(params["mesh"]["resolution"])
28-
self._nx = params["mesh"]["resolution"][0]
29-
self._Lx = params["mesh"]["boxsize"][0]
30-
self._dx = self._Lx / self._nx
31-
if self._dim > 1:
32-
self._ny = params["mesh"]["resolution"][1]
33-
self._Ly = params["mesh"]["boxsize"][1]
34-
self._dy = self._Ly / self._ny
35-
if self._dim == 3:
36-
self._nz = params["mesh"]["resolution"][2]
37-
self._Lz = params["mesh"]["boxsize"][2]
38-
self._dz = self._Lz / self._nz
23+
# start from default simulation parameters and update with user params
24+
self._params = set_up_parameters(params)
25+
26+
# additional checks
27+
if len(self.resolution) != len(self.box_size):
28+
raise ValueError("'resolution' and 'box_size' must have same shape")
29+
30+
if self.dim == 3:
31+
raise NotImplementedError("3D is not yet implemented.")
32+
33+
# print info
34+
if jax.process_index() == 0:
35+
print("Simulation parameters:")
36+
print_parameters(self.params)
3937

4038
# simulation state
4139
self.state = {}
42-
self.state["t"] = jnp.array(0.0)
43-
if params["physics"]["hydro"]:
44-
self.state["rho"] = jnp.zeros((self._nx, self._ny))
45-
self.state["vx"] = jnp.zeros((self._nx, self._ny))
46-
self.state["vy"] = jnp.zeros((self._nx, self._ny))
47-
self.state["P"] = jnp.zeros((self._nx, self._ny))
48-
if params["physics"]["magnetic"]:
49-
self.state["bx"] = jnp.zeros((self._nx, self._ny))
50-
self.state["by"] = jnp.zeros((self._nx, self._ny))
51-
if params["physics"]["quantum"]:
52-
self.state["psi"] = jnp.zeros((self._nx, self._ny), dtype=jnp.complex64)
40+
self.state["t"] = jnp.array(0) + jnp.nan
41+
if self.params["physics"]["hydro"]:
42+
self.state["rho"] = jnp.zeros(self.resolution) + jnp.nan
43+
self.state["vx"] = jnp.zeros(self.resolution) + jnp.nan
44+
self.state["vy"] = jnp.zeros(self.resolution) + jnp.nan
45+
self.state["P"] = jnp.zeros(self.resolution) + jnp.nan
46+
if self.params["physics"]["magnetic"]:
47+
self.state["bx"] = jnp.zeros(self.resolution) + jnp.nan
48+
self.state["by"] = jnp.zeros(self.resolution) + jnp.nan
49+
if self.params["physics"]["quantum"]:
50+
self.state["psi"] = (
51+
jnp.zeros(self.resolution, dtype=jnp.complex64) + jnp.nan
52+
)
5353

5454
@property
55-
def nt(self):
56-
return self._nt
55+
def resolution(self):
56+
"""
57+
Return the resolution (per dimension) of the simulation
58+
"""
59+
return self.params["mesh"]["resolution"]
5760

5861
@property
59-
def dt(self):
60-
return self._dt
62+
def box_size(self):
63+
"""
64+
Return the box size of the simulation
65+
"""
66+
return self.params["mesh"]["box_size"]
6167

6268
@property
6369
def dim(self):
64-
return self._dim
70+
"""
71+
Return the dimension of the simulation
72+
"""
73+
return len(self.resolution)
6574

6675
@property
6776
def params(self):
77+
"""
78+
Return the parameters of the simulation
79+
"""
6880
return self._params
6981

7082
@property
7183
def mesh(self):
72-
dx = self._dx
73-
dy = self._dy
74-
xlin = jnp.linspace(0.5 * dx, self._Lx - 0.5 * dx, self._nx)
75-
ylin = jnp.linspace(0.5 * dy, self._Ly - 0.5 * dy, self._ny)
84+
"""
85+
Return the simulation mesh
86+
"""
87+
Lx = self.box_size[0]
88+
Ly = self.box_size[1]
89+
nx = self.resolution[0]
90+
ny = self.resolution[1]
91+
dx = Lx / nx
92+
dy = Ly / ny
93+
xlin = jnp.linspace(0.5 * dx, Lx - 0.5 * dx, nx)
94+
ylin = jnp.linspace(0.5 * dy, Ly - 0.5 * dy, ny)
7695
X, Y = jnp.meshgrid(xlin, ylin, indexing="ij")
7796
return X, Y
7897

7998
@property
8099
def kgrid(self):
81-
n = self._nx
82-
L = self._Lx
83-
klin = 2.0 * jnp.pi / L * jnp.arange(-n / 2, n / 2)
100+
"""
101+
Return the simulation spectral grid
102+
"""
103+
Lx = self.box_size[0]
104+
nx = self.resolution[0]
105+
klin = (2.0 * jnp.pi / Lx) * jnp.arange(-nx / 2, nx / 2)
84106
kx, ky = jnp.meshgrid(klin, klin)
85107
kx = jnp.fft.ifftshift(kx)
86108
ky = jnp.fft.ifftshift(ky)
@@ -99,6 +121,9 @@ def _calc_grav_potential(self, state, k_sq, use_quantum, use_hydro):
99121

100122
@property
101123
def potential(self):
124+
"""
125+
Return the gravitational potential
126+
"""
102127
kx, ky = self.kgrid
103128
k_sq = kx**2 + ky**2
104129
return self._calc_grav_potential(
@@ -124,17 +149,25 @@ def _evolve(self, state):
124149
"""
125150

126151
# Simulation parameters
127-
dt = self._dt
128-
nt = self._nt
129-
dx = self._dx
152+
Lx = self.box_size[0]
153+
nx = self.resolution[0]
154+
dx = Lx / nx
155+
nt = self.params["time"]["num_timesteps"]
156+
t_span = self.params["time"]["span"]
157+
158+
fixed_timestepping = True if nt > 0 else False
159+
if fixed_timestepping:
160+
dt = t_span / nt
161+
162+
assert fixed_timestepping # XXX for now
130163

131164
# Physics flags
132165
use_hydro = self.params["physics"]["hydro"]
133166
use_magnetic = self.params["physics"]["magnetic"]
134167
use_quantum = self.params["physics"]["quantum"]
135168
use_gravity = self.params["physics"]["gravity"]
136169

137-
gamma = self.params["hydro"]["eos"]["gamma"] if use_hydro else None
170+
gamma = self.params["hydro"]["eos"]["gamma"]
138171

139172
# Precompute Fourier space variables
140173
k_sq = None

adirondax/utils.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import json
77
from importlib.metadata import version
88
import jax
9-
import jax.numpy as jnp
109

1110

1211
def print_parameters(params):

examples/kelvin_helmholtz_instability/kelvin_helmholtz_instability.py

Lines changed: 4 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -21,24 +21,19 @@ def set_up_simulation():
2121
n = 256
2222
nt = 1500 * int(n / 128)
2323
t_stop = 2.0
24-
dt = t_stop / nt
2524

2625
params = {
2726
"physics": {
2827
"hydro": True,
29-
"magnetic": False,
30-
"quantum": False,
31-
"gravity": False,
3228
},
3329
"mesh": {
3430
"type": "cartesian",
3531
"resolution": [n, n],
36-
"boxsize": [1.0, 1.0],
32+
"box_size": [1.0, 1.0],
3733
},
38-
"simulation": {
39-
"stop_time": t_stop,
40-
"timestep": dt,
41-
"n_timestep": nt,
34+
"time": {
35+
"span": t_stop,
36+
"num_timesteps": nt,
4237
},
4338
"hydro": {
4439
"eos": {"type": "ideal", "gamma": 5.0 / 3.0},

examples/orszag_tang/orszag_tang.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -25,7 +25,6 @@ def set_up_simulation():
2525
n = 512
2626
nt = 100 * int(n / 32)
2727
t_stop = 0.5
28-
dt = t_stop / nt
2928
gamma = 5.0 / 3.0
3029
box_size = 1.0
3130
dx = box_size / n
@@ -34,18 +33,15 @@ def set_up_simulation():
3433
"physics": {
3534
"hydro": True,
3635
"magnetic": True,
37-
"quantum": False,
38-
"gravity": False,
3936
},
4037
"mesh": {
4138
"type": "cartesian",
4239
"resolution": [n, n],
43-
"boxsize": [box_size, box_size],
40+
"box_size": [box_size, box_size],
4441
},
45-
"simulation": {
46-
"stop_time": t_stop,
47-
"timestep": dt,
48-
"n_timestep": nt,
42+
"time": {
43+
"span": t_stop,
44+
"num_timesteps": nt,
4945
},
5046
"hydro": {
5147
"eos": {"type": "ideal", "gamma": gamma},
8 Bytes
Loading

examples/schrodinger_poisson/schrodinger_poisson.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -25,24 +25,20 @@ def set_up_simulation():
2525
n = 128
2626
nt = 100 * int(n / 128)
2727
t_stop = 0.03
28-
dt = t_stop / nt
2928

3029
params = {
3130
"physics": {
32-
"hydro": False,
33-
"magnetic": False,
3431
"quantum": True,
3532
"gravity": True,
3633
},
3734
"mesh": {
3835
"type": "cartesian",
3936
"resolution": [n, n],
40-
"boxsize": [1.0, 1.0],
37+
"box_size": [1.0, 1.0],
4138
},
42-
"simulation": {
43-
"stop_time": t_stop,
44-
"timestep": dt,
45-
"n_timestep": nt,
39+
"time": {
40+
"span": t_stop,
41+
"num_timesteps": nt,
4642
},
4743
}
4844

@@ -59,8 +55,8 @@ def solve_inverse_problem(sim):
5955
rho_target = 1.0 - 0.5 * (rho_target - 0.5)
6056
rho_target /= jnp.mean(rho_target)
6157

62-
assert rho_target.shape[0] == sim.params["mesh"]["resolution"][0]
63-
assert rho_target.shape[1] == sim.params["mesh"]["resolution"][1]
58+
assert rho_target.shape[0] == sim.resolution[0]
59+
assert rho_target.shape[1] == sim.resolution[1]
6460

6561
# Define the loss function for the optimization
6662
@jax.jit

examples/schrodinger_poisson/schrodinger_poisson_optax.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -27,24 +27,20 @@ def set_up_simulation():
2727
n = 128
2828
nt = 100 * int(n / 128)
2929
t_stop = 0.03
30-
dt = t_stop / nt
3130

3231
params = {
3332
"physics": {
34-
"hydro": False,
35-
"magnetic": False,
3633
"quantum": True,
3734
"gravity": True,
3835
},
3936
"mesh": {
4037
"type": "cartesian",
4138
"resolution": [n, n],
42-
"boxsize": [1.0, 1.0],
39+
"box_size": [1.0, 1.0],
4340
},
44-
"simulation": {
45-
"stop_time": t_stop,
46-
"timestep": dt,
47-
"n_timestep": nt,
41+
"time": {
42+
"span": t_stop,
43+
"num_timesteps": nt,
4844
},
4945
}
5046

@@ -110,8 +106,8 @@ def solve_inverse_problem(sim):
110106
rho_target = 1.0 - 0.5 * (rho_target - 0.5)
111107
rho_target /= jnp.mean(rho_target)
112108

113-
assert rho_target.shape[0] == sim.params["mesh"]["resolution"][0]
114-
assert rho_target.shape[1] == sim.params["mesh"]["resolution"][1]
109+
assert rho_target.shape[0] == sim.resolution[0]
110+
assert rho_target.shape[1] == sim.resolution[1]
115111

116112
# Define the loss function for the optimization
117113
@jax.jit

0 commit comments

Comments
 (0)