Skip to content

Commit 75c2380

Browse files
committed
Fix source
1 parent 5fdc20d commit 75c2380

File tree

2 files changed

+9
-8
lines changed

2 files changed

+9
-8
lines changed

Project.toml

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,17 @@ StatsBase = "2913bbd2-ae8a-5f71-8c99-4fb6c76f3a91"
1818
StatsFuns = "4c63d2b9-4356-54db-8cca-17b64c39e42c"
1919
VecTargets = "8a639fad-7908-4fe4-8003-906e9297f002"
2020

21+
[sources]
22+
VecTargets = {url = "https://github.com/chalk-lab/VecTargets.jl", rev = "main"}
23+
24+
2125
[weakdeps]
2226
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
2327
ComponentArrays = "b0b7db55-cfe3-40fc-9ded-d10e2dbeff66"
2428
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
2529
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
2630
OrdinaryDiffEq = "1dea7af3-3e70-54e6-95c3-0bf5283fa5ed"
2731

28-
[sources]
29-
VecTargets = {rev = "main", url = "https://github.com/chalk-lab/VecTargets.jl"}
30-
3132
[extensions]
3233
AdvancedHMCADTypesExt = "ADTypes"
3334
AdvancedHMCComponentArraysExt = "ComponentArrays"

src/riemannian/hamiltonian.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -93,11 +93,11 @@ function ∂H∂θ_cache(
9393
-(∂ℓπ∂θ - 1 / 2 * diag(term_1_cached * ∂H∂θ) + 1 / 2 * diag(term_2 * ∂H∂θ))
9494
else
9595
-mapreduce(vcat, 1:d) do i
96-
∂H∂θᵢ = ∂H∂θ[:, :, i]
97-
# ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1)
98-
# NOTE Some further optimization can be done here: cache the 1st product all together
99-
∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly
100-
end
96+
∂H∂θᵢ = ∂H∂θ[:, :, i]
97+
# ∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * M' * (J .* (Q' * ∂H∂θᵢ * Q)) * M # (v1)
98+
# NOTE Some further optimization can be done here: cache the 1st product all together
99+
∂ℓπ∂θ[i] - 1 / 2 * tr(term_1_cached * ∂H∂θᵢ) + 1 / 2 * tr(term_2_cached * ∂H∂θᵢ) # (v2) cache friendly
100+
end
101101
end
102102

103103
dv = DualValue(ℓπ, g)

0 commit comments

Comments
 (0)