Skip to content

Commit 199ba71

Browse files
run JuliaFormatter
1 parent d77bbac commit 199ba71

10 files changed

+55
-44
lines changed

JuliaBUGS/ext/JuliaBUGSAdvancedHMCExt.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,7 @@ function _gibbs_internal_hmc(
4444
)
4545
# Create gradient model on-the-fly using DifferentiationInterface
4646
x = getparams(cond_model)
47-
prep = DI.prepare_gradient(
48-
_logdensity_switched, ad_backend, x, DI.Constant(cond_model)
49-
)
47+
prep = DI.prepare_gradient(_logdensity_switched, ad_backend, x, DI.Constant(cond_model))
5048
ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model)
5149
logdensitymodel = AbstractMCMC.LogDensityModel(ad_model)
5250

JuliaBUGS/ext/JuliaBUGSAdvancedMHExt.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,9 +56,7 @@ function _gibbs_internal_mh(
5656
)
5757
# Create gradient model on-the-fly using DifferentiationInterface
5858
x = getparams(cond_model)
59-
prep = DI.prepare_gradient(
60-
_logdensity_switched, ad_backend, x, DI.Constant(cond_model)
61-
)
59+
prep = DI.prepare_gradient(_logdensity_switched, ad_backend, x, DI.Constant(cond_model))
6260
ad_model = BUGSModelWithGradient(ad_backend, prep, cond_model)
6361
logdensitymodel = AbstractMCMC.LogDensityModel(ad_model)
6462

JuliaBUGS/ext/JuliaBUGSMCMCChainsExt.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,12 @@ module JuliaBUGSMCMCChainsExt
22

33
using AbstractMCMC
44
using JuliaBUGS
5-
using JuliaBUGS: BUGSModel, BUGSModelWithGradient, find_generated_quantities_variables, evaluate!!, getparams
5+
using JuliaBUGS:
6+
BUGSModel,
7+
BUGSModelWithGradient,
8+
find_generated_quantities_variables,
9+
evaluate!!,
10+
getparams
611
using JuliaBUGS.AbstractPPL
712
using JuliaBUGS.Accessors
813
using JuliaBUGS.LogDensityProblemsAD

JuliaBUGS/src/JuliaBUGS.jl

Lines changed: 11 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -320,14 +320,14 @@ function compile(
320320
),
321321
)
322322
base_model = BUGSModel(g, nonmissing_eval_env, model_def, data, initial_params)
323-
323+
324324
# If adtype provided, wrap with gradient capabilities
325325
if adtype !== nothing
326326
# Convert symbol to ADType if needed
327327
adtype_obj = _resolve_adtype(adtype)
328328
return _wrap_with_gradient(base_model, adtype_obj)
329329
end
330-
330+
331331
return base_model
332332
end
333333

