Skip to content

Commit 5036586

Browse files
xukai92claude
andcommitted
Add parallel HMC sampler (Phase 4)
Implements parallel HMC using DEER algorithm with two approaches: Approach A - Parallelize across HMC steps: - parallel_hmc() runs T HMC steps in parallel - hmc_transition_soft() with soft MH gating for differentiability Approach B - Parallelize leapfrog integration: - parallel_leapfrog() parallelizes L leapfrog steps - leapfrog_transition() as DEER transition function Block Quasi-DEER for leapfrog: - BlockQuasiDEER type exploiting 2x2 block Jacobian structure - _deer_iteration_block() for efficient block-diagonal Newton - parallel_scan_block() using Block2x2AffineTransform 49 new tests (345 total parallel tests) Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 4593dbc commit 5036586

File tree

5 files changed

+1306
-5
lines changed

5 files changed

+1306
-5
lines changed

src/parallel/Parallel.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,9 @@ include("deer.jl")
5050
# Parallel MALA
5151
include("mala.jl")
5252

53+
# Parallel HMC
54+
include("hmc.jl")
55+
5356
# Export types
5457
export AbstractParallelMethod, FullDEER, QuasiDEER, StochasticQuasiDEER, BlockQuasiDEER
5558

@@ -84,4 +87,12 @@ export MALARandomInputs, MALAConfig
8487
export sample_mala_inputs, mala_proposal, mala_transition
8588
export parallel_mala, sequential_mala
8689

90+
# Export HMC
91+
export HMCRandomInputs, HMCConfig
92+
export sample_hmc_inputs, hmc_transition, hmc_transition_soft
93+
export leapfrog_step, leapfrog_full, hmc_proposal
94+
export parallel_hmc, sequential_hmc
95+
export parallel_leapfrog, leapfrog_transition
96+
export hessian_diagonal_fd
97+
8798
end # module

src/parallel/deer.jl

