Skip to content

Commit b85ddd7

Browse files
committed
add rayleigh taylor example
1 parent 122ea03 commit b85ddd7

File tree

8 files changed

+280
-15
lines changed

8 files changed

+280
-15
lines changed

adirondax/gravity.py

Lines changed: 26 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -9,8 +9,30 @@ def calculate_gravitational_potential(rho, k_sq, G, rho_bar):
99
return V
1010

1111

12-
def get_acceleration(V, kx, ky):
13-
V_hat = jnp.fft.fftn(V)
14-
ax = -jnp.real(jnp.fft.ifftn(1.0j * kx * V_hat))
15-
ay = -jnp.real(jnp.fft.ifftn(1.0j * ky * V_hat))
12+
def get_acceleration(V, kx, ky, dx, dy, bc_x_is_reflective, bc_y_is_reflective):
13+
if not bc_x_is_reflective or not bc_y_is_reflective:
14+
V_hat = jnp.fft.fftn(V)
15+
if bc_x_is_reflective:
16+
# 2nd order finite difference
17+
ax = -(jnp.roll(V, -1, axis=0) - jnp.roll(V, 1, axis=0)) / (2.0 * dx)
18+
# one-sided 2nd order difference at boundary
19+
ax = ax.at[0, :].set(-(-3.0 * V[0, :] + 4.0 * V[1, :] - V[2, :]) / (2.0 * dx))
20+
ax = ax.at[-1, :].set(
21+
-(3.0 * V[-1, :] - 4.0 * V[-2, :] + V[-3, :]) / (2.0 * dx)
22+
)
23+
else:
24+
# periodic
25+
ax = -jnp.real(jnp.fft.ifftn(1.0j * kx * V_hat))
26+
if bc_y_is_reflective:
27+
# 2nd order finite difference
28+
ay = -(jnp.roll(V, -1, axis=1) - jnp.roll(V, 1, axis=1)) / (2.0 * dy)
29+
# one-sided 2nd order difference at boundary
30+
ay = ay.at[:, 0].set(-(-3.0 * V[:, 0] + 4.0 * V[:, 1] - V[:, 2]) / (2.0 * dy))
31+
ay = ay.at[:, -1].set(
32+
-(3.0 * V[:, -1] - 4.0 * V[:, -2] + V[:, -3]) / (2.0 * dy)
33+
)
34+
else:
35+
# periodic
36+
ay = -jnp.real(jnp.fft.ifftn(1.0j * ky * V_hat))
37+
1638
return ax, ay

adirondax/hydro/euler2d.py

Lines changed: 95 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -100,11 +100,79 @@ def hydro_euler2d_timestep(rho, vx, vy, P, gamma, dx, dy):
100100
return dt
101101

102102

