Skip to content

Commit 03e4ba2

Browse files
wsmosesdevmotion
andauthored
Mark Sampling context as not needing derivatives (#556)
* Mark Sampling context as not needing derivatives * Mark Sampling context as not needing derivatives * Fix format * Fix Project.toml * Qualify SamplingContext --------- Co-authored-by: David Widmann <[email protected]>
1 parent 34be85c commit 03e4ba2

File tree

5 files changed

+24
-0
lines changed

5 files changed

+24
-0
lines changed

Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -24,9 +24,11 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"
2424

2525
[weakdeps]
2626
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
27+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
2728

2829
[extensions]
2930
DynamicPPLMCMCChainsExt = ["MCMCChains"]
31+
DynamicPPLEnzymeCoreExt = ["EnzymeCore"]
3032

3133
[compat]
3234
AbstractMCMC = "5"
@@ -38,6 +40,7 @@ Compat = "4"
3840
ConstructionBase = "1.5.4"
3941
Distributions = "0.25"
4042
DocStringExtensions = "0.9"
43+
EnzymeCore = "0.6"
4144
LogDensityProblems = "2"
4245
MCMCChains = "6"
4346
MacroTools = "0.5.6"
@@ -52,3 +55,4 @@ julia = "1.6"
5255

5356
[extras]
5457
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
58+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"

ext/DynamicPPLEnzymeCoreExt.jl

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,13 @@
1+
module DynamicPPLEnzymeCoreExt
2+
3+
if isdefined(Base, :get_extension)
4+
using DynamicPPL: DynamicPPL
5+
using EnzymeCore
6+
else
7+
using ..DynamicPPL: DynamicPPL
8+
using ..EnzymeCore
9+
end
10+
11+
@inline EnzymeCore.EnzymeRules.inactive_type(::Type{<:DynamicPPL.SamplingContext}) = true
12+
13+
end

src/DynamicPPL.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -189,6 +189,9 @@ end
189189
@require MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d" include(
190190
"../ext/DynamicPPLMCMCChainsExt.jl"
191191
)
192+
@require EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869" include(
193+
"../ext/DynamicPPLEnzymeCoreExt.jl"
194+
)
192195
end
193196
end
194197

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
77
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
88
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
99
Documenter = "e30172f5-a6a5-5a46-863b-614d45cd2de4"
10+
EnzymeCore = "f151be2c-9106-41f4-ab19-57ee4f262869"
1011
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1213
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"

test/contexts.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,8 @@ using DynamicPPL:
1616
hasconditioned_nested,
1717
getconditioned_nested
1818

19+
using EnzymeCore
20+
1921
# Dummy context to test nested behaviors.
2022
struct ParentContext{C<:AbstractContext} <: AbstractContext
2123
context::C
@@ -252,6 +254,7 @@ end
252254
@test SamplingContext(Random.default_rng(), DefaultContext()) == context
253255
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
254256
@test SamplingContext(SampleFromPrior(), DefaultContext()) == context
257+
@test EnzymeCore.EnzymeRules.inactive_type(typeof(context))
255258
end
256259

257260
@testset "FixedContext" begin

0 commit comments

Comments
 (0)