Skip to content

Commit b3778ff

Browse files
committed
Run formatter
1 parent 019e41b commit b3778ff

15 files changed

+98
-90
lines changed

benchmarks/src/Models.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -47,7 +47,7 @@ A short model that tries to cover many DynamicPPL features.
4747
Includes scalar, vector univariate, and multivariate variables; ~, .~, and loops; allocating
4848
a variable vector; observations passed as arguments, and as literals.
4949
"""
50-
@model function smorgasbord(x, y, ::Type{TV}=Vector{Float64}) where {TV}
50+
@model function smorgasbord(x, y, (::Type{TV})=Vector{Float64}) where {TV}
5151
@assert length(x) == length(y)
5252
m ~ truncated(Normal(); lower=0)
5353
means ~ product_distribution(fill(Exponential(m), length(x)))
@@ -68,7 +68,7 @@ The second variable, `o`, is meant to be conditioned on after model instantiatio
6868
6969
See `multivariate` for a version that uses `product_distribution` rather than loops.
7070
"""
71-
@model function loop_univariate(num_dims, ::Type{TV}=Vector{Float64}) where {TV}
71+
@model function loop_univariate(num_dims, (::Type{TV})=Vector{Float64}) where {TV}
7272
a = TV(undef, num_dims)
7373
o = TV(undef, num_dims)
7474
for i in 1:num_dims
@@ -88,7 +88,7 @@ The second variable, `o`, is meant to be conditioned on after model instantiatio
8888
8989
See `loop_univariate` for a version that uses loops rather than `product_distribution`.
9090
"""
91-
@model function multivariate(num_dims, ::Type{TV}=Vector{Float64}) where {TV}
91+
@model function multivariate(num_dims, (::Type{TV})=Vector{Float64}) where {TV}
9292
a = TV(undef, num_dims)
9393
o = TV(undef, num_dims)
9494
a ~ product_distribution(fill(Normal(0, 1), num_dims))
@@ -118,7 +118,7 @@ end
118118
A model with random variables that have changing support under linking, or otherwise
119119
complicated bijectors.
120120
"""
121-
@model function dynamic(::Type{T}=Vector{Float64}) where {T}
121+
@model function dynamic((::Type{T})=Vector{Float64}) where {T}
122122
eta ~ truncated(Normal(); lower=0.0, upper=0.1)
123123
mat1 ~ LKJCholesky(4, eta)
124124
mat2 ~ InverseWishart(3.2, cholesky([1.0 0.5; 0.5 1.0]))

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -161,8 +161,8 @@ function _predictive_samples_to_arrays(predictive_samples)
161161

162162
variable_names = collect(variable_names_set)
163163
variable_values = [
164-
get(sample_dicts[i], key, missing) for i in eachindex(sample_dicts),
165-
key in variable_names
164+
get(sample_dicts[i], key, missing) for
165+
i in eachindex(sample_dicts), key in variable_names
166166
]
167167

168168
return variable_names, variable_values

