@@ -323,9 +323,157 @@ end
323323# ###
324324# ### Block Quasi-DEER for Leapfrog (Phase 4)
325325# ###
326- # ### Note: Block Quasi-DEER for leapfrog is more specialized and will be
327- # ### implemented in Phase 4 when we integrate with HMC.
328- # ###
326+
327+ """
328+ _deer_iteration(f, s0, trajectory, ω, method::BlockQuasiDEER; kwargs...)
329+
330+ Dispatch for Block Quasi-DEER method.
331+ """
332+ function _deer_iteration (
333+ f, s0, trajectory, ω, method:: BlockQuasiDEER ;
334+ jacobian_fn, jvp_fn, rng
335+ )
336+ return _deer_iteration_block (
337+ f, s0, trajectory, ω;
338+ hessian_diag_fn= method. hessian_diag_fn,
339+ ε= method. ε,
340+ M⁻¹= method. M⁻¹
341+ )
342+ end
343+
344+ """
345+ _deer_iteration_block(f, s0, trajectory, ω; hessian_diag_fn, ε, M⁻¹)
346+
347+ One Newton iteration using 2×2 block-diagonal Jacobian structure for leapfrog.
348+
349+ The state is s = [θ; r] where θ is position and r is momentum.
350+ The Jacobian has 2×2 block structure per dimension:
351+
352+ J_d = [ 1 ε*M⁻¹_d ]
353+ [ ε*H_d 1 + ε²*M⁻¹_d*H_d ]
354+
355+ where H_d is the d-th diagonal element of the Hessian of -log p.
356+
357+ Memory: O(T * D)
358+ Work: O(T * D) for 2×2 block operations in scan
359+ """
360+ function _deer_iteration_block (
361+ f,
362+ s0:: AbstractVector{T} ,
363+ trajectory:: AbstractMatrix{T} ,
364+ ω;
365+ hessian_diag_fn,
366+ ε:: T ,
367+ M⁻¹:: AbstractVector{T} ,
368+ ) where {T}
369+ T_len, state_dim = size (trajectory)
370+ D = state_dim ÷ 2 # θ and r each have dimension D
371+
372+ # Allocate arrays
373+ f_vals = zeros (T, T_len, state_dim)
374+
375+ # Store block Jacobian components for each timestep
376+ J_a = zeros (T, T_len, D) # Top-left diagonal
377+ J_b = zeros (T, T_len, D) # Top-right diagonal
378+ J_c = zeros (T, T_len, D) # Bottom-left diagonal
379+ J_e = zeros (T, T_len, D) # Bottom-right diagonal
380+ u_x = zeros (T, T_len, D) # Offset for position
381+ u_v = zeros (T, T_len, D) # Offset for momentum
382+
383+ # Step 1: Evaluate f and compute block Jacobians at all timesteps
384+ for t in 1 : T_len
385+ s_prev = (t == 1 ) ? s0 : trajectory[t - 1 , :]
386+
387+ # Evaluate transition function
388+ f_vals[t, :] = f (s_prev, ω[t])
389+
390+ # Extract position from previous state
391+ θ_prev = s_prev[1 : D]
392+
393+ # Compute Hessian diagonal at previous position
394+ H_diag = hessian_diag_fn (θ_prev)
395+
396+ # Block Jacobian structure for leapfrog:
397+ # J = [ I ε*M⁻¹ ]
398+ # [ ε*H_diag I + ε²*M⁻¹*H_diag ]
399+ J_a[t, :] .= one (T)
400+ J_b[t, :] .= ε .* M⁻¹
401+ J_c[t, :] .= ε .* H_diag
402+ J_e[t, :] .= one (T) .+ (ε^ 2 ) .* M⁻¹ .* H_diag
403+ end
404+
405+ # Step 2: Compute offsets u = f(s_prev) - J * s_prev
406+ for t in 1 : T_len
407+ s_prev = (t == 1 ) ? s0 : trajectory[t - 1 , :]
408+ θ_prev = s_prev[1 : D]
409+ r_prev = s_prev[(D+ 1 ): end ]
410+
411+ f_θ = f_vals[t, 1 : D]
412+ f_r = f_vals[t, (D+ 1 ): end ]
413+
414+ # u_x = f_θ - (J_a * θ_prev + J_b * r_prev)
415+ # u_v = f_r - (J_c * θ_prev + J_e * r_prev)
416+ u_x[t, :] = f_θ .- (J_a[t, :] .* θ_prev .+ J_b[t, :] .* r_prev)
417+ u_v[t, :] = f_r .- (J_c[t, :] .* θ_prev .+ J_e[t, :] .* r_prev)
418+ end
419+
420+ # Step 3: Build block transforms and solve via parallel scan
421+ transforms = [Block2x2AffineTransform (
422+ J_a[t, :], J_b[t, :], J_c[t, :], J_e[t, :],
423+ u_x[t, :], u_v[t, :]
424+ ) for t in 1 : T_len]
425+
426+ # Initial state split
427+ θ0 = s0[1 : D]
428+ r0 = s0[(D+ 1 ): end ]
429+
430+ # Run parallel scan
431+ trajectory_θ, trajectory_r = parallel_scan_block (transforms, θ0, r0)
432+
433+ # Combine into trajectory
434+ trajectory_new = zeros (T, T_len, state_dim)
435+ trajectory_new[:, 1 : D] = trajectory_θ
436+ trajectory_new[:, (D+ 1 ): end ] = trajectory_r
437+
438+ return trajectory_new
439+ end
440+
441+ """
442+ parallel_scan_block(transforms, θ0, r0)
443+
444+ Parallel scan for 2×2 block transforms.
445+
446+ Returns (trajectory_θ, trajectory_r) where each is a T_len × D matrix.
447+ """
448+ function parallel_scan_block (
449+ transforms:: Vector{<:Block2x2AffineTransform{T}} ,
450+ θ0:: AbstractVector{T} ,
451+ r0:: AbstractVector{T} ,
452+ ) where {T}
453+ T_len = length (transforms)
454+ D = length (θ0)
455+
456+ # Run prefix sum to get cumulative transforms
457+ prefix = Vector {Block2x2AffineTransform{T}} (undef, T_len)
458+ prefix[1 ] = transforms[1 ]
459+ for t in 2 : T_len
460+ prefix[t] = compose (transforms[t], prefix[t- 1 ])
461+ end
462+
463+ # Apply each cumulative transform to initial state
464+ trajectory_θ = zeros (T, T_len, D)
465+ trajectory_r = zeros (T, T_len, D)
466+
467+ for t in 1 : T_len
468+ tr = prefix[t]
469+ # Apply: [θ'] = [a b] [θ0] + [u_x]
470+ # [r'] [c e] [r0] [u_v]
471+ trajectory_θ[t, :] = tr. a .* θ0 .+ tr. b .* r0 .+ tr. u_x
472+ trajectory_r[t, :] = tr. c .* θ0 .+ tr. e .* r0 .+ tr. u_v
473+ end
474+
475+ return trajectory_θ, trajectory_r
476+ end
329477
330478# ###
331479# ### Utility Functions
0 commit comments