Skip to content

Commit dd1d301

Browse files
committed
Overload AbstractPPL.condition and AbstractPPL.decondition (#337)
Fixes #336. Tests will fail until the Zygote bug is fixed... Maybe we should just mark them as broken so we can merge and release some PRs? Co-authored-by: David Widmann <[email protected]>
1 parent 7a8ba7e commit dd1d301

File tree

5 files changed

+28
-17
lines changed

5 files changed

+28
-17
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "DynamicPPL"
22
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
3-
version = "0.16.0"
3+
version = "0.16.1"
44

55
[deps]
66
AbstractMCMC = "80f14c24-f653-4e6a-9b94-39d6b0f70001"

src/contexts.jl

Lines changed: 17 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -366,11 +366,15 @@ otherwise return `context` which is [`DefaultContext`](@ref) by default.
366366
367367
See also: [`decondition`](@ref)
368368
"""
369-
condition(; values...) = condition(DefaultContext(), NamedTuple(values))
370-
condition(values::NamedTuple) = condition(DefaultContext(), values)
371-
condition(context::AbstractContext, values::NamedTuple{()}) = context
372-
condition(context::AbstractContext, values::NamedTuple) = ConditionContext(values, context)
373-
condition(context::AbstractContext; values...) = condition(context, NamedTuple(values))
369+
AbstractPPL.condition(; values...) = condition(DefaultContext(), NamedTuple(values))
370+
AbstractPPL.condition(values::NamedTuple) = condition(DefaultContext(), values)
371+
AbstractPPL.condition(context::AbstractContext, values::NamedTuple{()}) = context
372+
function AbstractPPL.condition(context::AbstractContext, values::NamedTuple)
373+
return ConditionContext(values, context)
374+
end
375+
function AbstractPPL.condition(context::AbstractContext; values...)
376+
return condition(context, NamedTuple(values))
377+
end
374378

375379
"""
376380
decondition(context::AbstractContext, syms...)
@@ -381,20 +385,22 @@ Note that this recursively traverses contexts, deconditioning all along the way.
381385
382386
See also: [`condition`](@ref)
383387
"""
384-
decondition(::IsLeaf, context, args...) = context
385-
function decondition(::IsParent, context, args...)
388+
AbstractPPL.decondition(::IsLeaf, context, args...) = context
389+
function AbstractPPL.decondition(::IsParent, context, args...)
386390
return setchildcontext(context, decondition(childcontext(context), args...))
387391
end
388-
decondition(context, args...) = decondition(NodeTrait(context), context, args...)
389-
function decondition(context::ConditionContext)
392+
function AbstractPPL.decondition(context, args...)
393+
return decondition(NodeTrait(context), context, args...)
394+
end
395+
function AbstractPPL.decondition(context::ConditionContext)
390396
return decondition(childcontext(context))
391397
end
392-
function decondition(context::ConditionContext, sym)
398+
function AbstractPPL.decondition(context::ConditionContext, sym)
393399
return condition(
394400
decondition(childcontext(context), sym), BangBang.delete!!(context.values, sym)
395401
)
396402
end
397-
function decondition(context::ConditionContext, sym, syms...)
403+
function AbstractPPL.decondition(context::ConditionContext, sym, syms...)
398404
return decondition(
399405
condition(
400406
decondition(childcontext(context), syms...),

src/model.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -254,8 +254,8 @@ From this we can tell what the correct way to condition `m` within `demo_inner`
254254
is in the two different models.
255255
256256
"""
257-
condition(model::Model; values...) = condition(model, NamedTuple(values))
258-
function condition(model::Model, values)
257+
AbstractPPL.condition(model::Model; values...) = condition(model, NamedTuple(values))
258+
function AbstractPPL.condition(model::Model, values)
259259
return contextualize(model, condition(model.context, values))
260260
end
261261

@@ -307,7 +307,7 @@ julia> model(rng)
307307
(m = 0.683947930996541, x = 10.0)
308308
```
309309
"""
310-
function decondition(model::Model, syms...)
310+
function AbstractPPL.decondition(model::Model, syms...)
311311
return contextualize(model, decondition(model.context, syms...))
312312
end
313313

test/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -534,7 +534,7 @@ end
534534

535535
# Ensure we can specialize on arguments.
536536
@model demo(x) = x ~ Normal()
537-
length(methods(demo))
537+
@test length(methods(demo)) == 4
538538
@test f(demo(1.0))
539539
f(::Model{typeof(demo),(:x,)}) = false
540540
@test !f(demo(1.0))

test/test_util.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -43,7 +43,12 @@ function test_model_ad(model, logp_manual)
4343

4444
y, back = Zygote.pullback(logp_model, x)
4545
@test y lp
46-
@test back(1)[1] grad
46+
# will be fixed by https://github.com/FluxML/Zygote.jl/pull/1104
47+
if Threads.nthreads() > 1
48+
@test_broken back(1)[1] grad
49+
else
50+
@test back(1)[1] grad
51+
end
4752
end
4853

4954
"""

0 commit comments

Comments
 (0)