@@ -344,17 +344,19 @@ Supported symbol shortcuts:
344344
"""
345345
function _resolve_adtype(adtype::Symbol)
346346
if adtype === :ReverseDiff
347-
return ADTypes.AutoReverseDiff(compile=true)
347+
return ADTypes.AutoReverseDiff(; compile=true)
348348
elseif adtype === :ForwardDiff
349349
return ADTypes.AutoForwardDiff()
350350
elseif adtype === :Zygote
351351
return ADTypes.AutoZygote()
352352
elseif adtype === :Enzyme
353353
return ADTypes.AutoEnzyme()
354354
else
355-
error("Unknown AD backend symbol: $adtype. " *
356-
"Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " *
357-
"Or use an ADTypes object like AutoReverseDiff(compile=true).")
355+
error(
356+
"Unknown AD backend symbol: $adtype. " *
357+
"Supported symbols: :ReverseDiff, :ForwardDiff, :Zygote, :Enzyme. " *
358+
"Or use an ADTypes object like AutoReverseDiff(compile=true).",
359+
)
358360
end
359361
end
360362

@@ -366,17 +368,13 @@ function _wrap_with_gradient(base_model::Model.BUGSModel, adtype::ADTypes.Abstra
366368
# Get initial parameters for preparation
367369
# Use invokelatest to handle world age issues with generated functions
368370
x = Base.invokelatest(getparams, base_model)
369-
371+
370372
# Prepare gradient using DifferentiationInterface
371373
# Use invokelatest to handle world age issues when calling logdensity during preparation
372374
prep = Base.invokelatest(
373-
DI.prepare_gradient,
374-
Model._logdensity_switched,
375-
adtype,
376-
x,
377-
DI.Constant(base_model)
375+
DI.prepare_gradient, Model._logdensity_switched, adtype, x, DI.Constant(base_model)
378376
)
379-
377+
380378
return Model.BUGSModelWithGradient(adtype, prep, base_model)
381379
end
382380
# function compile(

JuliaBUGS/src/model/logdensityproblems.jl

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -102,7 +102,11 @@ function LogDensityProblems.logdensity_and_gradient(
102102
)
103103
else
104104
return DI.value_and_gradient(
105-
_logdensity_switched, model.prep, model.backend, x, DI.Constant(model.base_model)
105+
_logdensity_switched,
106+
model.prep,
107+
model.backend,
108+
x,
109+
DI.Constant(model.base_model),
106110
)
107111
end
108112
end

JuliaBUGS/test/BUGSPrimitives/distributions.jl

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,9 @@ end
1515
A[1:2, 1:2] ~ dwish(B[:, :], 2)
1616
C[1:2] ~ dmnorm(mu[:], A[:, :])
1717
end
18-
ad_model = compile(model_def, (mu=[0, 0], B=[1 0; 0 1]), (A=[1 0; 0 1],); adtype=AutoReverseDiff())
18+
ad_model = compile(
19+
model_def, (mu=[0, 0], B=[1 0; 0 1]), (A=[1 0; 0 1],); adtype=AutoReverseDiff()
20+
)
1921

2022
theta = [
2123
0.7931743744870574,

JuliaBUGS/test/ext/JuliaBUGSAdvancedHMCExt.jl

Lines changed: 13 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
y = x[1] + x[3]
77
end
88
data = (mu=[0, 0], sigma=[1 0; 0 1])
9-
ad_model = compile(model_def, data; adtype=AutoReverseDiff(compile=true))
9+
ad_model = compile(model_def, data; adtype=AutoReverseDiff(; compile=true))
1010
n_samples, n_adapts = 10, 0
1111
D = LogDensityProblems.dimension(ad_model)
1212
initial_θ = rand(D)
@@ -34,19 +34,19 @@
3434
end
3535
end
3636
data = (N=5, y=[1.0, 2.0, 1.5, 2.5, 1.8])
37-
37+
3838
# Test that symbol shortcut works
3939
ad_model_symbol = compile(model_def, data; adtype=:ReverseDiff)
40-
ad_model_explicit = compile(model_def, data; adtype=AutoReverseDiff(compile=true))
41-
40+
ad_model_explicit = compile(model_def, data; adtype=AutoReverseDiff(; compile=true))
41+
4242
@test ad_model_symbol isa JuliaBUGS.Model.BUGSModelWithGradient
4343
@test ad_model_explicit isa JuliaBUGS.Model.BUGSModelWithGradient
44-
44+
4545
# Test that both produce equivalent results
4646
n_samples, n_adapts = 100, 100
4747
D = LogDensityProblems.dimension(ad_model_symbol)
4848
initial_θ = rand(StableRNG(123), D)
49-
49+
5050
samples_symbol = AbstractMCMC.sample(
5151
StableRNG(1234),
5252
ad_model_symbol,
@@ -58,7 +58,7 @@
5858
init_params=initial_θ,
5959
discard_initial=n_adapts,
6060
)
61-
61+
6262
samples_explicit = AbstractMCMC.sample(
6363
StableRNG(1234),
6464
ad_model_explicit,
@@ -70,18 +70,20 @@
7070
init_params=initial_θ,
7171
discard_initial=n_adapts,
7272
)
73-
73+
7474
# Results should be very similar (same RNG seed)
75-
@test summarize(samples_symbol)[:mu].nt.mean[1]
76-
summarize(samples_explicit)[:mu].nt.mean[1] rtol=0.1
75+
@test summarize(samples_symbol)[:mu].nt.mean[1]
76+
summarize(samples_explicit)[:mu].nt.mean[1] rtol = 0.1
7777
end
7878

7979
@testset "Inference results on examples: $example" for example in
8080
[:seeds, :rats, :stacks]
8181
(; model_def, data, inits, reference_results) = Base.getfield(
8282
JuliaBUGS.BUGSExamples, example
8383
)
84-
ad_model = JuliaBUGS.compile(model_def, data, inits; adtype=AutoReverseDiff(compile=true))
84+
ad_model = JuliaBUGS.compile(
85+
model_def, data, inits; adtype=AutoReverseDiff(; compile=true)
86+
)
8587

8688
n_samples, n_adapts = 1000, 1000
8789

JuliaBUGS/test/ext/JuliaBUGSMCMCChainsExt.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
y=[1.58, 4.80, 7.10, 8.86, 11.73, 14.52, 18.22, 18.73, 21.04, 22.93],
2727
)
2828

29-
ad_model = compile(model_def, data, (;); adtype=AutoReverseDiff(compile=true))
29+
ad_model = compile(model_def, data, (;); adtype=AutoReverseDiff(; compile=true))
3030
n_samples, n_adapts = 2000, 1000
3131

3232
D = LogDensityProblems.dimension(ad_model)
@@ -106,7 +106,7 @@
106106
sigma[2] ~ InverseGamma(2, 3)
107107
sigma[3] ~ InverseGamma(2, 3)
108108
end
109-
ad_model = compile(model_def, (;); adtype=AutoReverseDiff(compile=true))
109+
ad_model = compile(model_def, (;); adtype=AutoReverseDiff(; compile=true))
110110
hmc_chain = AbstractMCMC.sample(
111111
ad_model, NUTS(0.8), 10; progress=false, chain_type=Chains
112112
)

JuliaBUGS/test/model/bugsmodel.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -414,22 +414,26 @@ end
414414
# Test :ReverseDiff shortcut
415415
model_rd = compile(model_def, data; adtype=:ReverseDiff)
416416
@test model_rd isa JuliaBUGS.Model.BUGSModelWithGradient
417-
417+
418418
# Test equivalence with explicit ADType
419-
model_explicit = compile(model_def, data; adtype=AutoReverseDiff(compile=true))
419+
model_explicit = compile(
420+
model_def, data; adtype=AutoReverseDiff(; compile=true)
421+
)
420422
@test model_explicit isa JuliaBUGS.Model.BUGSModelWithGradient
421-
423+
422424
# Test that unknown symbol throws error
423425
@test_throws ErrorException compile(model_def, data; adtype=:UnknownBackend)
424426
end
425427

426428
@testset "Explicit ADTypes" begin
427429
# Test with compile=true
428-
model_compile = compile(model_def, data; adtype=AutoReverseDiff(compile=true))
430+
model_compile = compile(model_def, data; adtype=AutoReverseDiff(; compile=true))
429431
@test model_compile isa JuliaBUGS.Model.BUGSModelWithGradient
430-
432+
431433
# Test with compile=false
432-
model_nocompile = compile(model_def, data; adtype=AutoReverseDiff(compile=false))
434+
model_nocompile = compile(
435+
model_def, data; adtype=AutoReverseDiff(; compile=false)
436+
)
433437
@test model_nocompile isa JuliaBUGS.Model.BUGSModelWithGradient
434438
end
435439

@@ -443,10 +447,10 @@ end
443447
@testset "Gradient computation" begin
444448
model = compile(model_def, data; adtype=:ReverseDiff)
445449
test_point = [0.0]
446-
450+
447451
# Test that gradient can be computed
448452
ℓ, grad = LogDensityProblems.logdensity_and_gradient(model, test_point)
449-
453+
450454
@testisa Real
451455
@test grad isa Vector
452456
@test length(grad) == 1

JuliaBUGS/test/parallel_sampling.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
inits = (mu=0.0, tau=1.0)
2121

2222
# Use compile=false for thread safety with ReverseDiff
23-
ad_model = compile(model_def, data, inits; adtype=AutoReverseDiff(compile=false))
23+
ad_model = compile(model_def, data, inits; adtype=AutoReverseDiff(; compile=false))
2424

2525
# Single chain reference
2626
n_samples = 200

0 commit comments

Comments
 (0)