@@ -106,8 +106,7 @@ function deer(
106106 for iter in 1 : max_iters
107107 # Run one Newton iteration based on method
108108 trajectory_new = _deer_iteration (
109- f, s0, trajectory, ω, method;
110- jacobian_fn= jacobian_fn, jvp_fn= jvp_fn, rng= rng
109+ f, s0, trajectory, ω, method; jacobian_fn= jacobian_fn, jvp_fn= jvp_fn, rng= rng
111110 )
112111
113112 # Check convergence
@@ -139,27 +138,19 @@ end
139138# ### Newton Iteration Dispatch
140139# ###
141140
142- function _deer_iteration (
143- f, s0, trajectory, ω, method:: FullDEER ;
144- jacobian_fn, jvp_fn, rng
145- )
141+ function _deer_iteration (f, s0, trajectory, ω, method:: FullDEER ; jacobian_fn, jvp_fn, rng)
146142 return _deer_iteration_full (f, s0, trajectory, ω; jacobian_fn= jacobian_fn)
147143end
148144
149- function _deer_iteration (
150- f, s0, trajectory, ω, method:: QuasiDEER ;
151- jacobian_fn, jvp_fn, rng
152- )
145+ function _deer_iteration (f, s0, trajectory, ω, method:: QuasiDEER ; jacobian_fn, jvp_fn, rng)
153146 return _deer_iteration_quasi (f, s0, trajectory, ω; jacobian_fn= jacobian_fn)
154147end
155148
156149function _deer_iteration (
157- f, s0, trajectory, ω, method:: StochasticQuasiDEER ;
158- jacobian_fn, jvp_fn, rng
150+ f, s0, trajectory, ω, method:: StochasticQuasiDEER ; jacobian_fn, jvp_fn, rng
159151)
160152 return _deer_iteration_stochastic (
161- f, s0, trajectory, ω;
162- jvp_fn= jvp_fn, rng= rng, n_samples= method. n_samples
153+ f, s0, trajectory, ω; jvp_fn= jvp_fn, rng= rng, n_samples= method. n_samples
163154 )
164155end
165156
@@ -176,11 +167,7 @@ Memory: O(T * D²)
176167Work: O(T * D³) for matrix multiplications in scan
177168"""
178169function _deer_iteration_full (
179- f,
180- s0:: AbstractVector{T} ,
181- trajectory:: AbstractMatrix{T} ,
182- ω;
183- jacobian_fn= jacobian_fd,
170+ f, s0:: AbstractVector{T} , trajectory:: AbstractMatrix{T} , ω; jacobian_fn= jacobian_fd
184171) where {T}
185172 T_len, D = size (trajectory)
186173
@@ -227,11 +214,7 @@ Memory: O(T * D)
227214Work: O(T * D) for elementwise operations in scan
228215"""
229216function _deer_iteration_quasi (
230- f,
231- s0:: AbstractVector{T} ,
232- trajectory:: AbstractMatrix{T} ,
233- ω;
234- jacobian_fn= jacobian_fd,
217+ f, s0:: AbstractVector{T} , trajectory:: AbstractMatrix{T} , ω; jacobian_fn= jacobian_fd
235218) where {T}
236219 T_len, D = size (trajectory)
237220
@@ -304,7 +287,9 @@ function _deer_iteration_stochastic(
304287
305288 # Estimate Jacobian diagonal via Hutchinson's method
306289 f_t (s) = f (s, ω[t])
307- J_diag[t, :] = hutchinson_diagonal (f_t, s_prev, jvp_fn; rng= rng, n_samples= n_samples)
290+ J_diag[t, :] = hutchinson_diagonal (
291+ f_t, s_prev, jvp_fn; rng= rng, n_samples= n_samples
292+ )
308293 end
309294
310295 # Step 2: Compute inputs u_t = f_t(s_{t-1}) - diag(J_t) .* s_{t-1}
@@ -330,14 +315,16 @@ end
330315Dispatch for Block Quasi-DEER method.
331316"""
332317function _deer_iteration (
333- f, s0, trajectory, ω, method:: BlockQuasiDEER ;
334- jacobian_fn, jvp_fn, rng
318+ f, s0, trajectory, ω, method:: BlockQuasiDEER ; jacobian_fn, jvp_fn, rng
335319)
336320 return _deer_iteration_block (
337- f, s0, trajectory, ω;
321+ f,
322+ s0,
323+ trajectory,
324+ ω;
338325 hessian_diag_fn= method. hessian_diag_fn,
339326 ε= method. ε,
340- M⁻¹= method. M⁻¹
327+ M⁻¹= method. M⁻¹,
341328 )
342329end
343330
@@ -406,10 +393,10 @@ function _deer_iteration_block(
406393 for t in 1 : T_len
407394 s_prev = (t == 1 ) ? s0 : trajectory[t - 1 , :]
408395 θ_prev = s_prev[1 : D]
409- r_prev = s_prev[(D+ 1 ): end ]
396+ r_prev = s_prev[(D + 1 ): end ]
410397
411398 f_θ = f_vals[t, 1 : D]
412- f_r = f_vals[t, (D+ 1 ): end ]
399+ f_r = f_vals[t, (D + 1 ): end ]
413400
414401 # u_x = f_θ - (J_a * θ_prev + J_b * r_prev)
415402 # u_v = f_r - (J_c * θ_prev + J_e * r_prev)
@@ -418,22 +405,23 @@ function _deer_iteration_block(
418405 end
419406
420407 # Step 3: Build block transforms and solve via parallel scan
421- transforms = [Block2x2AffineTransform (
422- J_a[t, :], J_b[t, :], J_c[t, :], J_e[t, :],
423- u_x[t, :], u_v[t, :]
424- ) for t in 1 : T_len]
408+ transforms = [
409+ Block2x2AffineTransform (
410+ J_a[t, :], J_b[t, :], J_c[t, :], J_e[t, :], u_x[t, :], u_v[t, :]
411+ ) for t in 1 : T_len
412+ ]
425413
426414 # Initial state split
427415 θ0 = s0[1 : D]
428- r0 = s0[(D+ 1 ): end ]
416+ r0 = s0[(D + 1 ): end ]
429417
430418 # Run parallel scan
431419 trajectory_θ, trajectory_r = parallel_scan_block (transforms, θ0, r0)
432420
433421 # Combine into trajectory
434422 trajectory_new = zeros (T, T_len, state_dim)
435423 trajectory_new[:, 1 : D] = trajectory_θ
436- trajectory_new[:, (D+ 1 ): end ] = trajectory_r
424+ trajectory_new[:, (D + 1 ): end ] = trajectory_r
437425
438426 return trajectory_new
439427end
@@ -457,7 +445,7 @@ function parallel_scan_block(
457445 prefix = Vector {Block2x2AffineTransform{T}} (undef, T_len)
458446 prefix[1 ] = transforms[1 ]
459447 for t in 2 : T_len
460- prefix[t] = compose (transforms[t], prefix[t- 1 ])
448+ prefix[t] = compose (transforms[t], prefix[t - 1 ])
461449 end
462450
463451 # Apply each cumulative transform to initial state
@@ -485,19 +473,17 @@ end
485473Run DEER with settings from a ParallelMCMCSettings struct.
486474"""
487475function deer_with_settings (
488- f,
489- s0:: AbstractVector ,
490- T_len:: Int ,
491- ω,
492- settings:: ParallelMCMCSettings ;
493- kwargs...
476+ f, s0:: AbstractVector , T_len:: Int , ω, settings:: ParallelMCMCSettings ; kwargs...
494477)
495478 return deer (
496- f, s0, T_len, ω;
479+ f,
480+ s0,
481+ T_len,
482+ ω;
497483 method= settings. method,
498484 tol= settings. tol,
499485 max_iters= settings. max_iters,
500- kwargs...
486+ kwargs... ,
501487 )
502488end
503489
@@ -508,12 +494,7 @@ Run MCMC sequentially (for comparison/testing).
508494
509495Returns the trajectory as a (T × D) matrix.
510496"""
511- function sequential_mcmc (
512- f,
513- s0:: AbstractVector{T} ,
514- T_len:: Int ,
515- ω,
516- ) where {T}
497+ function sequential_mcmc (f, s0:: AbstractVector{T} , T_len:: Int , ω) where {T}
517498 D = length (s0)
518499 trajectory = zeros (T, T_len, D)
519500
0 commit comments