103+
def add_ghost_cells(rho, vx, vy, P, axis):
104+
"""Add ghost cells for reflective boundary conditions along given axis"""
105+
106+
if axis == 0:
107+
# x-axis
108+
rho_new = jnp.concatenate((rho[0:1, :], rho, rho[-1:, :]), axis=0)
109+
vx_new = jnp.concatenate((-vx[0:1, :], vx, -vx[-1:, :]), axis=0)
110+
vy_new = jnp.concatenate((vy[0:1, :], vy, vy[-1:, :]), axis=0)
111+
P_new = jnp.concatenate((P[0:1, :], P, P[-1:, :]), axis=0)
112+
elif axis == 1:
113+
# y-axis
114+
rho_new = jnp.concatenate((rho[:, 0:1], rho, rho[:, -1:]), axis=1)
115+
vx_new = jnp.concatenate((vx[:, 0:1], vx, vx[:, -1:]), axis=1)
116+
vy_new = jnp.concatenate((-vy[:, 0:1], vy, -vy[:, -1:]), axis=1)
117+
P_new = jnp.concatenate((P[:, 0:1], P, P[:, -1:]), axis=1)
118+
119+
return rho_new, vx_new, vy_new, P_new
120+
121+
122+
def remove_ghost_cells(Mass, Momx, Momy, Energy, axis):
123+
"""Remove ghost cells for reflective boundary conditions along given axis"""
124+
125+
if axis == 0:
126+
# x-axis
127+
Mass_new = Mass[1:-1, :]
128+
Momx_new = Momx[1:-1, :]
129+
Momy_new = Momy[1:-1, :]
130+
Energy_new = Energy[1:-1, :]
131+
elif axis == 1:
132+
# y-axis
133+
Mass_new = Mass[:, 1:-1]
134+
Momx_new = Momx[:, 1:-1]
135+
Momy_new = Momy[:, 1:-1]
136+
Energy_new = Energy[:, 1:-1]
137+
138+
return Mass_new, Momx_new, Momy_new, Energy_new
139+
140+
141+
def set_ghost_gradients(f_dx, axis):
142+
"""Set gradients in ghost cells to (-1) x value of the first interior cell (f_dx already has ghost cells)"""
143+
144+
if axis == 0:
145+
f_dx = f_dx.at[0, :].set(-f_dx[1, :])
146+
f_dx = f_dx.at[-1, :].set(-f_dx[-2, :])
147+
elif axis == 1:
148+
f_dx = f_dx.at[:, 0].set(-f_dx[:, 1])
149+
f_dx = f_dx.at[:, -1].set(-f_dx[:, -2])
150+
151+
return f_dx
152+
153+
103154
def hydro_euler2d_fluxes(
104-
rho, vx, vy, P, gamma, dx, dy, dt, riemann_solver_type, use_slope_limiting
155+
rho,
156+
vx,
157+
vy,
158+
P,
159+
gamma,
160+
dx,
161+
dy,
162+
dt,
163+
riemann_solver_type,
164+
use_slope_limiting,
165+
bc_x_is_reflective,
166+
bc_y_is_reflective,
105167
):
106168
"""Take a simulation timestep"""
107169

170+
# Add Ghost Cells (if needed)
171+
if bc_x_is_reflective:
172+
rho, vx, vy, P = add_ghost_cells(rho, vx, vy, P, axis=0)
173+
if bc_y_is_reflective:
174+
rho, vx, vy, P = add_ghost_cells(rho, vx, vy, P, axis=1)
175+
108176
# get Conserved variables
109177
Mass, Momx, Momy, Energy = get_conserved(rho, vx, vy, P, gamma, dx * dy)
110178

