Skip to content

Commit b8ccbce

Browse files
committed
Fix imports
1 parent 19fb396 commit b8ccbce

File tree

3 files changed

+21
-18
lines changed

3 files changed

+21
-18
lines changed

src/DynamicPPL.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,8 @@
11
module DynamicPPL
22

33
using AbstractMCMC: AbstractSampler, AbstractChains, AbstractModel
4-
using Distributions: UnivariateDistribution,
5-
MultivariateDistribution,
6-
MatrixDistribution,
7-
Distribution
8-
using Bijectors: link, invlink
4+
using Distributions
5+
using Bijectors
96
using MacroTools
107

118
import Base: string,

src/context_implementations.jl

Lines changed: 13 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,9 @@
1+
using Distributions: UnivariateDistribution,
2+
MultivariateDistribution,
3+
MatrixDistribution,
4+
Distribution
5+
6+
17
alg_str(spl::Sampler) = string(nameof(typeof(spl.alg)))
28

39
# utility funcs for querying sampler information
@@ -93,9 +99,9 @@ function assume(
9399
end
94100
# NOTE: The importance weight is not correctly computed here because
95101
# r is genereated from some uniform distribution which is different from the prior
96-
# acclogp!(vi, logpdf_with_trans(dist, r, istrans(vi, vn)))
102+
# acclogp!(vi, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn)))
97103

98-
return r, logpdf_with_trans(dist, r, istrans(vi, vn))
104+
return r, Bijectors.logpdf_with_trans(dist, r, istrans(vi, vn))
99105
end
100106

101107
function observe(
@@ -105,7 +111,7 @@ function observe(
105111
vi::VarInfo,
106112
)
107113
increment_num_produce!(vi)
108-
return logpdf(dist, value)
114+
return Distributions.logpdf(dist, value)
109115
end
110116

111117
# .~ functions
@@ -209,7 +215,7 @@ function dot_assume(
209215
)
210216
@assert length(dist) == size(var, 1)
211217
r = get_and_set_val!(vi, vns, dist, spl)
212-
lp = sum(logpdf_with_trans(dist, r, istrans(vi, vns[1])))
218+
lp = sum(Bijectors.logpdf_with_trans(dist, r, istrans(vi, vns[1])))
213219
var .= r
214220
return var, lp
215221
end
@@ -222,7 +228,7 @@ function dot_assume(
222228
)
223229
r = get_and_set_val!(vi, vns, dists, spl)
224230
# Make sure `r` is not a matrix for multivariate distributions
225-
lp = sum(logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
231+
lp = sum(Bijectors.logpdf_with_trans.(dists, r, istrans(vi, vns[1])))
226232
var .= r
227233
return var, lp
228234
end
@@ -353,7 +359,7 @@ function dot_observe(
353359
increment_num_produce!(vi)
354360
DynamicPPL.DEBUG && @debug "dist = $dist"
355361
DynamicPPL.DEBUG && @debug "value = $value"
356-
return sum(logpdf(dist, value))
362+
return sum(Distributions.logpdf(dist, value))
357363
end
358364
function dot_observe(
359365
spl::Union{SampleFromPrior, SampleFromUniform},
@@ -364,7 +370,7 @@ function dot_observe(
364370
increment_num_produce!(vi)
365371
DynamicPPL.DEBUG && @debug "dists = $dists"
366372
DynamicPPL.DEBUG && @debug "value = $value"
367-
return sum(logpdf.(dists, value))
373+
return sum(Distributions.logpdf.(dists, value))
368374
end
369375
function dot_observe(
370376
spl::Sampler,

src/varinfo.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -772,7 +772,7 @@ function link!(vi::UntypedVarInfo, spl::Sampler)
772772
for vn in vns
773773
dist = getdist(vi, vn)
774774
# TODO: Use inplace versions to avoid allocations
775-
setval!(vi, vectorize(dist, link(dist, reconstruct(dist, getval(vi, vn)))), vn)
775+
setval!(vi, vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), vn)
776776
settrans!(vi, true, vn)
777777
end
778778
else
@@ -793,7 +793,7 @@ end
793793
# Iterate over all `f_vns` and transform
794794
for vn in f_vns
795795
dist = getdist(vi, vn)
796-
setval!(vi, vectorize(dist, link(dist, reconstruct(dist, getval(vi, vn)))), vn)
796+
setval!(vi, vectorize(dist, Bijectors.link(dist, reconstruct(dist, getval(vi, vn)))), vn)
797797
settrans!(vi, true, vn)
798798
end
799799
else
@@ -818,7 +818,7 @@ function invlink!(vi::UntypedVarInfo, spl::AbstractSampler)
818818
if istrans(vi, vns[1])
819819
for vn in vns
820820
dist = getdist(vi, vn)
821-
setval!(vi, vectorize(dist, invlink(dist, reconstruct(dist, getval(vi, vn)))), vn)
821+
setval!(vi, vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), vn)
822822
settrans!(vi, false, vn)
823823
end
824824
else
@@ -839,7 +839,7 @@ end
839839
# Iterate over all `f_vns` and transform
840840
for vn in f_vns
841841
dist = getdist(vi, vn)
842-
setval!(vi, vectorize(dist, invlink(dist, reconstruct(dist, getval(vi, vn)))), vn)
842+
setval!(vi, vectorize(dist, Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn)))), vn)
843843
settrans!(vi, false, vn)
844844
end
845845
else
@@ -894,14 +894,14 @@ function getindex(vi::AbstractVarInfo, vn::VarName)
894894
@assert haskey(vi, vn) "[DynamicPPL] attempted to replay unexisting variables in VarInfo"
895895
dist = getdist(vi, vn)
896896
return istrans(vi, vn) ?
897-
invlink(dist, reconstruct(dist, getval(vi, vn))) :
897+
Bijectors.invlink(dist, reconstruct(dist, getval(vi, vn))) :
898898
reconstruct(dist, getval(vi, vn))
899899
end
900900
function getindex(vi::AbstractVarInfo, vns::Vector{<:VarName})
901901
@assert haskey(vi, vns[1]) "[DynamicPPL] attempted to replay unexisting variables in VarInfo"
902902
dist = getdist(vi, vns[1])
903903
return istrans(vi, vns[1]) ?
904-
invlink(dist, reconstruct(dist, getval(vi, vns), length(vns))) :
904+
Bijectors.invlink(dist, reconstruct(dist, getval(vi, vns), length(vns))) :
905905
reconstruct(dist, getval(vi, vns), length(vns))
906906
end
907907

0 commit comments

Comments
 (0)