Skip to content

Commit fc6cae9

Browse files
authored
Make ZygoteRules and ChainRulesCore weak dependencies (#564)
* Make ZygoteRules and ChainRulesCore weak dependencies * Fix format * Add another non-differentiable to CRC extension * Perform coverage analysis on all Julia versions
1 parent 03e4ba2 commit fc6cae9

File tree

7 files changed

+70
-39
lines changed

7 files changed

+70
-39
lines changed

.github/workflows/CI.yml

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -56,19 +56,14 @@ jobs:
5656
${{ runner.os }}-
5757
- uses: julia-actions/julia-buildpkg@latest
5858
- uses: julia-actions/julia-runtest@latest
59-
with:
60-
coverage: ${{ matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1 }}
6159
env:
6260
GROUP: All
6361
JULIA_NUM_THREADS: ${{ matrix.num_threads }}
6462
- uses: julia-actions/julia-processcoverage@v1
65-
if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1
6663
- uses: codecov/codecov-action@v1
67-
if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1
6864
with:
6965
file: lcov.info
7066
- uses: coverallsapp/github-action@master
71-
if: matrix.version == '1' && matrix.os == 'ubuntu-latest' && matrix.num_threads == 1
7267
with:
7368
github-token: ${{ secrets.GITHUB_TOKEN }}
7469
path-to-lcov: lcov.info

Project.toml

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.24.2"
3+
version = "0.24.3"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"
@@ -23,12 +23,16 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
2323
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2424

2525
[weakdeps]
26-
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
26+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
2727
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
28+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
29+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2830

2931
[extensions]
30-
DynamicPPLMCMCChainsExt = ["MCMCChains"]
32+
DynamicPPLChainRulesCoreExt = ["ChainRulesCore"]
3133
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
34+
DynamicPPLMCMCChainsExt = ["MCMCChains"]
35+
DynamicPPLZygoteRulesExt = ["ZygoteRules"]
3236

3337
[compat]
3438
AbstractMCMC = "5"
@@ -54,5 +58,7 @@ Test = "1.6"
5458
julia = "1.6"
5559

5660
[extras]
57-
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
61+
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
5862
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
63+
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
64+
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

ext/DynamicPPLChainRulesCoreExt.jl

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
module DynamicPPLChainRulesCoreExt
2+
3+
if isdefined(Base, :get_extension)
4+
using DynamicPPL: DynamicPPL, BangBang, Distributions
5+
using ChainRulesCore: ChainRulesCore
6+
else
7+
using ..DynamicPPL: DynamicPPL, BangBang, Distributions
8+
using ..ChainRulesCore: ChainRulesCore
9+
end
10+
11+
# See https://github.com/TuringLang/Turing.jl/issues/1199
12+
ChainRulesCore.@non_differentiable BangBang.push!!(
13+
vi::DynamicPPL.VarInfo,
14+
vn::DynamicPPL.VarName,
15+
r,
16+
dist::Distributions.Distribution,
17+
gidset::Set{DynamicPPL.Selector},
18+
)
19+
20+
ChainRulesCore.@non_differentiable DynamicPPL.updategid!(
21+
vi::DynamicPPL.AbstractVarInfo, vn::DynamicPPL.VarName, spl::DynamicPPL.Sampler
22+
)
23+
24+
# No need + causes issues for some AD backends, e.g. Zygote.
25+
ChainRulesCore.@non_differentiable DynamicPPL.infer_nested_eltype(x)
26+
27+
end # module

ext/DynamicPPLZygoteRulesExt.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
module DynamicPPLZygoteRulesExt
2+
3+
if isdefined(Base, :get_extension)
4+
using DynamicPPL: DynamicPPL, Distributions
5+
using ZygoteRules: ZygoteRules
6+
else
7+
using ..DynamicPPL: DynamicPPL, Distributions
8+
using ..ZygoteRules: ZygoteRules
9+
end
10+
11+
# https://github.com/TuringLang/Turing.jl/issues/1595
12+
ZygoteRules.@adjoint function DynamicPPL.dot_observe(
13+
spl::Union{DynamicPPL.SampleFromPrior,DynamicPPL.SampleFromUniform},
14+
dists::AbstractArray{<:Distributions.Distribution},
15+
value::AbstractArray,
16+
vi,
17+
)
18+
function dot_observe_fallback(spl, dists, value, vi)
19+
DynamicPPL.increment_num_produce!(vi)
20+
return sum(map(Distributions.loglikelihood, dists, value)), vi
21+
end
22+
return ZygoteRules.pullback(__context__, dot_observe_fallback, spl, dists, value, vi)
23+
end
24+
25+
end # module

src/DynamicPPL.jl

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -9,11 +9,9 @@ using OrderedCollections: OrderedDict
99

1010
using AbstractMCMC: AbstractMCMC
1111
using BangBang: BangBang, push!!, empty!!, setindex!!
12-
using ChainRulesCore: ChainRulesCore
1312
using MacroTools: MacroTools
1413
using ConstructionBase: ConstructionBase
1514
using Setfield: Setfield
16-
using ZygoteRules: ZygoteRules
1715
using LogDensityProblems: LogDensityProblems
1816

1917
using LinearAlgebra: LinearAlgebra, Cholesky
@@ -171,7 +169,6 @@ include("simple_varinfo.jl")
171169
include("context_implementations.jl")
172170
include("compiler.jl")
173171
include("prob_macro.jl")
174-
include("compat/ad.jl")
175172
include("loglikelihoods.jl")
176173
include("submodel_macro.jl")
177174
include("test_utils.jl")
@@ -186,12 +183,18 @@ end
186183

187184
@static if !isdefined(Base, :get_extension)
188185
function __init__()
189-
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
190-
"../ext/DynamicPPLMCMCChainsExt.jl"
186+
@require ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" include(
187+
"../ext/DynamicPPLChainRulesCoreExt.jl"
191188
)
192189
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
193190
"../ext/DynamicPPLEnzymeCoreExt.jl"
194191
)
192+
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
193+
"../ext/DynamicPPLMCMCChainsExt.jl"
194+
)
195+
@require ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" include(
196+
"../ext/DynamicPPLZygoteRulesExt.jl"
197+
)
195198
end
196199
end
197200

src/compat/ad.jl

Lines changed: 0 additions & 22 deletions
This file was deleted.

src/utils.jl

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -883,9 +883,6 @@ end
883883
# Handle `AbstractDict` differently since `eltype` results in a `Pair`.
884884
infer_nested_eltype(::Type{<:AbstractDict{<:Any,ET}}) where {ET} = infer_nested_eltype(ET)
885885

886-
# No need + causes issues for some AD backends, e.g. Zygote.
887-
ChainRulesCore.@non_differentiable infer_nested_eltype(x)
888-
889886
"""
890887
varname_leaves(vn::VarName, val)
891888

0 commit comments

Comments
 (0)