Skip to content

Commit bc92248

Browse files
committed
Replace remaining instances of @SubModel
1 parent 0c266a6 commit bc92248

File tree

3 files changed

+16
-22
lines changed

3 files changed

+16
-22
lines changed

src/contexts.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ adds the `Prefix` to all parameters.
244244
This context is useful in nested models to ensure that the names of the parameters are
245245
unique.
246246
247-
See also: [`@submodel`](@ref)
247+
See also: [`to_submodel`](@ref)
248248
"""
249249
struct PrefixContext{Prefix,C} <: AbstractContext
250250
context::C

src/test_utils/models.jl

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -437,7 +437,8 @@ end
437437

438438
@model function demo_assume_submodel_observe_index_literal()
439439
# Submodel prior
440-
@submodel s, m = _prior_dot_assume()
440+
priors ~ to_submodel(_prior_dot_assume(), false)
441+
s, m = priors
441442
1.5 ~ Normal(m[1], sqrt(s[1]))
442443
2.0 ~ Normal(m[2], sqrt(s[2]))
443444

@@ -475,7 +476,9 @@ end
475476
m .~ Normal.(0, sqrt.(s))
476477

477478
# Submodel likelihood
478-
@submodel _likelihood_mltivariate_observe(s, m, x)
479+
# With to_submodel, we have to have a left-hand side variable to
480+
# capture the result, so we just use a dummy variable
481+
_ignore ~ to_submodel(_likelihood_mltivariate_observe(s, m, x))
479482

480483
return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
481484
end

test/compiler.jl

Lines changed: 10 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -409,7 +409,7 @@ module Issue537 end
409409
return x ~ Normal()
410410
end
411411
@model function demo2(x, y)
412-
@submodel demo1(x)
412+
_ignore ~ to_submodel(demo1(x), false)
413413
return y ~ Uniform()
414414
end
415415
# No observation.
@@ -441,7 +441,7 @@ module Issue537 end
441441

442442
# Check values makes sense.
443443
@model function demo3(x, y)
444-
@submodel demo1(x)
444+
_ignore ~ to_submodel(demo1(x), false)
445445
return y ~ Normal(x)
446446
end
447447
m = demo3(1000.0, missing)
@@ -453,12 +453,10 @@ module Issue537 end
453453
x ~ Normal()
454454
return x
455455
end
456-
457456
@model function demo_useval(x, y)
458-
@submodel prefix = "sub1" x1 = demo_return(x)
459-
@submodel prefix = "sub2" x2 = demo_return(y)
460-
461-
return z ~ Normal(x1 + x2 + 100, 1.0)
457+
sub1 ~ to_submodel(demo_return(x))
458+
sub2 ~ to_submodel(demo_return(y))
459+
return z ~ Normal(sub1 + sub2 + 100, 1.0)
462460
end
463461
m = demo_useval(missing, missing)
464462
vi = VarInfo(m)
@@ -472,21 +470,18 @@ module Issue537 end
472470
@model function AR1(num_steps, α, μ, σ, ::Type{TV}=Vector{Float64}) where {TV}
473471
η ~ MvNormal(zeros(num_steps), I)
474472
δ = sqrt(1 - α^2)
475-
476473
x = TV(undef, num_steps)
477474
x[1] = η[1]
478475
@inbounds for t in 2:num_steps
479476
x[t] = @. α * x[t - 1] + δ * η[t]
480477
end
481-
482478
return @. μ + σ * x
483479
end
484480

485481
@model function demo(y)
486482
α ~ Uniform()
487483
μ ~ Normal()
488484
σ ~ truncated(Normal(), 0, Inf)
489-
490485
num_steps = length(y[1])
491486
num_obs = length(y)
492487
@inbounds for i in 1:num_obs
@@ -613,14 +608,11 @@ module Issue537 end
613608
@model demo() = x ~ Normal()
614609
retval, svi = DynamicPPL.evaluate!!(demo(), SimpleVarInfo(), SamplingContext())
615610

616-
# Return-value when using `@submodel`
611+
# Return-value when using `to_submodel`
617612
@model inner() = x ~ Normal()
618-
# Without assignment.
619-
@model outer() = @submodel inner()
620-
@test outer()() isa Real
621-
622-
# With assignment.
623-
@model outer() = @submodel x = inner()
613+
@model function outer()
614+
_ignore ~ to_submodel(inner())
615+
end
624616
@test outer()() isa Real
625617

626618
# Edge-cases.
@@ -720,8 +712,7 @@ module Issue537 end
720712
return (; x, y)
721713
end
722714
@model function demo_tracked_submodel()
723-
@submodel (x, y) = demo_tracked()
724-
return (; x, y)
715+
vals ~ to_submodel(demo_tracked(), false)
725716
end
726717
for model in [demo_tracked(), demo_tracked_submodel()]
727718
# Make sure it's runnable and `y` is present in the return-value.

0 commit comments

Comments
 (0)