Skip to content
Merged
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion .github/workflows/CI.yml
Original file line number Diff line number Diff line change
Expand Up @@ -73,7 +73,6 @@ jobs:

- uses: julia-actions/julia-runtest@v1
env:
GROUP: All
JULIA_NUM_THREADS: ${{ matrix.runner.num_threads }}

- uses: julia-actions/julia-processcoverage@v1
Expand Down
2 changes: 1 addition & 1 deletion .github/workflows/CompatHelper.yml
Original file line number Diff line number Diff line change
Expand Up @@ -14,4 +14,4 @@ jobs:
env:
GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }}
COMPATHELPER_PRIV: ${{ secrets.DOCUMENTER_KEY }}
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test", "test/turing"])'
run: julia -e 'using CompatHelper; CompatHelper.main(; subdirs = ["", "docs", "test"])'
2 changes: 0 additions & 2 deletions .github/workflows/JuliaPre.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,3 @@ jobs:
- uses: julia-actions/cache@v2
- uses: julia-actions/julia-buildpkg@v1
- uses: julia-actions/julia-runtest@v1
env:
GROUP: DynamicPPL
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "DynamicPPL"
uuid = "366bfd00-2699-11ea-058f-f148b4cae6d8"
version = "0.31.4"
version = "0.32.0"

[deps]
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
Expand Down
5 changes: 0 additions & 5 deletions docs/src/api.md
Original file line number Diff line number Diff line change
Expand Up @@ -274,11 +274,6 @@ TypedVarInfo

One main characteristic of [`VarInfo`](@ref) is that samples are stored in a linearized form.

```@docs
link!
invlink!
```

