Skip to content

Commit 7631232

Browse files
devmotionstoropoli
andauthored
Remove in-place updates and fix return values (#64)
* Remove in-place updates and fix return values * Bump version * Fix format * added tests for all ADs * test - just sample 2 observations single-threaded * fix format with empty line * remove rdcache * updating jose's email in Project.toml * Update Project.toml * oops @Assert is actually @test * see if tests passes without Tracker * removed zygote also to see if tests passes Co-authored-by: Jose Storopoli <[email protected]>
1 parent 3781afe commit 7631232

File tree

5 files changed

+72
-32
lines changed

5 files changed

+72
-32
lines changed

Project.toml

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
name = "TuringGLM"
22
uuid = "0004c1f4-53c5-4d43-a221-a1dac6cf6b74"
3-
authors = ["Jose Storopoli <[email protected]>, Rik Huijzer <[email protected]>, and contributors"]
4-
version = "2.1.1"
3+
authors = ["Jose Storopoli <[email protected]>, Rik Huijzer <[email protected]>, and contributors"]
4+
version = "2.1.2"
5+
56

67
[deps]
78
Distributions = "31c24e10-a181-5473-b8eb-7969acd0382f"

src/turing_model.jl

Lines changed: 30 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -183,16 +183,16 @@ function _model(μ_X, σ_X, prior, intercept_ranef, idx, ::Type{Normal})
183183
α ~ prior.intercept
184184
β ~ filldist(prior.predictors, predictors)
185185
σ ~ Exponential(residual)
186-
μ = α .+ X * β
187-
if !isempty(intercept_ranef)
186+
if isempty(intercept_ranef)
187+
μ = α .+ X * β
188+
else
188189
τ ~ mad_y * truncated(TDist(3); lower=0)
189190
zⱼ ~ filldist(Normal(), n_gr)
190-
αⱼ = zⱼ .* τ
191-
μ .+= αⱼ[idxs]
191+
μ = α .+ τ .* getindex.((zⱼ,), idxs) .+ X * β
192192
end
193193
#TODO: implement random-effects slope
194194
y ~ MvNormal(μ, σ^2 * I)
195-
return (; α, β, σ, τ, zⱼ, αⱼ, y)
195+
return nothing
196196
end
197197
end
198198
function _model(μ_X, σ_X, prior, ::Type{Normal})
@@ -203,7 +203,7 @@ function _model(μ_X, σ_X, prior, ::Type{Normal})
203203
β ~ filldist(prior.predictors, predictors)
204204
σ ~ Exponential(residual)
205205
y ~ MvNormal.+ X * β, σ^2 * I)
206-
return (; α, β, σ, y)
206+
return nothing
207207
end
208208
end
209209

@@ -226,16 +226,16 @@ function _model(μ_X, σ_X, prior, intercept_ranef, idx, ::Type{TDist})
226226
β ~ filldist(prior.predictors, predictors)
227227
σ ~ Exponential(residual)
228228
ν ~ prior.auxiliary
229-
μ = α .+ X * β
230-
if !isempty(intercept_ranef)
229+
if isempty(intercept_ranef)
230+
μ = α .+ X * β
231+
else
231232
τ ~ mad_y * truncated(TDist(3); lower=0)
232233
zⱼ ~ filldist(Normal(), n_gr)
233-
αⱼ = zⱼ .* τ
234-
μ .+= αⱼ[idxs]
234+
μ = α .+ τ .* getindex.((zⱼ,), idxs) .+ X * β
235235
end
236236
#TODO: implement random-effects slope
237237
y ~ arraydist+ σ * TDist.(ν))
238-
return (; α, β, σ, ν, τ, zⱼ, αⱼ, y)
238+
return nothing
239239
end
240240
end
241241
function _model(μ_X, σ_X, prior, ::Type{TDist})
@@ -247,7 +247,7 @@ function _model(μ_X, σ_X, prior, ::Type{TDist})
247247
σ ~ Exponential(residual)
248248
ν ~ prior.auxiliary
249249
y ~ arraydist((α .+ X * β) .+ σ .* TDist.(ν))
250-
return (; α, β, σ, ν, y)
250+
return nothing
251251
end
252252
end
253253

@@ -267,16 +267,16 @@ function _model(μ_X, σ_X, prior, intercept_ranef, idx, ::Type{Bernoulli})
267267
)
268268
α ~ prior.intercept
269269
β ~ filldist(prior.predictors, predictors)
270-
μ = α .+ X * β
271-
if !isempty(intercept_ranef)
270+
if isempty(intercept_ranef)
271+
μ = α .+ X * β
272+
else
272273
τ ~ mad_y * truncated(TDist(3); lower=0)
273274
zⱼ ~ filldist(Normal(), n_gr)
274-
αⱼ = zⱼ .* τ
275-
μ .+= αⱼ[idxs]
275+
μ = α .+ τ .* getindex.((zⱼ,), idxs) .+ X * β
276276
end
277277
#TODO: implement random-effects slope
278278
y ~ arraydist(LazyArray(@~ BernoulliLogit.(μ)))
279-
return (; α, β, τ, zⱼ, αⱼ, y)
279+
return nothing
280280
end
281281
end
282282
function _model(μ_X, σ_X, prior, ::Type{Bernoulli})
@@ -286,7 +286,7 @@ function _model(μ_X, σ_X, prior, ::Type{Bernoulli})
286286
α ~ prior.intercept
287287
β ~ filldist(prior.predictors, predictors)
288288
y ~ arraydist(LazyArray(@~ BernoulliLogit.(α .+ X * β)))
289-
return (; α, β, y)
289+
return nothing
290290
end
291291
end
292292