src/extract_priors.jl

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -105,8 +105,9 @@ julia> length(extract_priors(rng, model)[@varname(x)])
105105
9
106106
```
107107
"""
108-
extract_priors(args::Union{Model,AbstractVarInfo}...) =
108+
function extract_priors(args::Union{Model,AbstractVarInfo}...)
109109
extract_priors(Random.default_rng(), args...)
110+
end
110111
function extract_priors(rng::Random.AbstractRNG, model::Model)
111112
context = PriorExtractorContext(SamplingContext(rng))
112113
evaluate!!(model, VarInfo(), context)

src/logdensityfunction.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -245,9 +245,11 @@ model.
245245
246246
By default, this just returns the input unchanged.
247247
"""
248-
tweak_adtype(
248+
function tweak_adtype(
249249
adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo, ::AbstractContext
250-
) = adtype
250+
)
251+
adtype
252+
end
251253

252254
"""
253255
use_closure(adtype::ADTypes.AbstractADType)

src/model.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,9 @@ Return a `Model` which now treats variables on the right-hand side as observatio
9696
9797
See [`condition`](@ref) for more information and examples.
9898
"""
99-
Base.:|(model::Model, values::Union{Pair,Tuple,NamedTuple,AbstractDict{<:VarName}}) =
99+
function Base.:|(model::Model, values::Union{Pair,Tuple,NamedTuple,AbstractDict{<:VarName}})
100100
condition(model, values)
101+
end
101102

102103
"""
103104
condition(model::Model; values...)
@@ -1467,5 +1468,6 @@ ERROR: ArgumentError: `~` with a model on the right-hand side of an observe stat
14671468
[...]
14681469
```
14691470
"""
1470-
to_submodel(model::Model, auto_prefix::Bool=true) =
1471+
function to_submodel(model::Model, auto_prefix::Bool=true)
14711472
to_sampleable(returned(model), auto_prefix)
1473+
end

src/simple_varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -244,7 +244,7 @@ function SimpleVarInfo{T}(
244244
end
245245

246246
# Constructor from `VarInfo`.
247-
function SimpleVarInfo(vi::TypedVarInfo, ::Type{D}=NamedTuple; kwargs...) where {D}
247+
function SimpleVarInfo(vi::TypedVarInfo, (::Type{D})=NamedTuple; kwargs...) where {D}
248248
return SimpleVarInfo{eltype(getlogp(vi))}(vi, D; kwargs...)
249249
end
250250
function SimpleVarInfo{T}(

src/test_utils/models.jl

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ x[4:5] ~ Dirichlet([1.0, 2.0])
4949
```
5050
"""
5151
@model function demo_one_variable_multiple_constraints(
52-
::Type{TV}=Vector{Float64}
52+
(::Type{TV})=Vector{Float64}
5353
) where {TV}
5454
x = TV(undef, 5)
5555
x[1] ~ Normal()
@@ -186,7 +186,9 @@ function _demo_logprior_true_with_logabsdet_jacobian(model, s, m)
186186
return (s=s_unconstrained, m=m), logprior_true(model, s, m) - Δlogp
187187
end
188188

189-
@model function demo_dot_assume_observe(x=[1.5, 2.0], ::Type{TV}=Vector{Float64}) where {TV}
189+
@model function demo_dot_assume_observe(
190+
x=[1.5, 2.0], (::Type{TV})=Vector{Float64}
191+
) where {TV}
190192
# `dot_assume` and `observe`
191193
s = TV(undef, length(x))
192194
m = TV(undef, length(x))
@@ -212,7 +214,7 @@ function varnames(model::Model{typeof(demo_dot_assume_observe)})
212214
end
213215

214216
@model function demo_assume_index_observe(
215-
x=[1.5, 2.0], ::Type{TV}=Vector{Float64}
217+
x=[1.5, 2.0], (::Type{TV})=Vector{Float64}
216218
) where {TV}
217219
# `assume` with indexing and `observe`
218220
s = TV(undef, length(x))
@@ -268,7 +270,7 @@ function varnames(model::Model{typeof(demo_assume_multivariate_observe)})
268270
end
269271

270272
@model function demo_dot_assume_observe_index(
271-
x=[1.5, 2.0], ::Type{TV}=Vector{Float64}
273+
x=[1.5, 2.0], (::Type{TV})=Vector{Float64}
272274
) where {TV}
273275
# `dot_assume` and `observe` with indexing
274276
s = TV(undef, length(x))
@@ -348,7 +350,9 @@ function varnames(model::Model{typeof(demo_assume_multivariate_observe_literal)}
348350
return [@varname(s), @varname(m)]
349351
end
350352

351-
@model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV}
353+
@model function demo_dot_assume_observe_index_literal(
354+
(::Type{TV})=Vector{Float64}
355+
) where {TV}
352356
# `dot_assume` and literal `observe` with indexing
353357
s = TV(undef, 2)
354358
m = TV(undef, 2)
@@ -425,7 +429,7 @@ function varnames(model::Model{typeof(demo_assume_dot_observe_literal)})
425429
end
426430

427431
# Only used as a submodel
428-
@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV}
432+
@model function _prior_dot_assume((::Type{TV})=Vector{Float64}) where {TV}
429433
s = TV(undef, 2)
430434
s .~ InverseGamma(2, 3)
431435
m = TV(undef, 2)
@@ -466,7 +470,7 @@ end
466470
end
467471

468472
@model function demo_dot_assume_observe_submodel(
469-
x=[1.5, 2.0], ::Type{TV}=Vector{Float64}
473+
x=[1.5, 2.0], (::Type{TV})=Vector{Float64}
470474
) where {TV}
471475
s = TV(undef, length(x))
472476
s .~ InverseGamma(2, 3)
@@ -496,7 +500,7 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_submodel)})
496500
end
497501

