Skip to content

Commit 4593dbc

Browse files
xukai92claude
andcommitted
Add parallel MALA sampler (Phase 3)
Implement parallelized Metropolis-Adjusted Langevin Algorithm using DEER. MALA components in src/parallel/mala.jl: - MALARandomInputs: Pre-sampled (ξ, u) pairs for proposals and accept-reject - MALAConfig: Configuration struct with step size, log density, gradient - mala_proposal(): Langevin proposal x̃ = x + ε∇log p(x) + √(2ε)ξ - mala_log_acceptance_ratio(): MH ratio with forward/backward densities - soft_gate(): Differentiable accept-reject using sigmoid + straight-through - mala_transition(): Complete MALA step combining proposal and accept-reject - parallel_mala(): Run full MALA chain in parallel via DEER - sequential_mala(): Reference implementation for testing The stop-gradient trick enables computing Jacobians through the non-differentiable accept-reject step by using a soft sigmoid gate in the backward pass while keeping hard decisions in the forward pass. Tests: 31 new tests covering proposal mechanics, sequential/parallel equivalence, various target distributions, acceptance behavior, and sample statistics. Co-Authored-By: Claude Opus 4.5 <noreply@anthropic.com>
1 parent 656fd8a commit 4593dbc

File tree

4 files changed

+748
-13
lines changed

4 files changed

+748
-13
lines changed

docs/parallel_mcmc_implementation_plan.md

Lines changed: 19 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -70,21 +70,23 @@ This document tracks the implementation of DEER (Doubly Efficient Estimation via
7070

7171
---
7272

73-
### Phase 3: MALA Integration
73+
### Phase 3: MALA Integration
7474
> Parallelized Metropolis-Adjusted Langevin Algorithm
7575
76-
- [ ] **3.1 MALA Transition Function**
77-
- [ ] Proposal: x̃ = x + ε∇log p(x) + √(2ε)ξ
78-
- [ ] Acceptance ratio computation
79-
- [ ] Stop-gradient trick for differentiable accept-reject
80-
- [ ] Soft gating: g = σ(log α - log u) with straight-through estimator
76+
- [x] **3.1 MALA Transition Function**
77+
- [x] Proposal: x̃ = x + ε∇log p(x) + √(2ε)ξ
78+
- [x] Acceptance ratio computation (forward and backward proposal densities)
79+
- [x] Stop-gradient trick for differentiable accept-reject
80+
- [x] Soft gating with sigmoid and straight-through estimator
8181

82-
- [ ] **3.2 Parallel MALA Sampler**
83-
- [ ] Pre-sample all random inputs (ξ for proposals, u for accept-reject)
84-
- [ ] Integrate with DEER framework
85-
- [ ] Return full chain from parallel computation
82+
- [x] **3.2 Parallel MALA Sampler**
83+
- [x] MALARandomInputs type for pre-sampled (ξ, u) pairs
84+
- [x] sample_mala_inputs() for batch sampling
85+
- [x] parallel_mala() integrating with DEER framework
86+
- [x] sequential_mala() for reference/testing
87+
- [x] Convenience API with automatic input sampling
8688

87-
- [ ] **3.3 MALA-specific Optimizations**
89+
- [ ] **3.3 MALA-specific Optimizations** (deferred)
8890
- [ ] Preconditioning with Hessian eigendecomposition (optional)
8991

9092
---
@@ -197,14 +199,14 @@ src/
197199
│ ├── scan.jl # Parallel scan implementations ✅
198200
│ ├── jacobian.jl # Jacobian computation utilities ✅
199201
│ ├── deer.jl # Core DEER algorithm ✅
200-
│ ├── mala.jl # Parallel MALA (TODO)
202+
│ ├── mala.jl # Parallel MALA
201203
│ └── hmc.jl # Parallel HMC (TODO)
202204
test/
203205
├── parallel/
204206
│ ├── test_scan.jl # ✅ 141 tests passing
205207
│ ├── test_jacobian.jl # ✅ 57 tests passing
206208
│ ├── test_deer.jl # ✅ 67 tests passing
207-
│ ├── test_mala.jl # (TODO)
209+
│ ├── test_mala.jl # ✅ 31 tests passing
208210
│ └── test_hmc.jl # (TODO)
209211
```
210212
@@ -249,6 +251,10 @@ test/
249251
| 2026-01-20 | 2 | DEER algorithm core | ✅ | Newton iteration, Full/Quasi/Stochastic DEER |
250252
| 2026-01-20 | 7.1 | DEER algorithm tests | ✅ | 67 tests passing |
251253
| 2026-01-20 | 2 | **Phase 2 Complete** | ✅ | Core DEER algorithm working |
254+
| 2026-01-20 | 3 | MALA transition function | ✅ | Proposal, acceptance, soft gating |
255+
| 2026-01-20 | 3 | Parallel MALA sampler | ✅ | Integrated with DEER framework |
256+
| 2026-01-20 | 7.1 | MALA tests | ✅ | 31 tests passing |
257+
| 2026-01-20 | 3 | **Phase 3 Complete** | ✅ | Parallel MALA working |
252258
253259
---
254260

src/parallel/Parallel.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,6 +47,9 @@ include("jacobian.jl")
4747
# Core DEER algorithm
4848
include("deer.jl")
4949

50+
# Parallel MALA
51+
include("mala.jl")
52+
5053
# Export types
5154
export AbstractParallelMethod, FullDEER, QuasiDEER, StochasticQuasiDEER, BlockQuasiDEER
5255

@@ -76,4 +79,9 @@ export hessian_diagonal, batch_hessian_diagonals
7679
export DEERResult
7780
export deer, deer_with_settings, sequential_mcmc
7881

82+
# Export MALA
83+
export MALARandomInputs, MALAConfig
84+
export sample_mala_inputs, mala_proposal, mala_transition
85+
export parallel_mala, sequential_mala
86+
7987
end # module

0 commit comments

Comments
 (0)