Skip to content
Merged
Show file tree
Hide file tree
Changes from 5 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions src/test_utils/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,12 @@ A model with variables `m` and `x` with `x` having support depending on `m`.
"""
@model function demo_dynamic_constraint()
m ~ Normal()
x ~ truncated(Normal(), m, Inf)
x ~ truncated(Normal(); lower=m)

return (m=m, x=x)
end
function logprior_true(model::Model{typeof(demo_dynamic_constraint)}, m, x)
return logpdf(Normal(), m) + logpdf(truncated(Normal(), m, Inf), x)
return logpdf(Normal(), m) + logpdf(truncated(Normal(); lower=m), x)
end
function loglikelihood_true(model::Model{typeof(demo_dynamic_constraint)}, m, x)
return zero(float(eltype(m)))
Expand All @@ -30,7 +30,7 @@ end
function logprior_true_with_logabsdet_jacobian(
model::Model{typeof(demo_dynamic_constraint)}, m, x
)
b_x = Bijectors.bijector(truncated(Normal(), m, Inf))
b_x = Bijectors.bijector(truncated(Normal(); lower=m))
x_unconstrained, Δlogp = Bijectors.with_logabsdet_jacobian(b_x, x)
return (m=m, x=x_unconstrained), logprior_true(model, m, x) - Δlogp
end
Expand Down
2 changes: 1 addition & 1 deletion test/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -501,7 +501,7 @@ module Issue537 end
@model function demo(y)
α ~ Uniform()
μ ~ Normal()
σ ~ truncated(Normal(), 0, Inf)
σ ~ truncated(Normal(); lower=0)
num_steps = length(y[1])
num_obs = length(y)
@inbounds for i in 1:num_obs
Expand Down
16 changes: 4 additions & 12 deletions test/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -69,14 +69,6 @@ end
# Test Base functions:
# string, Symbol, ==, hash, in, keys, haskey, isempty, push!!, empty!!,
# getindex, setindex!, getproperty, setproperty!
csym = gensym()
vn1 = @varname x[1][2]
@test string(vn1) == "x[1][2]"
@test Symbol(vn1) == Symbol("x[1][2]")

vn2 = @varname x[1][2]
@test vn2 == vn1
@test hash(vn2) == hash(vn1)

function test_base(vi_original)
vi = deepcopy(vi_original)
Expand Down Expand Up @@ -179,14 +171,14 @@ end
@testset "setval! & setval_and_resample!" begin
@model function testmodel(x)
n = length(x)
s ~ truncated(Normal(), 0, Inf)
s ~ truncated(Normal(); lower=0)
m ~ MvNormal(zeros(n), I)
return x ~ MvNormal(m, s^2 * I)
end

@model function testmodel_univariate(x, ::Type{TV}=Vector{Float64}) where {TV}
n = length(x)
s ~ truncated(Normal(), 0, Inf)
s ~ truncated(Normal(); lower=0)

m = TV(undef, n)
for i in eachindex(m)
Expand Down Expand Up @@ -444,10 +436,10 @@ end
end

@testset "istrans" begin
@model demo_constrained() = x ~ truncated(Normal(), 0, Inf)
@model demo_constrained() = x ~ truncated(Normal(); lower=0)
model = demo_constrained()
vn = @varname(x)
dist = truncated(Normal(), 0, Inf)
dist = truncated(Normal(); lower=0)

### `VarInfo`
# Need to run once since we can't specify that we want to _sample_
Expand Down
Loading