Skip to content

Commit 3b8dcea

Browse files
Fix depreciation for AutoReverseDiff. (#638)
* Update DynamicPPLReverseDiffExt.jl * Update Project.toml * Update ext/DynamicPPLReverseDiffExt.jl Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> * Update DynamicPPLReverseDiffExt.jl * Update Project.toml --------- Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com>
1 parent 949d153 commit 3b8dcea

File tree

2 files changed

+4
-4
lines changed

2 files changed

+4
-4
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ DynamicPPLReverseDiffExt = ["ReverseDiff"]
4141
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
4242

4343
[compat]
44-
ADTypes = "0.2, 1"
44+
ADTypes = "1"
4545
AbstractMCMC = "5"
4646
AbstractPPL = "0.8.4"
4747
Accessors = "0.1"

ext/DynamicPPLReverseDiffExt.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,12 @@ else
99
end
1010

1111
function LogDensityProblemsAD.ADgradient(
12-
ad::ADTypes.AutoReverseDiff, ℓ::DynamicPPL.LogDensityFunction
13-
)
12+
ad::ADTypes.AutoReverseDiff{Tcompile}, ℓ::DynamicPPL.LogDensityFunction
13+
) where {Tcompile}
1414
return LogDensityProblemsAD.ADgradient(
1515
Val(:ReverseDiff),
1616
ℓ;
17-
compile=Val(ad.compile),
17+
compile=Val(Tcompile),
1818
# `getparams` can return `Vector{Real}`, in which case, `ReverseDiff` will initialize the gradients to Integer 0
1919
# because at https://github.com/JuliaDiff/ReverseDiff.jl/blob/c982cde5494fc166965a9d04691f390d9e3073fd/src/tracked.jl#L473
2020
# `zero(D)` will return 0 when D is Real.

0 commit comments

Comments
 (0)