Skip to content
Closed
4 changes: 3 additions & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ Bijectors = "76274a88-744f-5084-9051-94815aaf08c4"
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"
MacroTools = "1914dd2f-81c6-5fcd-8719-6d5c9610ff09"
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Requires = "ae029012-a4dd-5104-9daa-d747884805df"
ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
Expand All @@ -22,6 +23,7 @@ julia = "1"
[extras]
AdvancedHMC = "0bf59076-c3b1-5ca4-86bd-e02cd72cde3d"
AdvancedMH = "5b7e9947-ddc0-4b3f-9b55-0d8042f74170"
BangBang = "198e06fe-97b7-11e9-32a5-e1d131e6ad66"
BinaryProvider = "b99e7846-7c00-51b0-8f62-c81ae34c0232"
DistributionsAD = "ced4e74d-a319-5a8a-b0ac-84af2272839c"
DocStringExtensions = "ffbed154-4ef7-542d-bbb7-c09d3a79fcae"
Expand All @@ -48,4 +50,4 @@ UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"

[targets]
test = ["AdvancedHMC", "AdvancedMH", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"]
test = ["AdvancedHMC", "AdvancedMH", "BangBang", "DistributionsAD", "DocStringExtensions", "EllipticalSliceSampling", "ForwardDiff", "Libtask", "LinearAlgebra", "LogDensityProblems", "Logging", "MCMCChains", "Markdown", "PDMats", "ProgressLogging", "Random", "Reexport", "Requires", "SpecialFunctions", "Statistics", "StatsBase", "StatsFuns", "Test", "Tracker", "UUIDs", "Zygote"]
5 changes: 4 additions & 1 deletion src/DynamicPPL.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
using Distributions
using Bijectors
using MacroTools
using Requires

import AbstractMCMC
import ZygoteRules
Expand All @@ -28,6 +29,7 @@ import Base: Symbol,
export AbstractVarInfo,
VarInfo,
UntypedVarInfo,
MixedVarInfo,
getlogp,
setlogp!,
acclogp!,
Expand Down Expand Up @@ -116,8 +118,9 @@ include("sampler.jl")
include("varname.jl")
include("distribution_wrappers.jl")
include("contexts.jl")
include("varinfo.jl")
include("varinfo/varinfo.jl")
include("threadsafe.jl")
include("mixedvarinfo.jl")
include("context_implementations.jl")
include("compiler.jl")
include("prob_macro.jl")
Expand Down
2 changes: 1 addition & 1 deletion src/compiler.jl
Original file line number Diff line number Diff line change
Expand Up @@ -378,7 +378,7 @@ Convert the `value` to the correct type for the `sampler` and the `vi` object.
function matchingvalue(sampler, vi, value)
T = typeof(value)
if hasmissing(T)
return get_matching_type(sampler, vi, T)(value)
return convert(get_matching_type(sampler, vi, T), value)
else
return value
end
Expand Down
20 changes: 10 additions & 10 deletions src/context_implementations.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,14 +24,14 @@ function tilde(ctx::DefaultContext, sampler, right, vn::VarName, _, vi)
end
function tilde(ctx::PriorContext, sampler, right, vn::VarName, inds, vi)
if ctx.vars !== nothing
vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds))
vi[vn, right] = _getindex(getfield(ctx.vars, getsym(vn)), inds)
settrans!(vi, false, vn)
end
return _tilde(sampler, right, vn, vi)
end
function tilde(ctx::LikelihoodContext, sampler, right, vn::VarName, inds, vi)
if ctx.vars !== nothing
vi[vn] = vectorize(right, _getindex(getfield(ctx.vars, getsym(vn)), inds))
vi[vn, right] = _getindex(getfield(ctx.vars, getsym(vn)), inds)
settrans!(vi, false, vn)
end
return _tilde(sampler, NoDist(right), vn, vi)
Expand Down Expand Up @@ -127,11 +127,11 @@ function assume(
if spl isa SampleFromUniform || is_flagged(vi, vn, "del")
unset_flag!(vi, vn, "del")
r = init(dist, spl)
vi[vn] = vectorize(dist, r)
vi[vn, dist] = r
settrans!(vi, false, vn)
setorder!(vi, vn, get_num_produce(vi))
else
r = vi[vn]
r = vi[vn, dist]
end
else
r = init(dist, spl)
Expand Down Expand Up @@ -297,12 +297,12 @@ function get_and_set_val!(
r = init(dist, spl, n)
for i in 1:n
vn = vns[i]
vi[vn] = vectorize(dist, r[:, i])
vi[vn, dist] = r[:, i]
settrans!(vi, false, vn)
setorder!(vi, vn, get_num_produce(vi))
end
else
r = vi[vns]
r = vi[vns, dist]
end
else
r = init(dist, spl, n)
Expand Down Expand Up @@ -330,12 +330,12 @@ function get_and_set_val!(
for i in eachindex(vns)
vn = vns[i]
dist = dists isa AbstractArray ? dists[i] : dists
vi[vn] = vectorize(dist, r[i])
vi[vn, dist] = r[i]
settrans!(vi, false, vn)
setorder!(vi, vn, get_num_produce(vi))
end
else
r = reshape(vi[vec(vns)], size(vns))
r = vi[vns, dists]
end
else
f = (vn, dist) -> init(dist, spl)
Expand All @@ -354,7 +354,7 @@ function set_val!(
)
@assert size(val, 2) == length(vns)
foreach(enumerate(vns)) do (i, vn)
vi[vn] = val[:,i]
vi[vn, dist] = val[:,i]
end
return val
end
Expand All @@ -367,7 +367,7 @@ function set_val!(
@assert size(val) == size(vns)
foreach(CartesianIndices(val)) do ind
dist = dists isa AbstractArray ? dists[ind] : dists
vi[vns[ind]] = vectorize(dist, val[ind])
vi[vns[ind], dist] = val[ind]
end
return val
end
Expand Down
Loading