@@ -306,16 +306,16 @@ function _model(μ_X, σ_X, prior, intercept_ranef, idx, ::Type{Poisson})
306306
)
307307
α ~ prior.intercept
308308
β ~ filldist(prior.predictors, predictors)
309-
μ = α .+ X * β
310-
if !isempty(intercept_ranef)
309+
if isempty(intercept_ranef)
310+
μ = α .+ X * β
311+
else
311312
τ ~ mad_y * truncated(TDist(3); lower=0)
312313
zⱼ ~ filldist(Normal(), n_gr)
313-
αⱼ = zⱼ .* τ
314-
μ .+= αⱼ[idxs]
314+
μ = α .+ τ .* getindex.((zⱼ,), idxs) .+ X * β
315315
end
316316
#TODO: implement random-effects slope
317317
y ~ arraydist(LazyArray(@~ LogPoisson.(μ)))
318-
return (; α, β, τ, zⱼ, αⱼ, y)
318+
return nothing
319319
end
320320
end
321321
function _model(μ_X, σ_X, prior, ::Type{Poisson})
@@ -325,7 +325,7 @@ function _model(μ_X, σ_X, prior, ::Type{Poisson})
325325
α ~ prior.intercept
326326
β ~ filldist(prior.predictors, predictors)
327327
y ~ arraydist(LazyArray(@~ LogPoisson.(α .+ X * β)))
328-
return (; α, β, y)
328+
return nothing
329329
end
330330
end
331331

@@ -347,16 +347,16 @@ function _model(μ_X, σ_X, prior, intercept_ranef, idx, ::Type{NegativeBinomial
347347
β ~ filldist(prior.predictors, predictors)
348348
ϕ⁻ ~ prior.auxiliary
349349
ϕ = 1 / ϕ⁻
350-
μ = α .+ X * β
351-
if !isempty(intercept_ranef)
350+
if isempty(intercept_ranef)
351+
μ = α .+ X * β
352+
else
352353
τ ~ mad_y * truncated(TDist(3); lower=0)
353354
zⱼ ~ filldist(Normal(), n_gr)
354-
αⱼ = zⱼ .* τ
355-
μ .+= αⱼ[idxs]
355+
μ = α .+ τ .* getindex.((zⱼ,), idxs) .+ X * β
356356
end
357357
#TODO: implement random-effects slope
358358
y ~ arraydist(LazyArray(@~ NegativeBinomial2.(exp.(μ), ϕ)))
359-
return (; α, β, ϕ, τ, zⱼ, αⱼ, y)
359+
return nothing
360360
end
361361
end
362362
function _model(μ_X, σ_X, prior, ::Type{NegativeBinomial})
@@ -366,7 +366,7 @@ function _model(μ_X, σ_X, prior, ::Type{NegativeBinomial})
366366
ϕ⁻ ~ prior.auxiliary
367367
ϕ = 1 / ϕ⁻
368368
y ~ arraydist(LazyArray(@~ NegativeBinomial2.(exp.(α .+ X * β), ϕ)))
369-
return (; α, β, ϕ, y)
369+
return nothing
370370
end
371371
end
372372

test/Project.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,10 +2,14 @@
22
CSV = "336ed68f-0bac-5ca0-87d4-7b16caf5d00b"
33
CategoricalArrays = "324d7699-5711-5eae-9e2f-1d82baa6b597"
44
DataFrames = "a93c6f00-e57d-5684-b7b6-d8193f3e46c0"
5+
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
56
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
7+
ReverseDiff = "37e2e3b7-166d-5795-8a7a-e32c996b4267"
68
Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
79
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
810
TimerOutputs = "a759f4b9-e2f1-59dc-863e-4aeb61b1ea8f"
11+
Tracker = "9f7883ad-71c0-57eb-9f7f-b5c9e6d3789c"
12+
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
913

1014
[compat]
1115
CSV = "0.9, 1"

test/ad_backends.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
@timed_testset "ad_backends" begin
2+
DATA_DIR = joinpath("..", "data")
3+
cheese = CSV.read(joinpath(DATA_DIR, "cheese.csv"), DataFrame)
4+
f = @formula(y ~ (1 | cheese) + background)
5+
m = turing_model(f, cheese)
6+
# only running 2 samples to test if the different ADs runs
7+
@timed_testset "ForwardDiff" begin
8+
Turing.setadbackend(:forwarddiff)
9+
chn = sample(m, NUTS(), 2)
10+
@test chn isa Chains
11+
end
12+
# TODO: fix Tracker tests
13+
# @timed_testset "Tracker" begin
14+
# using Tracker
15+
# Turing.setadbackend(:tracker)
16+
# chn = sample(m, NUTS(), 2)
17+
# @test chn isa Chains
18+
# end
19+
# TODO: fix Zygote tests
20+
# @timed_testset "Zygote" begin
21+
# using Zygote
22+
# Turing.setadbackend(:zygote)
23+
# chn = sample(m, NUTS(), 2)
24+
# @test chn isa Chains
25+
# end
26+
@timed_testset "ReverseDiff" begin
27+
using ReverseDiff
28+
Turing.setadbackend(:reversediff)
29+
chn = sample(m, NUTS(), 2)
30+
@test chn isa Chains
31+
end
32+
# go back to defaults
33+
Turing.setadbackend(:forwarddiff)
34+
end

test/runtests.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,7 @@ end
5050
include("utils.jl")
5151
include("priors.jl")
5252
include("turing_model.jl")
53+
include("ad_backends.jl")
5354
end
5455

5556
show(TIMEROUTPUT; compact=true, sortby=:firstexec)

0 commit comments

Comments
 (0)