Skip to content

Commit 9bdec6f

Browse files
committed
Fix JET detected issues
1 parent 237821e commit 9bdec6f

File tree

2 files changed

+3
-5
lines changed

2 files changed

+3
-5
lines changed

src/AdvancedHMC.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,8 @@ using LinearAlgebra:
1313
cholesky,
1414
UniformScaling,
1515
logdet,
16-
tr
16+
tr,
17+
eigen
1718
using StatsFuns: logaddexp, logsumexp, loghalf
1819
using Random: Random, AbstractRNG
1920
using ProgressMeter: ProgressMeter

src/riemannian/hamiltonian.jl

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -89,16 +89,13 @@ function ∂H∂θ_cache(
8989
d = length(∂ℓπ∂θ)
9090
D = diagm((Q' * r) ./ softabsλ)
9191
term_2_cached = Q * D * J * D * Q'
92-
g = if isdiag
93-
-(∂ℓπ∂θ - 1 / 2 * diag(term_1_cached * ∂H∂θ) + 1 / 2 * diag(term_2 * ∂H∂θ))
94-
else
92+
g =
9593
-mapreduce(vcat, 1:d) do i
9694
∂H∂θᵢ = ∂H∂θ[:, :, i]
9795
# ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1)
9896
# NOTE Some further optimization can be done here: cache the 1st product all together
9997
∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly
10098
end
101-
end
10299

103100
dv = DualValue(ℓπ, g)
104101
return return_cache ? (dv, (; ℓπ, ∂ℓπ∂θ, ∂H∂θ, Q, softabsλ, J, term_1_cached)) : dv

0 commit comments

Comments
 (0)