@@ -121,6 +189,18 @@ def hydro_euler2d_fluxes(
121189
vy_dx, vy_dy = slope_limit(vy, vy_dx, vy_dy, dx, dy)
122190
P_dx, P_dy = slope_limit(P, P_dx, P_dy, dx, dy)
123191

192+
# set ghost cell gradients
193+
if bc_x_is_reflective:
194+
rho_dx = set_ghost_gradients(rho_dx, axis=0)
195+
vx_dx = set_ghost_gradients(vx_dx, axis=0)
196+
vy_dx = set_ghost_gradients(vy_dx, axis=0)
197+
P_dx = set_ghost_gradients(P_dx, axis=0)
198+
if bc_y_is_reflective:
199+
rho_dy = set_ghost_gradients(rho_dy, axis=1)
200+
vx_dy = set_ghost_gradients(vx_dy, axis=1)
201+
vy_dy = set_ghost_gradients(vy_dy, axis=1)
202+
P_dy = set_ghost_gradients(P_dy, axis=1)
203+
124204
# extrapolate half-step in time
125205
rho_prime = rho - 0.5 * dt * (vx * rho_dx + rho * vx_dx + vy * rho_dy + rho * vy_dy)
126206
vx_prime = vx - 0.5 * dt * (vx * vx_dx + vy * vx_dy + (1.0 / rho) * P_dx)
@@ -167,15 +247,23 @@ def hydro_euler2d_fluxes(
167247
Momy = apply_fluxes(Momy, flux_Momy_X, flux_Momy_Y, dx, dy, dt)
168248
Energy = apply_fluxes(Energy, flux_Energy_X, flux_Energy_Y, dx, dy, dt)
169249

250+
# remove ghost cells
251+
if bc_x_is_reflective:
252+
Mass, Momx, Momy, Energy = remove_ghost_cells(Mass, Momx, Momy, Energy, axis=0)
253+
if bc_y_is_reflective:
254+
Mass, Momx, Momy, Energy = remove_ghost_cells(Mass, Momx, Momy, Energy, axis=1)
255+
170256
rho, vx, vy, P = get_primitive(Mass, Momx, Momy, Energy, gamma, dx * dy)
171257

172258
return rho, vx, vy, P
173259

174260

175-
def hydro_euler2d_accelerate(rho, vx, vy, P, ax, ay, gamma, dt):
176-
e = P / ((gamma - 1.0) * rho)
177-
e_new = e + dt * (ax * vx + ay * vy)
178-
P_new = (gamma - 1.0) * rho * e_new
179-
vx_new = vx + dt * ax
180-
vy_new = vy + dt * ay
261+
def hydro_euler2d_accelerate(rho, vx, vy, P, ax, ay, gamma, dx, dy, dt):
262+
Mass, Momx, Momy, Energy = get_conserved(rho, vx, vy, P, gamma, dx * dy)
263+
264+
Energy += dt * (Momx * ax + Momy * ay)
265+
Momx += dt * Mass * ax
266+
Momy += dt * Mass * ay
267+
268+
_, vx_new, vy_new, P_new = get_primitive(Mass, Momx, Momy, Energy, gamma, dx * dy)
181269
return vx_new, vy_new, P_new

adirondax/simulation.py

Lines changed: 35 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,19 @@ def __init__(self, params):
4747
):
4848
raise ValueError("'hlld' riemann solver only exists for magnetic=True")
4949

50+
if (
51+
self.params["mesh"]["boundary_condition"][0] != "periodic"
52+
or self.params["mesh"]["boundary_condition"][1] != "periodic"
53+
):
54+
if self.params["physics"]["quantum"]:
55+
raise NotImplementedError(
56+
"Quantum only implemented for periodic boundary conditions."
57+
)
58+
if self.params["physics"]["gravity"]:
59+
raise NotImplementedError(
60+
"Gravity only implemented for periodic boundary conditions."
61+
)
62+
5063
if self.params["output"]["save"] and self.params["time"]["num_timesteps"] > 0:
5164
if (
5265
self.params["time"]["num_timesteps"]
@@ -198,10 +211,16 @@ def _evolve(self, state):
198211
dy = Ly / ny
199212
nt = self.params["time"]["num_timesteps"]
200213
t_span = self.params["time"]["span"]
214+
bc_x = self.params["mesh"]["boundary_condition"][0]
215+
bc_y = self.params["mesh"]["boundary_condition"][1]
201216

202217
use_adaptive_timesteps = True if nt < 1 else False
203218
dt_ref = jnp.nan if use_adaptive_timesteps else t_span / nt
204219

220+
# boundary conditions
221+
bc_x_is_reflective = True if bc_x == "reflective" else False
222+
bc_y_is_reflective = True if bc_y == "reflective" else False
223+
205224
# Physics flags
206225
use_hydro = self.params["physics"]["hydro"]
207226
use_magnetic = self.params["physics"]["magnetic"]
@@ -290,14 +309,25 @@ def _kick(state, k_sq, dt):
290309
# apply
291310
if use_gravity or use_external_potential:
292311
if use_quantum:
293-
state["psi"] = quantum_kick(state["psi"], V, m_per_hbar, dt / 2.0)
312+
state["psi"] = quantum_kick(state["psi"], V, m_per_hbar, dt)
294313
if use_hydro:
295314
if use_magnetic:
296315
raise NotImplementedError("implement me.")
297316
kx, ky = self.kgrid
298-
ax, ay = get_acceleration(V, kx, ky)
317+
ax, ay = get_acceleration(
318+
V, kx, ky, dx, dy, bc_x_is_reflective, bc_y_is_reflective
319+
)
299320
state["vx"], state["vy"], state["P"] = hydro_euler2d_accelerate(
300-
state["rho"], state["vx"], state["vy"], state["P"], ax, ay, dt
321+
state["rho"],
322+
state["vx"],
323+
state["vy"],
324+
state["P"],
325+
ax,
326+
ay,
327+
gamma,
328+
dx,
329+
dy,
330+
dt,
301331
)
302332

303333
def _drift(state, k_sq, dt):
@@ -342,6 +372,8 @@ def _drift(state, k_sq, dt):
342372
dt,
343373
riemann_solver_type,
344374
use_slope_limiting,
375+
bc_x_is_reflective,
376+
bc_y_is_reflective,
345377
)
346378
)
347379

