22import jax .numpy as jnp
33
44from .constants import constants
5- from .hydro .euler2d import hydro_euler2d_fluxes
6- from .hydro .mhd2d import hydro_mhd2d_fluxes
7- from .quantum import quantum_kick , quantum_drift
5+ from .hydro .euler2d import hydro_euler2d_fluxes , hydro_euler2d_timestep
6+ from .hydro .mhd2d import hydro_mhd2d_fluxes , hydro_mhd2d_timestep
7+ from .quantum import quantum_kick , quantum_drift , quantum_timestep
88from .gravity import calculate_gravitational_potential
99from .utils import set_up_parameters , print_parameters
1010
@@ -37,7 +37,7 @@ def __init__(self, params):
3737
3838 # simulation state
3939 self .state = {}
40- self .state ["t" ] = jnp .array (0 ) + jnp .nan
40+ self .state ["t" ] = jnp .array (0.0 ) + jnp .nan
4141 if self .params ["physics" ]["hydro" ]:
4242 self .state ["rho" ] = jnp .zeros (self .resolution ) + jnp .nan
4343 self .state ["vx" ] = jnp .zeros (self .resolution ) + jnp .nan
@@ -51,6 +51,9 @@ def __init__(self, params):
5151 jnp .zeros (self .resolution , dtype = jnp .complex64 ) + jnp .nan
5252 )
5353
54+ # extra info to keep track of
55+ self .state ["steps_taken" ] = jnp .array (0 ) + jnp .nan
56+
5457 @property
5558 def resolution (self ):
5659 """
@@ -72,6 +75,13 @@ def dim(self):
7275 """
7376 return len (self .resolution )
7477
78+ @property
79+ def steps_taken (self ):
80+ """
81+ Return the number of steps taken in the simulation
82+ """
83+ return self .state ["steps_taken" ]
84+
7585 @property
7686 def params (self ):
7787 """
@@ -155,11 +165,8 @@ def _evolve(self, state):
155165 nt = self .params ["time" ]["num_timesteps" ]
156166 t_span = self .params ["time" ]["span" ]
157167
158- fixed_timestepping = True if nt > 0 else False
159- if fixed_timestepping :
160- dt = t_span / nt
161-
162- assert fixed_timestepping # XXX for now
168+ use_adaptive_timesteps = True if nt < 1 else False
169+ dt = 0.0 if use_adaptive_timesteps else t_span / nt
163170
164171 # Physics flags
165172 use_hydro = self .params ["physics" ]["hydro" ]
@@ -168,6 +175,9 @@ def _evolve(self, state):
168175 use_gravity = self .params ["physics" ]["gravity" ]
169176
170177 gamma = self .params ["hydro" ]["eos" ]["gamma" ]
178+ cfl = self .params ["hydro" ]["cfl" ]
179+
180+ m_per_hbar = 1.0 # XXX
171181
172182 # Precompute Fourier space variables
173183 k_sq = None
@@ -180,27 +190,56 @@ def _evolve(self, state):
180190 if use_gravity :
181191 V = self ._calc_grav_potential (state , k_sq , use_quantum , use_hydro )
182192
183- # Build the carry: (state, V, k_sq)
184- carry = (state , V , k_sq )
193+ # Build the carry:
194+ carry = (state , dt , V , k_sq )
185195
186- def step_fn (carry , _ ):
196+ def step_fn (carry ):
187197 """
188198 Pure step function: advances state by one timestep.
189- Returns new carry and None (no stacked outputs).
190199 """
191- state , V , k_sq = carry
200+ state , dt , V , k_sq = carry
192201
193202 # Create new state dict to avoid mutation
194203 new_state = {}
195204
205+ # Get the timestep
206+ if use_adaptive_timesteps :
207+ dt = jnp .inf
208+ if use_hydro :
209+ if use_magnetic :
210+ dt_hydro = hydro_mhd2d_timestep (
211+ state ["rho" ],
212+ state ["vx" ],
213+ state ["vy" ],
214+ state ["P" ],
215+ state ["bx" ],
216+ state ["by" ],
217+ gamma ,
218+ dx ,
219+ )
220+ else :
221+ dt_hydro = hydro_euler2d_timestep (
222+ state ["rho" ],
223+ state ["vx" ],
224+ state ["vy" ],
225+ state ["P" ],
226+ gamma ,
227+ dx ,
228+ )
229+ dt = jnp .minimum (dt , cfl * dt_hydro )
230+ if use_quantum :
231+ dt_quantum = quantum_timestep (m_per_hbar , dx )
232+ dt = jnp .minimum (dt , dt_quantum )
233+ dt = jnp .minimum (dt , t_span - state ["t" ])
234+
196235 # Kick (half-step) - quantum + gravity
197236 psi = state .get ("psi" )
198237 if use_quantum and use_gravity and psi is not None :
199- psi = quantum_kick (psi , V , 1.0 , dt / 2.0 )
238+ psi = quantum_kick (psi , V , m_per_hbar , dt / 2.0 )
200239
201240 # Drift (full-step) - quantum
202241 if use_quantum and psi is not None :
203- psi = quantum_drift (psi , k_sq , 1.0 , dt )
242+ psi = quantum_drift (psi , k_sq , m_per_hbar , dt )
204243
205244 if use_quantum :
206245 new_state ["psi" ] = psi
@@ -254,21 +293,47 @@ def step_fn(carry, _):
254293 # Update time
255294 new_state ["t" ] = state ["t" ] + dt
256295
257- return (new_state , new_V , k_sq ), None
296+ # Update diagnostics
297+ new_state ["steps_taken" ] = state ["steps_taken" ] + 1
298+
299+ return (new_state , dt , new_V , k_sq )
258300
259301 # Run the entire loop as a single JIT-compiled function
260302 def run_loop (carry ):
261- final_carry , _ = jax .lax .scan (step_fn , carry , xs = None , length = nt )
303+ if use_adaptive_timesteps :
304+ # def cond_fn(carry):
305+ # state, _, _, _ = carry
306+ # return state["t"] < t_span * (1.0 - 1e-10)
307+
308+ # final_carry = jax.lax.while_loop(cond_fn, step_fn, carry)
309+
310+ # do a simple while loop
311+ state , _ , _ , _ = carry
312+ while state ["t" ] < t_span * (1.0 - 1e-10 ):
313+ carry = step_fn (carry )
314+ state , _ , _ , _ = carry
315+ final_carry = carry
316+ else :
317+
318+ def step_fn_stacked (carry , _ ):
319+ # Returns new carry and None (no stacked outputs) for jax.lax.scan()
320+ return step_fn (carry ), None
321+
322+ final_carry , _ = jax .lax .scan (
323+ step_fn_stacked , carry , xs = None , length = nt
324+ )
262325 return final_carry
263326
264327 # Execute the compiled loop
265- state , _ , _ = run_loop (carry )
328+ state , _ , _ , _ = run_loop (carry )
266329
267330 return state
268331
269332 def run (self ):
270333 """
271334 Run the simulation
272335 """
336+ self .state ["steps_taken" ] = 0
273337 self .state = self ._evolve (self .state )
274338 jax .block_until_ready (self .state )
339+ # assert jnp.isfinite(self.state["t"]), "state['t'] is NaN/infinity"
0 commit comments