```@docs
set_flag!
unset_flag!
Expand Down
62 changes: 46 additions & 16 deletions src/test_utils/models.jl
Original file line number Diff line number Diff line change
Expand Up @@ -323,28 +323,30 @@ function varnames(model::Model{typeof(demo_assume_dot_observe)})
return [@varname(s), @varname(m)]
end

@model function demo_assume_observe_literal()
# `assume` and literal `observe`
@model function demo_assume_multivariate_observe_literal()
# multivariate `assume` and literal `observe`
s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
m ~ MvNormal(zeros(2), Diagonal(s))
[1.5, 2.0] ~ MvNormal(m, Diagonal(s))

return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
end
function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
function logprior_true(model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m)
s_dist = product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
m_dist = MvNormal(zeros(2), Diagonal(s))
return logpdf(s_dist, s) + logpdf(m_dist, m)
end
function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
function loglikelihood_true(
model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m
)
return logpdf(MvNormal(m, Diagonal(s)), [1.5, 2.0])
end
function logprior_true_with_logabsdet_jacobian(
model::Model{typeof(demo_assume_observe_literal)}, s, m
model::Model{typeof(demo_assume_multivariate_observe_literal)}, s, m
)
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end
function varnames(model::Model{typeof(demo_assume_observe_literal)})
function varnames(model::Model{typeof(demo_assume_multivariate_observe_literal)})
return [@varname(s), @varname(m)]
end

Expand Down Expand Up @@ -377,26 +379,50 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_index_literal)})
return [@varname(s[1]), @varname(s[2]), @varname(m[1]), @varname(m[2])]
end

@model function demo_assume_literal_dot_observe()
@model function demo_assume_observe_literal()
# univariate `assume` and literal `observe`
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
1.5 ~ Normal(m, sqrt(s))
2.0 ~ Normal(m, sqrt(s))

return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
end
function logprior_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m)
end
function loglikelihood_true(model::Model{typeof(demo_assume_observe_literal)}, s, m)
return logpdf(Normal(m, sqrt(s)), 1.5) + logpdf(Normal(m, sqrt(s)), 2.0)
end
function logprior_true_with_logabsdet_jacobian(
model::Model{typeof(demo_assume_observe_literal)}, s, m
)
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end
function varnames(model::Model{typeof(demo_assume_observe_literal)})
return [@varname(s), @varname(m)]
end

@model function demo_assume_dot_observe_literal()
# `assume` and literal `dot_observe`
s ~ InverseGamma(2, 3)
m ~ Normal(0, sqrt(s))
[1.5, 2.0] .~ Normal(m, sqrt(s))

return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
end
function logprior_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m)
function logprior_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m)
return logpdf(InverseGamma(2, 3), s) + logpdf(Normal(0, sqrt(s)), m)
end
function loglikelihood_true(model::Model{typeof(demo_assume_literal_dot_observe)}, s, m)
function loglikelihood_true(model::Model{typeof(demo_assume_dot_observe_literal)}, s, m)
return loglikelihood(Normal(m, sqrt(s)), [1.5, 2.0])
end
function logprior_true_with_logabsdet_jacobian(
model::Model{typeof(demo_assume_literal_dot_observe)}, s, m
model::Model{typeof(demo_assume_dot_observe_literal)}, s, m
)
return _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
end
function varnames(model::Model{typeof(demo_assume_literal_dot_observe)})
function varnames(model::Model{typeof(demo_assume_dot_observe_literal)})
return [@varname(s), @varname(m)]
end

Expand Down Expand Up @@ -574,8 +600,9 @@ const DemoModels = Union{
Model{typeof(demo_assume_multivariate_observe)},
Model{typeof(demo_dot_assume_observe_index)},
Model{typeof(demo_assume_dot_observe)},
Model{typeof(demo_assume_literal_dot_observe)},
Model{typeof(demo_assume_dot_observe_literal)},
Model{typeof(demo_assume_observe_literal)},
Model{typeof(demo_assume_multivariate_observe_literal)},
Model{typeof(demo_dot_assume_observe_index_literal)},
Model{typeof(demo_assume_submodel_observe_index_literal)},
Model{typeof(demo_dot_assume_observe_submodel)},
Expand All @@ -585,7 +612,9 @@ const DemoModels = Union{
}

const UnivariateAssumeDemoModels = Union{
Model{typeof(demo_assume_dot_observe)},Model{typeof(demo_assume_literal_dot_observe)}
Model{typeof(demo_assume_dot_observe)},
Model{typeof(demo_assume_dot_observe_literal)},
Model{typeof(demo_assume_observe_literal)},
}
function posterior_mean(model::UnivariateAssumeDemoModels)
return (s=49 / 24, m=7 / 6)
Expand All @@ -609,7 +638,7 @@ const MultivariateAssumeDemoModels = Union{
Model{typeof(demo_assume_index_observe)},
Model{typeof(demo_assume_multivariate_observe)},
Model{typeof(demo_dot_assume_observe_index)},
Model{typeof(demo_assume_observe_literal)},
Model{typeof(demo_assume_multivariate_observe_literal)},
Model{typeof(demo_dot_assume_observe_index_literal)},
Model{typeof(demo_assume_submodel_observe_index_literal)},
Model{typeof(demo_dot_assume_observe_submodel)},
Expand Down Expand Up @@ -759,9 +788,10 @@ const DEMO_MODELS = (
demo_assume_multivariate_observe(),
demo_dot_assume_observe_index(),
demo_assume_dot_observe(),
demo_assume_observe_literal(),
demo_assume_multivariate_observe_literal(),
demo_dot_assume_observe_index_literal(),
demo_assume_literal_dot_observe(),
demo_assume_dot_observe_literal(),
demo_assume_observe_literal(),
demo_assume_submodel_observe_index_literal(),
demo_dot_assume_observe_submodel(),
demo_dot_assume_dot_observe_matrix(),
Expand Down
44 changes: 0 additions & 44 deletions src/varinfo.jl
Original file line number Diff line number Diff line change
Expand Up @@ -1215,27 +1215,6 @@ function link!!(
return Accessors.@set vi.varinfo = DynamicPPL.link!!(t, vi.varinfo, spl, model)
end

"""
link!(vi::VarInfo, spl::Sampler)

Transform the values of the random variables sampled by `spl` in `vi` from the support
of their distributions to the Euclidean space and set their corresponding `"trans"`
flag values to `true`.
"""
function link!(vi::VarInfo, spl::AbstractSampler)
Base.depwarn(
"`link!(varinfo, sampler)` is deprecated, use `link!!(varinfo, sampler, model)` instead.",
:link!,
)
return _link!(vi, spl)
end
function link!(vi::VarInfo, spl::AbstractSampler, spaceval::Val)
Base.depwarn(
"`link!(varinfo, sampler, spaceval)` is deprecated, use `link!!(varinfo, sampler, model)` instead.",
:link!,
)
return _link!(vi, spl, spaceval)
end
function _link!(vi::UntypedVarInfo, spl::AbstractSampler)
# TODO: Change to a lazy iterator over `vns`
vns = _getvns(vi, spl)
Expand Down Expand Up @@ -1313,29 +1292,6 @@ function maybe_invlink_before_eval!!(vi::VarInfo, context::AbstractContext, mode
return maybe_invlink_before_eval!!(t, vi, context, model)
end

"""
invlink!(vi::VarInfo, spl::AbstractSampler)

Transform the values of the random variables sampled by `spl` in `vi` from the
Euclidean space back to the support of their distributions and sets their corresponding
`"trans"` flag values to `false`.
"""
function invlink!(vi::VarInfo, spl::AbstractSampler)
Base.depwarn(
"`invlink!(varinfo, sampler)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.",
:invlink!,
)
return _invlink!(vi, spl)
end

