Skip to content

Commit ecc388a

Browse files
xukai92yebaigithub-actions[bot]
authored
feat: support position-dependent kinetic (#369)
* feat: support position-dependent kinetic Signed-off-by: Kai Xu <[email protected]> * Update src/hamiltonian.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/hamiltonian.jl Co-authored-by: Hong Ge <[email protected]> * Update src/hamiltonian.jl Co-authored-by: Hong Ge <[email protected]> * Update src/metric.jl Co-authored-by: Hong Ge <[email protected]> * Update src/metric.jl Co-authored-by: Hong Ge <[email protected]> * Update src/metric.jl Co-authored-by: Hong Ge <[email protected]> * fix: add position-independent methods back for leapfrog comptaibility Signed-off-by: Kai Xu <[email protected]> * Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/metric.jl * Update src/metric.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/metric.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update src/metric.jl Co-authored-by: Hong Ge <[email protected]> --------- Signed-off-by: Kai Xu <[email protected]> Co-authored-by: Hong Ge <[email protected]> Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 6111133 commit ecc388a

File tree

2 files changed

+30
-4
lines changed

2 files changed

+30
-4
lines changed

src/hamiltonian.jl

Lines changed: 13 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,12 @@ end
4545
∂H∂r(h::Hamiltonian{<:DenseEuclideanMetric,<:GaussianKinetic}, r::AbstractVecOrMat) =
4646
h.metric.M⁻¹ * r
4747

48+
# TODO (kai) make the order of θ and r consistent with neg_energy
49+
# TODO (kai) add stricter types to block hamiltonian.jl#L37 from working on unknown metric/kinetic
50+
# The gradient of a position-dependent Hamiltonian system depends on both θ and r.
51+
∂H∂θ(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂θ(h, θ)
52+
∂H∂r(h::Hamiltonian, θ::AbstractVecOrMat, r::AbstractVecOrMat) = ∂H∂r(h, r)
53+
4854
struct PhasePoint{T<:AbstractVecOrMat{<:AbstractFloat},V<:DualValue}
4955
θ::T # Position variables / model parameters.
5056
r::T # Momentum variables
@@ -156,7 +162,7 @@ phasepoint(
156162
rng::Union{AbstractRNG,AbstractVector{<:AbstractRNG}},
157163
θ::AbstractVecOrMat{T},
158164
h::Hamiltonian,
159-
) where {T<:Real} = phasepoint(h, θ, rand(rng, h.metric, h.kinetic))
165+
) where {T<:Real} = phasepoint(h, θ, rand(rng, h.metric, h.kinetic, θ))
160166

161167
abstract type AbstractMomentumRefreshment end
162168

@@ -168,7 +174,7 @@ refresh(
168174
::FullMomentumRefreshment,
169175
h::Hamiltonian,
170176
z::PhasePoint,
171-
) = phasepoint(h, z.θ, rand(rng, h.metric, h.kinetic))
177+
) = phasepoint(h, z.θ, rand(rng, h.metric, h.kinetic, z.θ))
172178

173179
"""
174180
$(TYPEDEF)
@@ -196,4 +202,8 @@ refresh(
196202
ref::PartialMomentumRefreshment,
197203
h::Hamiltonian,
198204
z::PhasePoint,
199-
) = phasepoint(h, z.θ, ref.α * z.r + sqrt(1 - ref.α^2) * rand(rng, h.metric, h.kinetic))
205+
) = phasepoint(
206+
h,
207+
z.θ,
208+
ref.α * z.r + sqrt(1 - ref.α^2) * rand(rng, h.metric, h.kinetic, z.θ),
209+
)

src/metric.jl

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -129,7 +129,7 @@ function _rand(
129129
return r
130130
end
131131

132-
# TODO The rand interface should be updated by rand from momentum distribution + optional affine transformation by metric
132+
# TODO (kai) The rand interface should be updated as "rand from momentum distribution + optional affine transformation by metric"
133133
Base.rand(rng::AbstractRNG, metric::AbstractMetric, kinetic::AbstractKinetic) =
134134
_rand(rng, metric, kinetic) # this disambiguity is required by Random.rand
135135
Base.rand(
@@ -139,3 +139,19 @@ Base.rand(
139139
) = _rand(rng, metric, kinetic)
140140
Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic) =
141141
rand(GLOBAL_RNG, metric, kinetic)
142+
143+
# ignore θ by default unless defined by the specific kinetic (i.e. not position-dependent)
144+
Base.rand(
145+
rng::AbstractRNG,
146+
metric::AbstractMetric,
147+
kinetic::AbstractKinetic,
148+
θ::AbstractVecOrMat,
149+
) = rand(rng, metric, kinetic) # this disambiguity is required by Random.rand
150+
Base.rand(
151+
rng::AbstractVector{<:AbstractRNG},
152+
metric::AbstractMetric,
153+
kinetic::AbstractKinetic,
154+
θ::AbstractVecOrMat,
155+
) = rand(rng, metric, kinetic)
156+
Base.rand(metric::AbstractMetric, kinetic::AbstractKinetic, θ::AbstractVecOrMat) =
157+
rand(metric, kinetic)

0 commit comments

Comments
 (0)