Skip to content

Commit dfa09d5

Browse files
committed
add adaptive timesteps
1 parent ee7889a commit dfa09d5

File tree

8 files changed

+121
-28
lines changed

8 files changed

+121
-28
lines changed

adirondax/hydro/euler2d.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -61,15 +61,21 @@ def get_flux(rho_L, rho_R, vx_L, vx_R, vy_L, vy_R, P_L, P_R, gamma):
6161
return flux_Mass, flux_Momx, flux_Momy, flux_Energy
6262

6363

64+
def hydro_euler2d_timestep(rho, vx, vy, P, gamma, dx):
65+
"""Calculate the simulation timestep based on CFL condition"""
66+
67+
# get time step (CFL) = dx / max signal speed
68+
dt = jnp.min(dx / (jnp.sqrt(gamma * P / rho) + jnp.sqrt(vx**2 + vy**2)))
69+
70+
return dt
71+
72+
6473
def hydro_euler2d_fluxes(rho, vx, vy, P, gamma, dx, dt):
6574
"""Take a simulation timestep"""
6675

6776
# get Conserved variables
6877
Mass, Momx, Momy, Energy = get_conserved(rho, vx, vy, P, gamma, dx**2)
6978

70-
# get time step (CFL) = dx / max signal speed
71-
# dt = courant_fac * jnp.min(dx / (jnp.sqrt(gamma * P / rho) + jnp.sqrt(vx**2 + vy**2)))
72-
7379
# calculate gradients
7480
rho_dx, rho_dy = get_gradient(rho, dx)
7581
vx_dx, vx_dy = get_gradient(vx, dx)

adirondax/hydro/mhd2d.py

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -421,16 +421,26 @@ def get_flux_hlld(
421421
return flux_Mass, flux_Momx, flux_Momy, flux_Energy, flux_By
422422

423423

424+
def hydro_mhd2d_timestep(rho, vx, vy, P, bx, by, gamma, dx):
425+
"""Calculate the simulation timestep based on CFL condition"""
426+
427+
# get time step (CFL) = dx / max signal speed
428+
Bx, By = get_avg(bx, by)
429+
dt = jnp.min(
430+
dx
431+
/ (jnp.sqrt(gamma * P / rho) + jnp.sqrt(vx**2 + vy**2 + (Bx**2 + By**2) / rho))
432+
)
433+
434+
return dt
435+
436+
424437
def hydro_mhd2d_fluxes(rho, vx, vy, P, bx, by, gamma, dx, dt):
425438
"""Take a simulation timestep"""
426439

427440
# get Conserved variables
428441
Bx, By = get_avg(bx, by)
429442
Mass, Momx, Momy, Energy = get_conserved(rho, vx, vy, P, Bx, By, gamma, dx**2)
430443

431-
# get time step (CFL) = dx / max signal speed
432-
# dt = courant_fac * jnp.min(dx / (jnp.sqrt(gamma * P / rho) + jnp.sqrt(vx**2 + vy**2)))
433-
434444
# calculate gradients
435445
rho_dx, rho_dy = get_gradient(rho, dx)
436446
vx_dx, vx_dy = get_gradient(vx, dx)

adirondax/params_default.json

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,14 @@
8585
"default": "./eos_table.h5",
8686
"description": "path to tabular EOS data file for 'tabular' gas."
8787
}
88+
},
89+
"riemann_solver": {
90+
"default": "llf",
91+
"description": "options: 'llf', 'hlld'."
92+
},
93+
"cfl": {
94+
"default": 0.5,
95+
"description": "CFL number for hydrodynamics timestep."
8896
}
8997
},
9098
"quantum": {

adirondax/quantum.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,7 @@ def quantum_drift(psi, k_sq, m_per_hbar, dt):
1313
psi_hat = jnp.exp(dt * (-1.0j * k_sq / m_per_hbar / 2.0)) * psi_hat
1414
psi = jnp.fft.ifftn(psi_hat)
1515
return psi
16+
17+
18+
def quantum_timestep(m_per_hbar, dx):
19+
return (m_per_hbar / 6.0) * (dx * dx)

adirondax/simulation.py

Lines changed: 84 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -2,9 +2,9 @@
22
import jax.numpy as jnp
33

44
from .constants import constants
5-
from .hydro.euler2d import hydro_euler2d_fluxes
6-
from .hydro.mhd2d import hydro_mhd2d_fluxes
7-
from .quantum import quantum_kick, quantum_drift
5+
from .hydro.euler2d import hydro_euler2d_fluxes, hydro_euler2d_timestep
6+
from .hydro.mhd2d import hydro_mhd2d_fluxes, hydro_mhd2d_timestep
7+
from .quantum import quantum_kick, quantum_drift, quantum_timestep
88
from .gravity import calculate_gravitational_potential
99
from .utils import set_up_parameters, print_parameters
1010

@@ -37,7 +37,7 @@ def __init__(self, params):
3737

3838
# simulation state
3939
self.state = {}
40-
self.state["t"] = jnp.array(0) + jnp.nan
40+
self.state["t"] = jnp.array(0.0) + jnp.nan
4141
if self.params["physics"]["hydro"]:
4242
self.state["rho"] = jnp.zeros(self.resolution) + jnp.nan
4343
self.state["vx"] = jnp.zeros(self.resolution) + jnp.nan
@@ -51,6 +51,9 @@ def __init__(self, params):
5151
jnp.zeros(self.resolution, dtype=jnp.complex64) + jnp.nan
5252
)
5353