function invlink!(vi::VarInfo, spl::AbstractSampler, spaceval::Val)
Base.depwarn(
"`invlink!(varinfo, sampler, spaceval)` is deprecated, use `invlink!!(varinfo, sampler, model)` instead.",
:invlink!,
)
return _invlink!(vi, spl, spaceval)
end

function _invlink!(vi::UntypedVarInfo, spl::AbstractSampler)
vns = _getvns(vi, spl)
if istrans(vi, vns[1])
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ LogDensityProblemsAD = "996a588d-648d-4e1f-a8f0-a84b347e47b1"
MCMCChains = "c7f686f2-ff18-58e9-bc7b-31028e88f75d"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
OrderedCollections = "bac558e1-5e72-5ebc-8fee-abe8a469f55d"
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
Expand Down
39 changes: 39 additions & 0 deletions test/ad.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,4 +34,43 @@
end
end
end

@testset "Turing#2151: ReverseDiff compilation & eltype(vi, spl)" begin
# Failing model
t = 1:0.05:8
σ = 0.3
y = @. rand(sin(t) + Normal(0, σ))
@model function state_space(y, TT, ::Type{T}=Float64) where {T}
# Priors
α ~ Normal(y[1], 0.001)
τ ~ Exponential(1)
η ~ filldist(Normal(0, 1), TT - 1)
σ ~ Exponential(1)
# create latent variable
x = Vector{T}(undef, TT)
x[1] = α
for t in 2:TT
x[t] = x[t - 1] + η[t - 1] * τ
end
# measurement model
y ~ MvNormal(x, σ^2 * I)
return x
end
model = state_space(y, length(t))

# Dummy sampling algorithm for testing. The test case can only be replicated
# with a custom sampler, it doesn't work with SampleFromPrior(). We need to
# overload assume so that model evaluation doesn't fail due to a lack
# of implementation
struct MyEmptyAlg end
DynamicPPL.getspace(::DynamicPPL.Sampler{MyEmptyAlg}) = ()
DynamicPPL.assume(rng, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi) =
DynamicPPL.assume(dist, vn, vi)

# Compiling the ReverseDiff tape used to fail here
spl = Sampler(MyEmptyAlg())
vi = VarInfo(model)
ldf = DynamicPPL.LogDensityFunction(vi, model, SamplingContext(spl))
@test LogDensityProblemsAD.ADgradient(AutoReverseDiff(; compile=true), ldf) isa Any
end
end
Loading
Loading