Skip to content

Commit 9dd161c

Browse files
committed
do not carry dt
1 parent df454ad commit 9dd161c

File tree

3 files changed

+25
-30
lines changed

3 files changed

+25
-30
lines changed

adirondax/hydro/common2d.py

Lines changed: 16 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -49,47 +49,40 @@ def slope_limit(f, dx, f_dx, f_dy):
4949
"""
5050
Apply slope limiter to slopes
5151
"""
52-
53-
f_dx = (
54-
jnp.maximum(
52+
f_new_dx = (
53+
f_dx
54+
* jnp.maximum(
5555
0.0,
5656
jnp.minimum(
5757
1.0, ((f - jnp.roll(f, 1, axis=0)) / dx) / (f_dx + 1.0e-8 * (f_dx == 0))
5858
),
5959
)
60-
* f_dx
61-
)
62-
f_dx = (
63-
jnp.maximum(
60+
* jnp.maximum(
6461
0.0,
6562
jnp.minimum(
6663
1.0,
6764
(-(f - jnp.roll(f, -1, axis=0)) / dx) / (f_dx + 1.0e-8 * (f_dx == 0)),
6865
),
6966
)
70-
* f_dx
7167
)
72-
f_dy = (
73-
jnp.maximum(
68+
f_new_dy = (
69+
f_dy
70+
* jnp.maximum(
7471
0.0,
7572
jnp.minimum(
7673
1.0, ((f - jnp.roll(f, 1, axis=1)) / dx) / (f_dy + 1.0e-8 * (f_dy == 0))
7774
),
7875
)
79-
* f_dy
80-
)
81-
f_dy = (
82-
jnp.maximum(
76+
* jnp.maximum(
8377
0.0,
8478
jnp.minimum(
8579
1.0,
8680
(-(f - jnp.roll(f, -1, axis=1)) / dx) / (f_dy + 1.0e-8 * (f_dy == 0)),
8781
),
8882
)
89-
* f_dy
9083
)
9184

92-
return f_dx, f_dy
85+
return f_new_dx, f_new_dy
9386

9487

9588
def extrapolate_to_face(f, f_dx, f_dy, dx):
@@ -110,9 +103,11 @@ def apply_fluxes(F, flux_F_X, flux_F_Y, dx, dt):
110103
"""
111104
Apply fluxes to conserved variables
112105
"""
113-
F += -dt * dx * flux_F_X
114-
F += dt * dx * jnp.roll(flux_F_X, 1, axis=0)
115-
F += -dt * dx * flux_F_Y
116-
F += dt * dx * jnp.roll(flux_F_Y, 1, axis=1)
106+
F_new = F + (dt * dx) * (
107+
-flux_F_X
108+
+ jnp.roll(flux_F_X, 1, axis=0)
109+
- flux_F_Y
110+
+ jnp.roll(flux_F_Y, 1, axis=1)
111+
)
117112

118-
return F
113+
return F_new

adirondax/simulation.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -175,7 +175,7 @@ def _evolve(self, state):
175175
t_span = self.params["time"]["span"]
176176

177177
use_adaptive_timesteps = True if nt < 1 else False
178-
dt = 0.0 if use_adaptive_timesteps else t_span / nt
178+
dt_ref = jnp.nan if use_adaptive_timesteps else t_span / nt
179179

180180
# Physics flags
181181
use_hydro = self.params["physics"]["hydro"]
@@ -201,18 +201,19 @@ def _evolve(self, state):
201201
V = self._calc_grav_potential(state, k_sq, use_quantum, use_hydro)
202202

203203
# Build the carry:
204-
carry = (state, dt, V, k_sq)
204+
carry = (state, V, k_sq)
205205

206206
def step_fn(carry):
207207
"""
208208
Pure step function: advances state by one timestep.
209209
"""
210-
state, dt, V, k_sq = carry
210+
state, V, k_sq = carry
211211

212212
# Create new state dict to avoid mutation
213213
new_state = {}
214214

215215
# Get the timestep
216+
dt = dt_ref
216217
if use_adaptive_timesteps:
217218
dt = jnp.inf
218219
if use_hydro:
@@ -307,22 +308,22 @@ def step_fn(carry):
307308
# Update diagnostics
308309
new_state["steps_taken"] = state["steps_taken"] + 1
309310

310-
return (new_state, dt, new_V, k_sq)
311+
return (new_state, new_V, k_sq)
311312

312313
# Run the entire loop as a single JIT-compiled function
313314
def run_loop(carry):
314315
if use_adaptive_timesteps:
315316
# def cond_fn(carry):
316-
# state, _, _, _ = carry
317+
# state, _, _ = carry
317318
# return state["t"] < t_span * (1.0 - 1e-10)
318319

319320
# final_carry = jax.lax.while_loop(cond_fn, step_fn, carry)
320321

321322
# do a simple while loop
322-
state, _, _, _ = carry
323+
state, _, _ = carry
323324
while state["t"] < t_span * (1.0 - 1e-10):
324325
carry = step_fn(carry)
325-
state, _, _, _ = carry
326+
state, _, _ = carry
326327
final_carry = carry
327328
else:
328329

@@ -336,7 +337,7 @@ def step_fn_stacked(carry, _):
336337
return final_carry
337338

338339
# Execute the compiled loop
339-
state, _, _, _ = run_loop(carry)
340+
state, _, _ = run_loop(carry)
340341

341342
return state
342343

@@ -347,4 +348,3 @@ def run(self):
347348
self.state["steps_taken"] = 0
348349
self.state = self._evolve(self.state)
349350
jax.block_until_ready(self.state)
350-
# assert jnp.isfinite(self.state["t"]), "state['t'] is NaN/infinity"

examples/orszag_tang/output.png

13 Bytes
Loading

0 commit comments

Comments
 (0)