Skip to content

Commit 8201ee0

Browse files
committed
Format
1 parent aa81d07 commit 8201ee0

12 files changed

+33
-27
lines changed

main.jl

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ import Zygote
1313

1414
# AD backends to test.
1515
ADTYPES = Dict(
16-
"FiniteDifferences" => AutoFiniteDifferences(; fdm=central_fdm(5, 1)),
16+
"FiniteDifferences" => AutoFiniteDifferences(; fdm = central_fdm(5, 1)),
1717
"ForwardDiff" => AutoForwardDiff(),
18-
"ReverseDiff" => AutoReverseDiff(; compile=false),
19-
"ReverseDiffCompiled" => AutoReverseDiff(; compile=true),
20-
"Mooncake" => AutoMooncake(; config=nothing),
21-
"EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
22-
"EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
18+
"ReverseDiff" => AutoReverseDiff(; compile = false),
19+
"ReverseDiffCompiled" => AutoReverseDiff(; compile = true),
20+
"Mooncake" => AutoMooncake(; config = nothing),
21+
"EnzymeForward" => AutoEnzyme(; mode = set_runtime_activity(Forward, true)),
22+
"EnzymeReverse" => AutoEnzyme(; mode = set_runtime_activity(Reverse, true)),
2323
"Zygote" => AutoZygote(),
2424
)
2525

@@ -75,9 +75,9 @@ elseif length(ARGS) == 3 && ARGS[1] == "--run"
7575
# https://github.com/TuringLang/ADTests/issues/4
7676
vi = DynamicPPL.unflatten(VarInfo(model), [0.5, -0.5])
7777
params = [-0.5, 0.5]
78-
result = run_ad(model, adtype; varinfo=vi, params=params, benchmark=true)
78+
result = run_ad(model, adtype; varinfo = vi, params = params, benchmark = true)
7979
else
80-
result = run_ad(model, adtype; benchmark=true)
80+
result = run_ad(model, adtype; benchmark = true)
8181
end
8282
# If reached here - nothing went wrong
8383
@printf("%.3f", result.time_vs_primal)

models/demo_assume_dot_observe.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
@model function demo_assume_dot_observe(x=[1.5, 2.0])
1+
@model function demo_assume_dot_observe(x = [1.5, 2.0])
22
# `assume` and `dot_observe`
33
s ~ InverseGamma(2, 3)
44
m ~ Normal(0, sqrt(s))
55
x .~ Normal(m, sqrt(s))
66

7-
return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
7+
return (; s = s, m = m, x = x, logp = getlogp(__varinfo__))
88
end
99

1010
@register demo_assume_dot_observe()

models/demo_assume_dot_observe_literal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
m ~ Normal(0, sqrt(s))
55
[1.5, 2.0] .~ Normal(m, sqrt(s))
66

7-
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
7+
return (; s = s, m = m, x = [1.5, 2.0], logp = getlogp(__varinfo__))
88
end
99

1010
@model demo_assume_dot_observe_literal()
Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
@model function demo_assume_matrix_observe_matrix_index(
2-
x=transpose([1.5 2.0;]), ::Type{TV}=Array{Float64}
2+
x = transpose([1.5 2.0;]),
3+
::Type{TV} = Array{Float64},
34
) where {TV}
45
n = length(x)
56
d = n ÷ 2
@@ -9,7 +10,7 @@
910

1011
x[:, 1] ~ MvNormal(m, Diagonal(s_vec))
1112

12-
return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
13+
return (; s = s, m = m, x = x, logp = getlogp(__varinfo__))
1314
end
1415

1516
@register demo_assume_matrix_observe_matrix_index()
Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
@model function demo_assume_multivariate_observe(x=[1.5, 2.0])
1+
@model function demo_assume_multivariate_observe(x = [1.5, 2.0])
22
# Multivariate `assume` and `observe`
33
s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
44
m ~ MvNormal(zero(x), Diagonal(s))
55
x ~ MvNormal(m, Diagonal(s))
66

7-
return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
7+
return (; s = s, m = m, x = x, logp = getlogp(__varinfo__))
88
end
99

1010
@register demo_assume_multivariate_observe()

models/demo_assume_multivariate_observe_literal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
m ~ MvNormal(zeros(2), Diagonal(s))
55
[1.5, 2.0] ~ MvNormal(m, Diagonal(s))
66

7-
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
7+
return (; s = s, m = m, x = [1.5, 2.0], logp = getlogp(__varinfo__))
88
end
99

1010
@register demo_assume_multivariate_observe_literal()

models/demo_assume_observe_literal.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
1.5 ~ Normal(m, sqrt(s))
66
2.0 ~ Normal(m, sqrt(s))
77

8-
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
8+
return (; s = s, m = m, x = [1.5, 2.0], logp = getlogp(__varinfo__))
99
end
1010

1111
@register demo_assume_observe_literal()

models/demo_assume_submodel_observe_index_literal.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
@model function _prior_dot_assume(::Type{TV}=Vector{Float64}) where {TV}
1+
@model function _prior_dot_assume(::Type{TV} = Vector{Float64}) where {TV}
22
s = TV(undef, 2)
33
s .~ InverseGamma(2, 3)
44
m = TV(undef, 2)
@@ -13,7 +13,7 @@ end
1313
1.5 ~ Normal(m[1], sqrt(s[1]))
1414
2.0 ~ Normal(m[2], sqrt(s[2]))
1515

16-
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
16+
return (; s = s, m = m, x = [1.5, 2.0], logp = getlogp(__varinfo__))
1717
end
1818

1919
@register demo_assume_submodel_observe_index_literal()

models/demo_dot_assume_observe_index.jl

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
@model function demo_dot_assume_observe_index(
2-
x=[1.5, 2.0], ::Type{TV}=Vector{Float64}
2+
x = [1.5, 2.0],
3+
::Type{TV} = Vector{Float64},
34
) where {TV}
45
# `dot_assume` and `observe` with indexing
56
s = TV(undef, length(x))
@@ -10,7 +11,7 @@
1011
x[i] ~ Normal(m[i], sqrt(s[i]))
1112
end
1213

13-
return (; s=s, m=m, x=x, logp=getlogp(__varinfo__))
14+
return (; s = s, m = m, x = x, logp = getlogp(__varinfo__))
1415
end
1516

1617
@register demo_dot_assume_observe_index()
Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
@model function demo_dot_assume_observe_index_literal(::Type{TV}=Vector{Float64}) where {TV}
1+
@model function demo_dot_assume_observe_index_literal(
2+
::Type{TV} = Vector{Float64},
3+
) where {TV}
24
# `dot_assume` and literal `observe` with indexing
35
s = TV(undef, 2)
46
m = TV(undef, 2)
@@ -8,7 +10,7 @@
810
1.5 ~ Normal(m[1], sqrt(s[1]))
911
2.0 ~ Normal(m[2], sqrt(s[2]))
1012

11-
return (; s=s, m=m, x=[1.5, 2.0], logp=getlogp(__varinfo__))
13+
return (; s = s, m = m, x = [1.5, 2.0], logp = getlogp(__varinfo__))
1214
end
1315

1416
@register demo_dot_assume_observe_index_literal()

0 commit comments

Comments
 (0)