Skip to content

Commit cce5e53

Browse files
xukai92claude
andcommitted
Fix test suite integration and apply JuliaFormatter
- Fix @testset conflict: guard against ReTest in common.jl - Guard against multiple includes of common.jl - Apply JuliaFormatter (Blue style) to all parallel source and test files - Format docs/parallel_mcmc_implementation_plan.md Co-Authored-By: Claude Opus 4.6 <noreply@anthropic.com>
1 parent f19ac33 commit cce5e53

File tree

14 files changed

+418
-399
lines changed

14 files changed

+418
-399
lines changed

docs/parallel_mcmc_implementation_plan.md

Lines changed: 233 additions & 208 deletions
Large diffs are not rendered by default.

src/parallel/Parallel.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,8 +36,8 @@ using LinearAlgebra
3636
using Random
3737

3838
# Check if we're a submodule of AdvancedHMC
39-
const IS_SUBMODULE = parentmodule(@__MODULE__) !== Main &&
40-
nameof(parentmodule(@__MODULE__)) === :AdvancedHMC
39+
const IS_SUBMODULE =
40+
parentmodule(@__MODULE__) !== Main && nameof(parentmodule(@__MODULE__)) === :AdvancedHMC
4141

4242
# Import dependencies based on context
4343
if IS_SUBMODULE

src/parallel/abstractmcmc.jl

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -43,9 +43,7 @@ result = sample(model, sampler, 1000)
4343
```
4444
"""
4545
struct ParallelHMCSampler{
46-
T<:AbstractFloat,
47-
M<:AbstractParallelMethod,
48-
I<:Union{Symbol,AbstractMetric},
46+
T<:AbstractFloat,M<:AbstractParallelMethod,I<:Union{Symbol,AbstractMetric}
4947
} <: AbstractParallelSampler
5048
ε::T
5149
L::Int
@@ -99,9 +97,7 @@ result = sample(model, sampler, 1000)
9997
```
10098
"""
10199
struct ParallelMALASampler{
102-
T<:AbstractFloat,
103-
M<:AbstractParallelMethod,
104-
I<:Union{Symbol,AbstractMetric},
100+
T<:AbstractFloat,M<:AbstractParallelMethod,I<:Union{Symbol,AbstractMetric}
105101
} <: AbstractParallelSampler
106102
ε::T
107103
method::M
@@ -244,7 +240,7 @@ function parallel_sample(
244240

245241
# Create log density and gradient functions
246242
logp = x -> LogDensityProblems.logdensity(logdensity, x)
247-
∇logp = function(x)
243+
∇logp = function (x)
248244
_, grad = LogDensityProblems.logdensity_and_gradient(logdensity, x)
249245
return grad
250246
end
@@ -267,11 +263,14 @@ function parallel_sample(
267263

268264
# Run parallel HMC
269265
result = parallel_hmc(
270-
config, s0, N, ω;
266+
config,
267+
s0,
268+
N,
269+
ω;
271270
method=sampler.method,
272271
tol=sampler.tol,
273272
max_iters=sampler.max_iters,
274-
verbose=verbose
273+
verbose=verbose,
275274
)
276275

277276
# Estimate acceptance rate (from soft gating, approximate)
@@ -283,7 +282,7 @@ function parallel_sample(
283282
result.converged,
284283
result.iterations,
285284
result.max_residual,
286-
acceptance_rate
285+
acceptance_rate,
287286
)
288287
end
289288

@@ -305,7 +304,7 @@ function parallel_sample(
305304

306305
# Create log density and gradient functions
307306
logp = x -> LogDensityProblems.logdensity(logdensity, x)
308-
∇logp = function(x)
307+
∇logp = function (x)
309308
_, grad = LogDensityProblems.logdensity_and_gradient(logdensity, x)
310309
return grad
311310
end
@@ -325,11 +324,14 @@ function parallel_sample(
325324

326325
# Run parallel MALA
327326
result = parallel_mala(
328-
config, s0, N, ω;
327+
config,
328+
s0,
329+
N,
330+
ω;
329331
method=sampler.method,
330332
tol=sampler.tol,
331333
max_iters=sampler.max_iters,
332-
verbose=verbose
334+
verbose=verbose,
333335
)
334336

335337
# Estimate acceptance rate
@@ -340,7 +342,7 @@ function parallel_sample(
340342
result.converged,
341343
result.iterations,
342344
result.max_residual,
343-
acceptance_rate
345+
acceptance_rate,
344346
)
345347
end
346348

@@ -370,7 +372,7 @@ get_samples(state::ParallelSamplerState) = state.trajectory
370372
Extract samples after discarding burn-in period.
371373
"""
372374
function get_samples(state::ParallelSamplerState, burn_in::Int)
373-
return state.trajectory[(burn_in+1):end, :]
375+
return state.trajectory[(burn_in + 1):end, :]
374376
end
375377

