Skip to content

Commit 17c8c5d

Browse files
committed
rm simple unpack from ham flows
1 parent e257fab commit 17c8c5d

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

example/demo_hamiltonian_flow.jl

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ using Optimisers, ADTypes
44
using Mooncake
55
using Bijectors
66
using Bijectors: partition, combine, PartitionMask
7-
using SimpleUnPack: @unpack
87

98
using NormalizingFlows
109

@@ -62,7 +61,7 @@ function _leapfrog(
6261
end
6362

6463
function Bijectors.transform(lf::LeapFrog{T}, z::AbstractVector{T}) where {T<:Real}
65-
@unpack dim, logϵ, L, ∇logp = lf
64+
(; dim, logϵ, L, ∇logp) = lf
6665
@assert length(z) == 2dim "dimension of input must be even, z = [x, ρ]"
6766

6867
ϵ = _get_stepsize(lf)
@@ -73,7 +72,7 @@ end
7372

7473
function Bijectors.transform(ilf::Inverse{<:LeapFrog{T}}, z::AbstractVector{T}) where {T<:Real}
7574
lf = ilf.orig
76-
@unpack dim, logϵ, L, ∇logp = lf
75+
(; dim, logϵ, L, ∇logp) = lf
7776
@assert length(z) == 2dim "dimension of input must be even, z = [x, ρ]"
7877

7978
ϵ = _get_stepsize(lf)
@@ -123,6 +122,9 @@ function logp_joint(z::AbstractVector{T}) where {T<:Real}
123122
logp_ρ = sum(logpdf(Normal(), ρ))
124123
return logp_x + logp_ρ
125124
end
125+
126+
# the score function is the gradient of the logpdf.
127+
# In all the synthetic targets, the score function is only implemented for the Banana target
126128
∇logp = Base.Fix1(score, target)
127129

128130
######################################

0 commit comments

Comments
 (0)