498502
@model function demo_dot_assume_observe_matrix_index(
499-
x=transpose([1.5 2.0;]), ::Type{TV}=Vector{Float64}
503+
x=transpose([1.5 2.0;]), (::Type{TV})=Vector{Float64}
500504
) where {TV}
501505
s = TV(undef, length(x))
502506
s .~ InverseGamma(2, 3)
@@ -525,7 +529,7 @@ function varnames(model::Model{typeof(demo_dot_assume_observe_matrix_index)})
525529
end
526530

527531
@model function demo_assume_matrix_observe_matrix_index(
528-
x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64}
532+
x=transpose([1.5 2.0;]), (::Type{TV})=Array{Float64}
529533
) where {TV}
530534
n = length(x)
531535
d = n ÷ 2

src/varinfo.jl

Lines changed: 42 additions & 50 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ end
327327
# TODO(mhauru) Note that this could still generate an empty metadata object if none
328328
# of the lenses in `vns` are in `metadata`. Not sure if that's okay. Checking for
329329
# emptiness would make this type unstable again.
330-
:((; $sym=subset(metadata.$sym, vns)))
330+
:((; ($sym)=subset(metadata.$sym, vns)))
331331
else
332332
:(NamedTuple{}())
333333
end
@@ -708,8 +708,9 @@ findinds(vnv::VarNamedVector) = 1:length(vnv.varnames)
708708
709709
Return a `NamedTuple` of the variables in `vi` grouped by symbol.
710710
"""
711-
all_varnames_grouped_by_symbol(vi::TypedVarInfo) =
711+
function all_varnames_grouped_by_symbol(vi::TypedVarInfo)
712712
all_varnames_grouped_by_symbol(vi.metadata)
713+
end
713714

714715
@generated function all_varnames_grouped_by_symbol(md::NamedTuple{names}) where {names}
715716
expr = Expr(:tuple)
@@ -981,25 +982,22 @@ end
981982
if !(f in vns_names)
982983
continue
983984
end
984-
push!(
985-
expr.args,
986-
quote
987-
f_vns = vi.metadata.$f.vns
988-
f_vns = filter_subsumed(vns.$f, f_vns)
989-
if !isempty(f_vns)
990-
if !istrans(vi, f_vns[1])
991-
# Iterate over all `f_vns` and transform
992-
for vn in f_vns
993-
f = internal_to_linked_internal_transform(vi, vn)
994-
_inner_transform!(vi, vn, f)
995-
settrans!!(vi, true, vn)
996-
end
997-
else
998-
@warn("[DynamicPPL] attempt to link a linked vi")
985+
push!(expr.args, quote
986+
f_vns = vi.metadata.$f.vns
987+
f_vns = filter_subsumed(vns.$f, f_vns)
988+
if !isempty(f_vns)
989+
if !istrans(vi, f_vns[1])
990+
# Iterate over all `f_vns` and transform
991+
for vn in f_vns
992+
f = internal_to_linked_internal_transform(vi, vn)
993+
_inner_transform!(vi, vn, f)
994+
settrans!!(vi, true, vn)
999995
end
996+
else
997+
@warn("[DynamicPPL] attempt to link a linked vi")
1000998
end
1001-
end,
1002-
)
999+
end
1000+
end)
10031001
end
10041002
return expr
10051003
end
@@ -1085,23 +1083,20 @@ end
10851083
continue
10861084
end
10871085

1088-
push!(
1089-
expr.args,
1090-
quote
1091-
f_vns = vi.metadata.$f.vns
1092-
f_vns = filter_subsumed(vns.$f, f_vns)
1093-
if istrans(vi, f_vns[1])
1094-
# Iterate over all `f_vns` and transform
1095-
for vn in f_vns
1096-
f = linked_internal_to_internal_transform(vi, vn)
1097-
_inner_transform!(vi, vn, f)
1098-
settrans!!(vi, false, vn)
1099-
end
1100-
else
1101-
@warn("[DynamicPPL] attempt to invlink an invlinked vi")
1086+
push!(expr.args, quote
1087+
f_vns = vi.metadata.$f.vns
1088+
f_vns = filter_subsumed(vns.$f, f_vns)
1089+
if istrans(vi, f_vns[1])
1090+
# Iterate over all `f_vns` and transform
1091+
for vn in f_vns
1092+
f = linked_internal_to_internal_transform(vi, vn)
1093+
_inner_transform!(vi, vn, f)
1094+
settrans!!(vi, false, vn)
11021095
end
1103-
end,
1104-
)
1096+
else
1097+
@warn("[DynamicPPL] attempt to invlink an invlinked vi")
1098+
end
1099+
end)
11051100
end
11061101
return expr
11071102
end
@@ -1774,23 +1769,20 @@ end
17741769
f_idcs = :(idcs.$f)
17751770
f_orders = :(metadata.$f.orders)
17761771
f_flags = :(metadata.$f.flags)
1777-
push!(
1778-
expr.args,
1779-
quote
1780-
# Set the flag for variables with symbol `f`
1781-
if num_produce == 0
1782-
for i in length($f_idcs):-1:1
1783-
$f_flags["del"][$f_idcs[i]] = true
1784-
end
1785-
else
1786-
for i in 1:length($f_orders)
1787-
if i in $f_idcs && $f_orders[i] > num_produce
1788-
$f_flags["del"][i] = true
1789-
end
1772+
push!(expr.args, quote
1773+
# Set the flag for variables with symbol `f`
1774+
if num_produce == 0
1775+
for i in length($f_idcs):-1:1
1776+
$f_flags["del"][$f_idcs[i]] = true
1777+
end
1778+
else
1779+
for i in 1:length($f_orders)
1780+
if i in $f_idcs && $f_orders[i] > num_produce
1781+
$f_flags["del"][i] = true
17901782
end
17911783
end
1792-
end,
1793-
)
1784+
end
1785+
end)
17941786
end
17951787
return expr
17961788
end

test/ad.jl

Lines changed: 9 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ using DynamicPPL: LogDensityFunction
7171
t = 1:0.05:8
7272
σ = 0.3
7373
y = @. rand(sin(t) + Normal(0, σ))
74-
@model function state_space(y, TT, ::Type{T}=Float64) where {T}
74+
@model function state_space(y, TT, (::Type{T})=Float64) where {T}
7575
# Priors
7676
α ~ Normal(y[1], 0.001)
7777
τ ~ Exponential(1)
@@ -94,9 +94,11 @@ using DynamicPPL: LogDensityFunction
9494
# overload assume so that model evaluation doesn't fail due to a lack
9595
# of implementation
9696
struct MyEmptyAlg end
97-
DynamicPPL.assume(
97+
function DynamicPPL.assume(
9898
::Random.AbstractRNG, ::DynamicPPL.Sampler{MyEmptyAlg}, dist, vn, vi
99-
) = DynamicPPL.assume(dist, vn, vi)
99+
)
100+
return DynamicPPL.assume(dist, vn, vi)
101+
end
100102

101103
# Compiling the ReverseDiff tape used to fail here
102104
spl = Sampler(MyEmptyAlg())
@@ -117,7 +119,7 @@ using DynamicPPL: LogDensityFunction
117119
return LogDensityProblems.logdensity_and_gradient(ldf, m[:])
118120
end
119121

120-
@model function scalar_matrix_model(::Type{T}=Float64) where {T<:Real}
122+
@model function scalar_matrix_model((::Type{T})=Float64) where {T<:Real}
121123
m = Matrix{T}(undef, 2, 3)
122124
return m ~ filldist(MvNormal(zeros(2), I), 3)
123125
end
@@ -126,14 +128,14 @@ using DynamicPPL: LogDensityFunction
126128
scalar_matrix_model, test_m, ref_adtype
127129
)
128130

129-
@model function matrix_model(::Type{T}=Matrix{Float64}) where {T}
131+
@model function matrix_model((::Type{T})=Matrix{Float64}) where {T}
130132
m = T(undef, 2, 3)
131133
return m ~ filldist(MvNormal(zeros(2), I), 3)
132134
end
133135

134136
matrix_model_reference = eval_logp_and_grad(matrix_model, test_m, ref_adtype)
135137

136-
@model function scalar_array_model(::Type{T}=Float64) where {T<:Real}
138+
@model function scalar_array_model((::Type{T})=Float64) where {T<:Real}
137139
m = Array{T}(undef, 2, 3)
138140
return m ~ filldist(MvNormal(zeros(2), I), 3)
139141
end
@@ -142,7 +144,7 @@ using DynamicPPL: LogDensityFunction
142144
scalar_array_model, test_m, ref_adtype
143145
)
144146

145-
@model function array_model(::Type{T}=Array{Float64}) where {T}
147+
@model function array_model((::Type{T})=Array{Float64}) where {T}
146148
m = T(undef, 2, 3)
147149
return m ~ filldist(MvNormal(zeros(2), I), 3)
148150
end

0 commit comments

Comments
 (0)