376378
####
@@ -396,10 +398,7 @@ function Base.iterate(state::ParallelSamplerState, i::Int)
396398
return nothing
397399
end
398400
θ = state.trajectory[i, :]
399-
stat = (
400-
iteration=i,
401-
converged=state.converged,
402-
)
401+
stat = (iteration=i, converged=state.converged)
403402
return ParallelTransition(θ, stat), i + 1
404403
end
405404

@@ -429,7 +428,9 @@ struct SimpleLogDensity{F,G}
429428
end
430429

431430
LogDensityProblems.dimension(ld::SimpleLogDensity) = ld.dim
432-
LogDensityProblems.capabilities(::Type{<:SimpleLogDensity}) = LogDensityProblems.LogDensityOrder{1}()
431+
function LogDensityProblems.capabilities(::Type{<:SimpleLogDensity})
432+
LogDensityProblems.LogDensityOrder{1}()
433+
end
433434
LogDensityProblems.logdensity(ld::SimpleLogDensity, x) = ld.logp(x)
434435
function LogDensityProblems.logdensity_and_gradient(ld::SimpleLogDensity, x)
435436
return ld.logp(x), ld.∇logp(x)

src/parallel/deer.jl

Lines changed: 33 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -106,8 +106,7 @@ function deer(
106106
for iter in 1:max_iters
107107
# Run one Newton iteration based on method
108108
trajectory_new = _deer_iteration(
109-
f, s0, trajectory, ω, method;
110-
jacobian_fn=jacobian_fn, jvp_fn=jvp_fn, rng=rng
109+
f, s0, trajectory, ω, method; jacobian_fn=jacobian_fn, jvp_fn=jvp_fn, rng=rng
111110
)
112111

113112
# Check convergence
@@ -139,27 +138,19 @@ end
139138
#### Newton Iteration Dispatch
140139
####
141140

142-
function _deer_iteration(
143-
f, s0, trajectory, ω, method::FullDEER;
144-
jacobian_fn, jvp_fn, rng
145-
)
141+
function _deer_iteration(f, s0, trajectory, ω, method::FullDEER; jacobian_fn, jvp_fn, rng)
146142
return _deer_iteration_full(f, s0, trajectory, ω; jacobian_fn=jacobian_fn)
147143
end
148144

149-
function _deer_iteration(
150-
f, s0, trajectory, ω, method::QuasiDEER;
151-
jacobian_fn, jvp_fn, rng
152-
)
145+
function _deer_iteration(f, s0, trajectory, ω, method::QuasiDEER; jacobian_fn, jvp_fn, rng)
153146
return _deer_iteration_quasi(f, s0, trajectory, ω; jacobian_fn=jacobian_fn)
154147
end
155148

156149
function _deer_iteration(
157-
f, s0, trajectory, ω, method::StochasticQuasiDEER;
158-
jacobian_fn, jvp_fn, rng
150+
f, s0, trajectory, ω, method::StochasticQuasiDEER; jacobian_fn, jvp_fn, rng
159151
)
160152
return _deer_iteration_stochastic(
161-
f, s0, trajectory, ω;
162-
jvp_fn=jvp_fn, rng=rng, n_samples=method.n_samples
153+
f, s0, trajectory, ω; jvp_fn=jvp_fn, rng=rng, n_samples=method.n_samples
163154
)
164155
end
165156

@@ -176,11 +167,7 @@ Memory: O(T * D²)
176167
Work: O(T * D³) for matrix multiplications in scan
177168
"""
178169
function _deer_iteration_full(
179-
f,
180-
s0::AbstractVector{T},
181-
trajectory::AbstractMatrix{T},
182-
ω;
183-
jacobian_fn=jacobian_fd,
170+
f, s0::AbstractVector{T}, trajectory::AbstractMatrix{T}, ω; jacobian_fn=jacobian_fd
184171
) where {T}
185172
T_len, D = size(trajectory)
186173

@@ -227,11 +214,7 @@ Memory: O(T * D)
227214
Work: O(T * D) for elementwise operations in scan
228215
"""
229216
function _deer_iteration_quasi(
230-
f,
231-
s0::AbstractVector{T},
232-
trajectory::AbstractMatrix{T},
233-
ω;
234-
jacobian_fn=jacobian_fd,
217+
f, s0::AbstractVector{T}, trajectory::AbstractMatrix{T}, ω; jacobian_fn=jacobian_fd
235218
) where {T}
236219
T_len, D = size(trajectory)
237220

@@ -304,7 +287,9 @@ function _deer_iteration_stochastic(
304287

305288
# Estimate Jacobian diagonal via Hutchinson's method
306289
f_t(s) = f(s, ω[t])
307-
J_diag[t, :] = hutchinson_diagonal(f_t, s_prev, jvp_fn; rng=rng, n_samples=n_samples)
290+
J_diag[t, :] = hutchinson_diagonal(
291+
f_t, s_prev, jvp_fn; rng=rng, n_samples=n_samples
292+
)
308293
end
309294

310295
# Step 2: Compute inputs u_t = f_t(s_{t-1}) - diag(J_t) .* s_{t-1}
@@ -330,14 +315,16 @@ end
330315
Dispatch for Block Quasi-DEER method.
331316
"""
332317
function _deer_iteration(
333-
f, s0, trajectory, ω, method::BlockQuasiDEER;
334-
jacobian_fn, jvp_fn, rng
318+
f, s0, trajectory, ω, method::BlockQuasiDEER; jacobian_fn, jvp_fn, rng
335319
)
336320
return _deer_iteration_block(
337-
f, s0, trajectory, ω;
321+
f,
322+
s0,
323+
trajectory,
324+
ω;
338325
hessian_diag_fn=method.hessian_diag_fn,
339326
ε=method.ε,
340-
M⁻¹=method.M⁻¹
327+
M⁻¹=method.M⁻¹,
341328
)
342329
end
343330

@@ -406,10 +393,10 @@ function _deer_iteration_block(
406393
for t in 1:T_len
407394
s_prev = (t == 1) ? s0 : trajectory[t - 1, :]
408395
θ_prev = s_prev[1:D]
409-
r_prev = s_prev[(D+1):end]
396+
r_prev = s_prev[(D + 1):end]
410397

411398
f_θ = f_vals[t, 1:D]
412-
f_r = f_vals[t, (D+1):end]
399+
f_r = f_vals[t, (D + 1):end]
413400

414401
# u_x = f_θ - (J_a * θ_prev + J_b * r_prev)
415402
# u_v = f_r - (J_c * θ_prev + J_e * r_prev)
@@ -418,22 +405,23 @@ function _deer_iteration_block(
418405
end
419406

420407
# Step 3: Build block transforms and solve via parallel scan
421-
transforms = [Block2x2AffineTransform(
422-
J_a[t, :], J_b[t, :], J_c[t, :], J_e[t, :],
423-
u_x[t, :], u_v[t, :]
424-
) for t in 1:T_len]
408+
transforms = [
409+
Block2x2AffineTransform(
410+
J_a[t, :], J_b[t, :], J_c[t, :], J_e[t, :], u_x[t, :], u_v[t, :]
411+
) for t in 1:T_len
412+
]
425413

426414
# Initial state split
427415
θ0 = s0[1:D]
428-
r0 = s0[(D+1):end]
416+
r0 = s0[(D + 1):end]
429417

430418
# Run parallel scan
431419
trajectory_θ, trajectory_r = parallel_scan_block(transforms, θ0, r0)
432420

433421
# Combine into trajectory
434422
trajectory_new = zeros(T, T_len, state_dim)
435423
trajectory_new[:, 1:D] = trajectory_θ
436-
trajectory_new[:, (D+1):end] = trajectory_r
424+
trajectory_new[:, (D + 1):end] = trajectory_r
437425

438426
return trajectory_new
439427
end
@@ -457,7 +445,7 @@ function parallel_scan_block(
457445
prefix = Vector{Block2x2AffineTransform{T}}(undef, T_len)
458446
prefix[1] = transforms[1]
459447
for t in 2:T_len
460-
prefix[t] = compose(transforms[t], prefix[t-1])
448+
prefix[t] = compose(transforms[t], prefix[t - 1])
461449
end
462450

463451
# Apply each cumulative transform to initial state
@@ -485,19 +473,17 @@ end
485473
Run DEER with settings from a ParallelMCMCSettings struct.
486474
"""
487475
function deer_with_settings(
488-
f,
489-
s0::AbstractVector,
490-
T_len::Int,
491-
ω,
492-
settings::ParallelMCMCSettings;
493-
kwargs...
476+
f, s0::AbstractVector, T_len::Int, ω, settings::ParallelMCMCSettings; kwargs...
494477
)
495478
return deer(
496-
f, s0, T_len, ω;
479+
f,
480+
s0,
481+
T_len,
482+
ω;
497483
method=settings.method,
498484
tol=settings.tol,
499485
max_iters=settings.max_iters,
500-
kwargs...
486+
kwargs...,
501487
)
502488
end
503489

@@ -508,12 +494,7 @@ Run MCMC sequentially (for comparison/testing).
508494
509495
Returns the trajectory as a (T × D) matrix.
510496
"""
511-
function sequential_mcmc(
512-
f,
513-
s0::AbstractVector{T},
514-
T_len::Int,
515-
ω,
516-
) where {T}
497+
function sequential_mcmc(f, s0::AbstractVector{T}, T_len::Int, ω) where {T}
517498
D = length(s0)
518499
trajectory = zeros(T, T_len, D)
519500

0 commit comments

Comments
 (0)