Skip to content

Commit 1565acf

Browse files
authored
Merge pull request #39 from TuringLang/yebai-patch-1
replace `filldist` with `product_distribution`
2 parents ff99b21 + d82d3a0 commit 1565acf

8 files changed

+17
-23
lines changed

Project.toml

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,6 @@ DynamicPPL = "366bfd00-2699-11ea-058f-f148b4cae6d8"
99
Enzyme = "7da242da-08ed-463a-9acd-ee780be4f1d9"
1010
FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000"
1111
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
12-
LazyArrays = "5078a376-72f3-5289-bfd5-ec5146d43c02"
1312
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1413
LogDensityProblems = "6fdf6af0-433a-55f7-b3ed-c6c6e0b8df7c"
1514
MLDatasets = "eb30cadb-4394-5ae3-aed4-317e484a6458"

models/dppl_gauss_unknown.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@ y = randn() .+ s * randn(n)
66
N = length(y)
77
m ~ Normal(0, 1)
88
s ~ truncated(Cauchy(0, 5); lower=0)
9-
y ~ filldist(Normal(m, s), N)
9+
y ~ product_distribution(fill(Normal(m, s), N))
1010
end
1111

1212
model = dppl_gauss_unknown(y)

models/dppl_hier_poisson.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,3 @@
1-
using LazyArrays
21
using Turing: LogPoisson
32

43
nd, ns = 5, 10
@@ -13,15 +12,13 @@ y = mapreduce(λi -> rand(Poisson(λi), nd), vcat, λ)
1312
x = repeat(logpop, inner=nd)
1413
idx = repeat(collect(1:ns), inner=nd)
1514

16-
lazyarray(f, x) = LazyArray(Base.broadcasted(f, x))
17-
1815
@model function dppl_hier_poisson(y, x, idx, ns)
1916
a0 ~ Normal(0, 10)
2017
a1 ~ Normal(0, 1)
2118
a0_sig ~ truncated(Cauchy(0, 1); lower=0)
22-
a0s ~ filldist(Normal(0, a0_sig), ns)
19+
a0s ~ product_distribution(fill(Normal(0, a0_sig), ns))
2320
alpha = a0 .+ a0s[idx] .+ a1 * x
24-
y ~ arraydist(lazyarray(LogPoisson, alpha))
21+
y ~ product_distribution(LogPoisson.(alpha))
2522
end
2623

2724
model = dppl_hier_poisson(y, x, idx, ns)

models/dppl_high_dim_gauss.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
@model function dppl_high_dim_gauss(D)
2-
m ~ filldist(Normal(0, 1), D)
2+
m ~ product_distribution(fill(Normal(0, 1), D))
33
end
44

55
model = dppl_high_dim_gauss(10_000)

models/dppl_hmm_semisup.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -28,8 +28,8 @@ for t in 2:T_unsup
2828
end
2929

3030
@model function dppl_hmm_semisup(K, T, T_unsup, w, z, u, alpha, beta)
31-
theta ~ filldist(Dirichlet(alpha), K)
32-
phi ~ filldist(Dirichlet(beta), K)
31+
theta ~ product_distribution(fill(Dirichlet(alpha), K))
32+
phi ~ product_distribution(fill(Dirichlet(beta), K))
3333
for t in 1:T
3434
w[t] ~ Categorical(phi[:, z[t]]);
3535
end

models/dppl_lda.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@ for i in 1:m
2121
end
2222

2323
@model function dppl_lda(k, m, w, doc, alpha, beta)
24-
theta ~ filldist(Dirichlet(alpha), m)
25-
phi ~ filldist(Dirichlet(beta), k)
24+
theta ~ product_distribution(fill(Dirichlet(alpha), m))
25+
phi ~ product_distribution(fill(Dirichlet(beta), k))
2626
log_phi_dot_theta = log.(phi * theta)
2727
@addlogprob! sum(log_phi_dot_theta[CartesianIndex.(w, doc)])
2828
end

models/dppl_logistic_regression.jl

Lines changed: 6 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,21 +1,19 @@
11
using StatsFuns: logistic
2-
using LazyArrays
2+
3+
function safelogistic(x::T) where {T}
4+
logistic(x) * (1 - 2 * eps(T)) + eps(T)
5+
end
36

47
d, n = 100, 10_000
58
X = randn(d, n)
69
w = randn(d)
710
y = Int.(logistic.(X' * w) .> 0.5)
811

9-
function safelogistic(x::T) where {T}
10-
logistic(x) * (1 - 2 * eps(T)) + eps(T)
11-
end
12-
13-
lazyarray(f, x) = LazyArray(Base.broadcasted(f, x))
1412

1513
@model function dppl_logistic_regression(Xt, y)
1614
N, D = size(Xt)
17-
w ~ filldist(Normal(), D)
18-
y ~ arraydist(lazyarray(x -> Bernoulli(safelogistic(x)), Xt * w))
15+
w ~ product_distribution(Normal.(zeros(D)))
16+
y ~ product_distribution(Bernoulli.(safelogistic.(Xt * w)))
1917
end
2018

2119
model = dppl_logistic_regression(X', y)

models/dppl_naive_bayes.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -16,12 +16,12 @@ image = transform(pca, image_raw)
1616
# Take only the first 1000 images and vectorise
1717
N = 1000
1818
image_subset = image[:, 1:N]'
19-
image_vec = vec(image_subset[:, :])
19+
image_vec = image_subset[:, :]
2020
labels = labels[1:N]
2121

2222
@model function dppl_naive_bayes(image_vec, labels, C, D)
23-
m ~ filldist(Normal(0, 10), C, D)
24-
image_vec ~ MvNormal(vec(m[labels, :]), I)
23+
m ~ product_distribution(fill(Normal(0, 10), C, D))
24+
image_vec ~ product_distribution(Normal.(m[labels, :]))
2525
end
2626

2727
model = dppl_naive_bayes(image_vec, labels, C, D)

0 commit comments

Comments
 (0)