Skip to content

Commit e366079

Browse files
committed
Add always_use_return = true to format config
1 parent 2168a86 commit e366079

18 files changed

+44
-36
lines changed

.JuliaFormatter.toml

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +1,6 @@
11
style="blue"
22
format_markdown = true
3+
# The below should actually be part of Blue according to
4+
# https://github.com/JuliaDiff/BlueStyle?tab=readme-ov-file#method-definitions
5+
# but JuliaFormatter v2.10 doesn't enforce it.
6+
always_use_return = true

ext/DynamicPPLMCMCChainsExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -254,7 +254,7 @@ function DynamicPPL.returned(model::DynamicPPL.Model, chain_full::MCMCChains.Cha
254254
DynamicPPL.setval_and_resample!(varinfo, chain, sample_idx, chain_idx)
255255
# NOTE: Some of the varialbes can be a view into the `varinfo`, so we need to
256256
# `deepcopy` the `varinfo` before passing it to the `model`.
257-
model(deepcopy(varinfo))
257+
return model(deepcopy(varinfo))
258258
end
259259
end
260260

src/compiler.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -635,7 +635,7 @@ end
635635

636636
function namedtuple_from_splitargs(splitargs)
637637
names = map(splitargs) do (arg_name, arg_type, is_splat, default)
638-
is_splat ? Symbol("#splat#$(arg_name)") : arg_name
638+
return is_splat ? Symbol("#splat#$(arg_name)") : arg_name
639639
end
640640
names_expr = Expr(:tuple, map(QuoteNode, names)...)
641641
vals = Expr(:tuple, map(first, splitargs)...)

src/debug_utils.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -521,7 +521,7 @@ function has_static_constraints(
521521
rng::Random.AbstractRNG, model::Model; num_evals=5, kwargs...
522522
)
523523
results = map(1:num_evals) do _
524-
check_model_and_trace(rng, model; kwargs...)
524+
return check_model_and_trace(rng, model; kwargs...)
525525
end
526526
issuccess = all(first, results)
527527
issuccess || throw(ArgumentError("model check failed"))
@@ -530,7 +530,7 @@ function has_static_constraints(
530530
traces = map(last, results)
531531
dists_per_trace = map(distributions_in_trace, traces)
532532
transforms = map(dists_per_trace) do dists
533-
map(DynamicPPL.link_transform, dists)
533+
return map(DynamicPPL.link_transform, dists)
534534
end
535535

536536
# Check if the distributions are the same across all runs.

src/extract_priors.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ julia> length(extract_priors(rng, model)[@varname(x)])
106106
```
107107
"""
108108
function extract_priors(args::Union{Model,AbstractVarInfo}...)
109-
extract_priors(Random.default_rng(), args...)
109+
return extract_priors(Random.default_rng(), args...)
110110
end
111111
function extract_priors(rng::Random.AbstractRNG, model::Model)
112112
context = PriorExtractorContext(SamplingContext(rng))

src/logdensityfunction.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -248,7 +248,7 @@ By default, this just returns the input unchanged.
248248
function tweak_adtype(
249249
adtype::ADTypes.AbstractADType, ::Model, ::AbstractVarInfo, ::AbstractContext
250250
)
251-
adtype
251+
return adtype
252252
end
253253

254254
"""

src/model.jl

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -97,7 +97,7 @@ Return a `Model` which now treats variables on the right-hand side as observatio
9797
See [`condition`](@ref) for more information and examples.
9898
"""
9999
function Base.:|(model::Model, values::Union{Pair,Tuple,NamedTuple,AbstractDict{<:VarName}})
100-
condition(model, values)
100+
return condition(model, values)
101101
end
102102

103103
"""
@@ -1069,7 +1069,7 @@ function logjoint(model::Model, chain::AbstractMCMC.AbstractChains)
10691069
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
10701070
vn_parent in keys(var_info)
10711071
)
1072-
logjoint(model, argvals_dict)
1072+
return logjoint(model, argvals_dict)
10731073
end
10741074
end
10751075

@@ -1116,7 +1116,7 @@ function logprior(model::Model, chain::AbstractMCMC.AbstractChains)
11161116
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
11171117
vn_parent in keys(var_info)
11181118
)
1119-
logprior(model, argvals_dict)
1119+
return logprior(model, argvals_dict)
11201120
end
11211121
end
11221122

@@ -1163,7 +1163,7 @@ function Distributions.loglikelihood(model::Model, chain::AbstractMCMC.AbstractC
11631163
values_from_chain(var_info, vn_parent, chain, chain_idx, iteration_idx) for
11641164
vn_parent in keys(var_info)
11651165
)
1166-
loglikelihood(model, argvals_dict)
1166+
return loglikelihood(model, argvals_dict)
11671167
end
11681168
end
11691169

@@ -1469,5 +1469,5 @@ ERROR: ArgumentError: `~` with a model on the right-hand side of an observe stat
14691469
```
14701470
"""
14711471
function to_submodel(model::Model, auto_prefix::Bool=true)
1472-
to_sampleable(returned(model), auto_prefix)
1472+
return to_sampleable(returned(model), auto_prefix)
14731473
end

src/model_utils.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -204,6 +204,8 @@ function value_iterator_from_chain(vi::AbstractVarInfo, chain)
204204
return Iterators.map(
205205
Iterators.product(1:size(chain, 1), 1:size(chain, 3))
206206
) do (iteration_idx, chain_idx)
207-
values_from_chain!(vi, chain, chain_idx, iteration_idx, OrderedDict{VarName,Any}())
207+
return values_from_chain!(
208+
vi, chain, chain_idx, iteration_idx, OrderedDict{VarName,Any}()
209+
)
208210
end
209211
end

src/simple_varinfo.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -315,7 +315,7 @@ function Base.getindex(vi::SimpleVarInfo, vn::VarName, dist::Distribution)
315315
end
316316
function Base.getindex(vi::SimpleVarInfo, vns::Vector{<:VarName}, dist::Distribution)
317317
vals_linked = mapreduce(vcat, vns) do vn
318-
getindex(vi, vn, dist)
318+
return getindex(vi, vn, dist)
319319
end
320320
return recombine(dist, vals_linked, length(vns))
321321
end
@@ -362,7 +362,7 @@ function BangBang.setindex!!(vi::SimpleVarInfo{<:AbstractDict}, val, vn::VarName
362362
# Attempt to split into `parent` and `child` optic.
363363
parent, child, issuccess = splitoptic(getoptic(vn)) do optic
364364
o = optic === nothing ? identity : optic
365-
haskey(dict, VarName(vn, o))
365+
return haskey(dict, VarName(vn, o))
366366
end
367367
# When combined with `VarInfo`, `nothing` is equivalent to `identity`.
368368
keyoptic = parent === nothing ? identity : parent

src/test_utils/varinfo.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ function setup_varinfos(
5858
svi_vnv_ref,
5959
)) do vi
6060
# Set them all to the same values.
61-
DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
61+
return DynamicPPL.setlogp!!(update_values!!(vi, example_values, varnames), lp)
6262
end
6363

6464
if include_threadsafe

0 commit comments

Comments
 (0)