examples/rayleigh_taylor/README.md

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
# Rayleigh-Taylor
2+
3+
Simulate the Rayleigh-Taylor Instability (Euler equations)
4+
5+
Philip Mocz (2025)
6+
7+
Usage:
8+
9+
```console
10+
python rayleigh_taylor.py
11+
```
12+
13+
Takes around 7 seconds to run on my macbook (cpu).
14+
15+
16+
## Simulation snapshots
17+
18+
<div style="display:flex;flex-wrap:wrap;gap:8px">
19+
<img src="output.png" alt="output" width="45%"/>
20+
</div>

examples/rayleigh_taylor/movie.gif

734 KB
Loading
55.7 KB
Loading
Lines changed: 103 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,103 @@
1+
import jax.numpy as jnp
2+
3+
# TODO: REMOVE THE FOLLOWING LINES
4+
import sys
5+
6+
sys.path.append("../../")
7+
8+
import adirondax as adx
9+
import time
10+
import matplotlib.pyplot as plt
11+
12+
"""
13+
Simulate the Rayleigh-Taylor Instability
14+
15+
Philip Mocz (2025)
16+
"""
17+
18+
19+
def set_up_simulation():
20+
# Define the parameters for the simulation
21+
nx = 64
22+
ny = 192
23+
nt = 10000 # -1
24+
t_stop = 15.0
25+
26+
params = {
27+
"physics": {
28+
"hydro": True,
29+
"external_potential": True,
30+
},
31+
"mesh": {
32+
"type": "cartesian",
33+
"resolution": [nx, ny],
34+
"box_size": [0.5, 1.5],
35+
"boundary_condition": ["periodic", "reflective"],
36+
},
37+
"time": {
38+
"span": t_stop,
39+
"num_timesteps": nt,
40+
},
41+
"output": {
42+
"num_checkpoints": 100,
43+
"save": True,
44+
"plot_dynamic_range": 2.0,
45+
},
46+
"hydro": {
47+
"eos": {"type": "ideal", "gamma": 1.4},
48+
"slope_limiting": False,
49+
},
50+
}
51+
52+
# Initialize the simulation
53+
sim = adx.Simulation(params)
54+
55+
# Set initial conditions
56+
# (heavy fluid on top of light)
57+
sim.state["t"] = 0.0
58+
X, Y = sim.mesh
59+
w0 = 0.0025
60+
P0 = 2.5
61+
g = 0.1
62+
sim.state["rho"] = 1.0 + (Y > 0.75)
63+
sim.state["vx"] = jnp.zeros(X.shape)
64+
sim.state["vy"] = (
65+
w0 * (1.0 - jnp.cos(4.0 * jnp.pi * X)) * (1.0 - jnp.cos(4.0 * jnp.pi * Y / 3.0))
66+
)
67+
sim.state["P"] = P0 - g * (Y - 0.75) * sim.state["rho"]
68+
69+
# external potential
70+
def external_potential(x, y):
71+
V = g * y
72+
return V
73+
74+
sim.external_potential = external_potential
75+
76+
return sim
77+
78+
79+
def make_plot(sim):
80+
# Plot the solution
81+
plt.figure(figsize=(4, 6), dpi=80)
82+
plt.imshow(sim.state["rho"].T, cmap="jet", vmin=0.8, vmax=2.2)
83+
plt.gca().invert_yaxis()
84+
plt.colorbar(label="density")
85+
plt.tight_layout()
86+
plt.savefig("output.png", dpi=240)
87+
plt.show()
88+
89+
90+
def main():
91+
sim = set_up_simulation()
92+
93+
# Evolve the system
94+
t0 = time.time()
95+
sim.run()
96+
print("Run time (s): ", time.time() - t0)
97+
print("Steps taken:", sim.steps_taken)
98+
99+
make_plot(sim)
100+
101+
102+
if __name__ == "__main__":
103+
main()

scripts/make_gif.sh

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
# Usage: ./make_gif.sh /path/to/folder output.gif
44

55
folder="$1"
6-
output="${2:-output.gif}"
6+
output="${2:-movie.gif}"
77

88
if [ -z "$folder" ]; then
99
echo "Usage: $0 /path/to/folder [output.gif]"

0 commit comments

Comments
 (0)