Skip to content

Commit c728e95

Browse files
committed
Add tests for model conditioning syntax
1 parent 9a8a36b commit c728e95

File tree

1 file changed

+22
-0
lines changed

1 file changed

+22
-0
lines changed

test/model.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,28 @@ const GDEMO_DEFAULT = DynamicPPL.TestUtils.demo_assume_observe_literal()
100100
end
101101
end
102102

103+
@testset "model conditioning with various arguments" begin
104+
@model function demo_condition()
105+
x ~ Normal()
106+
y ~ Normal(x)
107+
end
108+
model = demo_condition()
109+
# Test that different syntaxes work and give the same underlying ConditionContext
110+
@testset "NamedTuple ConditionContext" begin
111+
expected_values = (y = 2,)
112+
@test condition(model, (y=2,)).context.values == expected_values
113+
@test condition(model, y=2).context.values == expected_values
114+
@test condition(model; y=2).context.values == expected_values
115+
@test (model | (y = 2, )).context.values == expected_values
116+
end
117+
@testset "AbstractDict ConditionContext" begin
118+
expected_values = Dict(@varname(y) => 2)
119+
@test condition(model, Dict(@varname(y) => 2)).context.values == expected_values
120+
@test condition(model, @varname(y) => 2).context.values == expected_values
121+
@test (model | (@varname(y) => 2, )).context.values == expected_values
122+
end
123+
end
124+
103125
@testset "DynamicPPL#684: threadsafe evaluation with multiple types" begin
104126
@model function multiple_types(x)
105127
ns ~ filldist(Normal(0, 2.0), 3)

0 commit comments

Comments
 (0)