Lines changed: 151 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -323,9 +323,157 @@ end
323323
####
324324
#### Block Quasi-DEER for Leapfrog (Phase 4)
325325
####
326-
#### Note: Block Quasi-DEER for leapfrog is more specialized and will be
327-
#### implemented in Phase 4 when we integrate with HMC.
328-
####
326+
327+
"""
328+
_deer_iteration(f, s0, trajectory, ω, method::BlockQuasiDEER; kwargs...)
329+
330+
Dispatch for Block Quasi-DEER method.
331+
"""
332+
function _deer_iteration(
333+
f, s0, trajectory, ω, method::BlockQuasiDEER;
334+
jacobian_fn, jvp_fn, rng
335+
)
336+
return _deer_iteration_block(
337+
f, s0, trajectory, ω;
338+
hessian_diag_fn=method.hessian_diag_fn,
339+
ε=method.ε,
340+
M⁻¹=method.M⁻¹
341+
)
342+
end
343+
344+
"""
345+
_deer_iteration_block(f, s0, trajectory, ω; hessian_diag_fn, ε, M⁻¹)
346+
347+
One Newton iteration using 2×2 block-diagonal Jacobian structure for leapfrog.
348+
349+
The state is s = [θ; r] where θ is position and r is momentum.
350+
The Jacobian has 2×2 block structure per dimension:
351+
352+
J_d = [ 1 ε*M⁻¹_d ]
353+
[ ε*H_d 1 + ε²*M⁻¹_d*H_d ]
354+
355+
where H_d is the d-th diagonal element of the Hessian of -log p.
356+
357+
Memory: O(T * D)
358+
Work: O(T * D) for 2×2 block operations in scan
359+
"""
360+
function _deer_iteration_block(
361+
f,
362+
s0::AbstractVector{T},
363+
trajectory::AbstractMatrix{T},
364+
ω;
365+
hessian_diag_fn,
366+
ε::T,
367+
M⁻¹::AbstractVector{T},
368+
) where {T}
369+
T_len, state_dim = size(trajectory)
370+
D = state_dim ÷ 2 # θ and r each have dimension D
371+
372+
# Allocate arrays
373+
f_vals = zeros(T, T_len, state_dim)
374+
375+
# Store block Jacobian components for each timestep
376+
J_a = zeros(T, T_len, D) # Top-left diagonal
377+
J_b = zeros(T, T_len, D) # Top-right diagonal
378+
J_c = zeros(T, T_len, D) # Bottom-left diagonal
379+
J_e = zeros(T, T_len, D) # Bottom-right diagonal
380+
u_x = zeros(T, T_len, D) # Offset for position
381+
u_v = zeros(T, T_len, D) # Offset for momentum
382+
383+
# Step 1: Evaluate f and compute block Jacobians at all timesteps
384+
for t in 1:T_len
385+
s_prev = (t == 1) ? s0 : trajectory[t - 1, :]
386+
387+
# Evaluate transition function
388+
f_vals[t, :] = f(s_prev, ω[t])
389+
390+
# Extract position from previous state
391+
θ_prev = s_prev[1:D]
392+
393+
# Compute Hessian diagonal at previous position
394+
H_diag = hessian_diag_fn(θ_prev)
395+
396+
# Block Jacobian structure for leapfrog:
397+
# J = [ I ε*M⁻¹ ]
398+
# [ ε*H_diag I + ε²*M⁻¹*H_diag ]
399+
J_a[t, :] .= one(T)
400+
J_b[t, :] .= ε .* M⁻¹
401+
J_c[t, :] .= ε .* H_diag
402+
J_e[t, :] .= one(T) .+^2) .* M⁻¹ .* H_diag
403+
end
404+
405+
# Step 2: Compute offsets u = f(s_prev) - J * s_prev
406+
for t in 1:T_len
407+
s_prev = (t == 1) ? s0 : trajectory[t - 1, :]
408+
θ_prev = s_prev[1:D]
409+
r_prev = s_prev[(D+1):end]
410+
411+
f_θ = f_vals[t, 1:D]
412+
f_r = f_vals[t, (D+1):end]
413+
414+
# u_x = f_θ - (J_a * θ_prev + J_b * r_prev)
415+
# u_v = f_r - (J_c * θ_prev + J_e * r_prev)
416+
u_x[t, :] = f_θ .- (J_a[t, :] .* θ_prev .+ J_b[t, :] .* r_prev)
417+
u_v[t, :] = f_r .- (J_c[t, :] .* θ_prev .+ J_e[t, :] .* r_prev)
418+
end
419+
420+
# Step 3: Build block transforms and solve via parallel scan
421+
transforms = [Block2x2AffineTransform(
422+
J_a[t, :], J_b[t, :], J_c[t, :], J_e[t, :],
423+
u_x[t, :], u_v[t, :]
424+
) for t in 1:T_len]
425+
426+
# Initial state split
427+
θ0 = s0[1:D]
428+
r0 = s0[(D+1):end]
429+
430+
# Run parallel scan
431+
trajectory_θ, trajectory_r = parallel_scan_block(transforms, θ0, r0)
432+
433+
# Combine into trajectory
434+
trajectory_new = zeros(T, T_len, state_dim)
435+
trajectory_new[:, 1:D] = trajectory_θ
436+
trajectory_new[:, (D+1):end] = trajectory_r
437+
438+
return trajectory_new
439+
end
440+
441+
"""
442+
parallel_scan_block(transforms, θ0, r0)
443+
444+
Parallel scan for 2×2 block transforms.
445+
446+
Returns (trajectory_θ, trajectory_r) where each is a T_len × D matrix.
447+
"""
448+
function parallel_scan_block(
449+
transforms::Vector{<:Block2x2AffineTransform{T}},
450+
θ0::AbstractVector{T},
451+
r0::AbstractVector{T},
452+
) where {T}
453+
T_len = length(transforms)
454+
D = length(θ0)
455+
456+
# Run prefix sum to get cumulative transforms
457+
prefix = Vector{Block2x2AffineTransform{T}}(undef, T_len)
458+
prefix[1] = transforms[1]
459+
for t in 2:T_len
460+
prefix[t] = compose(transforms[t], prefix[t-1])
461+
end
462+
463+
# Apply each cumulative transform to initial state
464+
trajectory_θ = zeros(T, T_len, D)
465+
trajectory_r = zeros(T, T_len, D)
466+
467+
for t in 1:T_len
468+
tr = prefix[t]
469+
# Apply: [θ'] = [a b] [θ0] + [u_x]
470+
# [r'] [c e] [r0] [u_v]
471+
trajectory_θ[t, :] = tr.a .* θ0 .+ tr.b .* r0 .+ tr.u_x
472+
trajectory_r[t, :] = tr.c .* θ0 .+ tr.e .* r0 .+ tr.u_v
473+
end
474+
475+
return trajectory_θ, trajectory_r
476+
end
329477

330478
####
331479
#### Utility Functions

0 commit comments

Comments
 (0)