54+
# extra info to keep track of
55+
self.state["steps_taken"] = jnp.array(0) + jnp.nan
56+
5457
@property
5558
def resolution(self):
5659
"""
@@ -72,6 +75,13 @@ def dim(self):
7275
"""
7376
return len(self.resolution)
7477

78+
@property
79+
def steps_taken(self):
80+
"""
81+
Return the number of steps taken in the simulation
82+
"""
83+
return self.state["steps_taken"]
84+
7585
@property
7686
def params(self):
7787
"""
@@ -155,11 +165,8 @@ def _evolve(self, state):
155165
nt = self.params["time"]["num_timesteps"]
156166
t_span = self.params["time"]["span"]
157167

158-
fixed_timestepping = True if nt > 0 else False
159-
if fixed_timestepping:
160-
dt = t_span / nt
161-
162-
assert fixed_timestepping # XXX for now
168+
use_adaptive_timesteps = True if nt < 1 else False
169+
dt = 0.0 if use_adaptive_timesteps else t_span / nt
163170

164171
# Physics flags
165172
use_hydro = self.params["physics"]["hydro"]
@@ -168,6 +175,9 @@ def _evolve(self, state):
168175
use_gravity = self.params["physics"]["gravity"]
169176

170177
gamma = self.params["hydro"]["eos"]["gamma"]
178+
cfl = self.params["hydro"]["cfl"]
179+
180+
m_per_hbar = 1.0 # XXX
171181

172182
# Precompute Fourier space variables
173183
k_sq = None
@@ -180,27 +190,56 @@ def _evolve(self, state):
180190
if use_gravity:
181191
V = self._calc_grav_potential(state, k_sq, use_quantum, use_hydro)
182192

183-
# Build the carry: (state, V, k_sq)
184-
carry = (state, V, k_sq)
193+
# Build the carry:
194+
carry = (state, dt, V, k_sq)
185195

186-
def step_fn(carry, _):
196+
def step_fn(carry):
187197
"""
188198
Pure step function: advances state by one timestep.
189-
Returns new carry and None (no stacked outputs).
190199
"""
191-
state, V, k_sq = carry
200+
state, dt, V, k_sq = carry
192201

193202
# Create new state dict to avoid mutation
194203
new_state = {}
195204

205+
# Get the timestep
206+
if use_adaptive_timesteps:
207+
dt = jnp.inf
208+
if use_hydro:
209+
if use_magnetic:
210+
dt_hydro = hydro_mhd2d_timestep(
211+
state["rho"],
212+
state["vx"],
213+
state["vy"],
214+
state["P"],
215+
state["bx"],
216+
state["by"],
217+
gamma,
218+
dx,
219+
)
220+
else:
221+
dt_hydro = hydro_euler2d_timestep(
222+
state["rho"],
223+
state["vx"],
224+
state["vy"],
225+
state["P"],
226+
gamma,
227+
dx,
228+
)
229+
dt = jnp.minimum(dt, cfl * dt_hydro)
230+
if use_quantum:
231+
dt_quantum = quantum_timestep(m_per_hbar, dx)
232+
dt = jnp.minimum(dt, dt_quantum)
233+
dt = jnp.minimum(dt, t_span - state["t"])
234+
196235
# Kick (half-step) - quantum + gravity
197236
psi = state.get("psi")
198237
if use_quantum and use_gravity and psi is not None:
199-
psi = quantum_kick(psi, V, 1.0, dt / 2.0)
238+
psi = quantum_kick(psi, V, m_per_hbar, dt / 2.0)
200239

201240
# Drift (full-step) - quantum
202241
if use_quantum and psi is not None:
203-
psi = quantum_drift(psi, k_sq, 1.0, dt)
242+
psi = quantum_drift(psi, k_sq, m_per_hbar, dt)
204243

205244
if use_quantum:
206245
new_state["psi"] = psi
@@ -254,21 +293,47 @@ def step_fn(carry, _):
254293
# Update time
255294
new_state["t"] = state["t"] + dt
256295

257-
return (new_state, new_V, k_sq), None
296+
# Update diagnostics
297+
new_state["steps_taken"] = state["steps_taken"] + 1
298+
299+
return (new_state, dt, new_V, k_sq)
258300

259301
# Run the entire loop as a single JIT-compiled function
260302
def run_loop(carry):
261-
final_carry, _ = jax.lax.scan(step_fn, carry, xs=None, length=nt)
303+
if use_adaptive_timesteps:
304+
# def cond_fn(carry):
305+
# state, _, _, _ = carry
306+
# return state["t"] < t_span * (1.0 - 1e-10)
307+
308+
# final_carry = jax.lax.while_loop(cond_fn, step_fn, carry)
309+
310+
# do a simple while loop
311+
state, _, _, _ = carry
312+
while state["t"] < t_span * (1.0 - 1e-10):
313+
carry = step_fn(carry)
314+
state, _, _, _ = carry
315+
final_carry = carry
316+
else:
317+
318+
def step_fn_stacked(carry, _):
319+
# Returns new carry and None (no stacked outputs) for jax.lax.scan()
320+
return step_fn(carry), None
321+
322+
final_carry, _ = jax.lax.scan(
323+
step_fn_stacked, carry, xs=None, length=nt
324+
)
262325
return final_carry
263326

264327
# Execute the compiled loop
265-
state, _, _ = run_loop(carry)
328+
state, _, _, _ = run_loop(carry)
266329

267330
return state
268331

269332
def run(self):
270333
"""
271334
Run the simulation
272335
"""
336+
self.state["steps_taken"] = 0
273337
self.state = self._evolve(self.state)
274338
jax.block_until_ready(self.state)
339+
# assert jnp.isfinite(self.state["t"]), "state['t'] is NaN/infinity"
13 Bytes
Loading

examples/orszag_tang/orszag_tang.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,6 @@
2323
def set_up_simulation():
2424
# Define the parameters for the simulation
2525
n = 512
26-
nt = 100 * int(n / 32)
2726
t_stop = 0.5
2827
gamma = 5.0 / 3.0
2928
box_size = 1.0
@@ -41,10 +40,10 @@ def set_up_simulation():
4140
},
4241
"time": {
4342
"span": t_stop,
44-
"num_timesteps": nt,
4543
},
4644
"hydro": {
4745
"eos": {"type": "ideal", "gamma": gamma},
46+
"cfl": 0.6,
4847
},
4948
}
5049

@@ -90,7 +89,8 @@ def main():
9089
# Evolve the system
9190
t0 = time.time()
9291
sim.run()
93-
print("Run time (s): ", time.time() - t0)
92+
print("Steps taken:", sim.steps_taken)
93+
print("Run time (s):", time.time() - t0)
9494

9595
make_plot(sim)
9696

examples/orszag_tang/output.png

-404 Bytes
Loading

0 commit comments

Comments
 (0)