Skip to content

Commit 1f4206d

Browse files
committed
Revert "Revert today's commits"
This reverts commit 78a2913.
1 parent 78a2913 commit 1f4206d

31 files changed

+218
-42
lines changed

.github/workflows/generate_website.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@ on:
77
# If _any_ of the modified files match this filter, it will trigger this
88
# workflow
99
paths:
10-
- '*.jl'
10+
- '**/*.jl'
1111
- '*.toml'
1212
- 'ad.py'
13+
- '.github/workflows/generate_website.yml'
1314
pull_request:
1415
paths:
15-
- '*.jl'
16+
- '**/*.jl'
1617
- '*.toml'
1718
- 'ad.py'
19+
- '.github/workflows/generate_website.yml'
1820
workflow_dispatch:
1921

2022
permissions:

.github/workflows/refresh_website.yml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,16 @@ on:
77
# If _all_ the modified files match this filter, it won't trigger this
88
# workflow
99
paths-ignore:
10-
- '*.jl'
10+
- '**/*.jl'
1111
- '*.toml'
1212
- 'ad.py'
13+
- '.github/workflows/refresh_website.yml'
1314
pull_request:
1415
paths:
15-
- '*.jl'
16+
- '**/*.jl'
1617
- '*.toml'
1718
- 'ad.py'
19+
- '.github/workflows/refresh_website.yml'
1820
workflow_dispatch:
1921

2022
permissions:

ad.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -89,8 +89,12 @@ def run_ad(args):
8989
try:
9090
output = run_and_capture([*RUN_JULIA_COMMAND, "--run", model_key, adtype])
9191
result = try_float(output.splitlines()[-1])
92+
if not isinstance(result, float):
93+
print(f"Output: {output}")
9294
except sp.CalledProcessError as e:
93-
print(f"Error running {model_key} with {adtype}. Output: {e.output}")
95+
# Julia crashed
96+
print(f"Julia crashed when running {model_key} with {adtype}.")
97+
print(f"To reproduce, run: `julia --project=. main.jl --run {model_key} {adtype}`")
9498
result = "error"
9599

96100
print(f" ... {model_key} with {adtype} ==> {result}")

main.jl

Lines changed: 23 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@ import Zygote
1313

1414
# AD backends to test.
1515
ADTYPES = Dict(
16-
"FiniteDifferences" => AutoFiniteDifferences(; fdm=central_fdm(5, 1)),
16+
"FiniteDifferences" => AutoFiniteDifferences(; fdm = central_fdm(5, 1)),
1717
"ForwardDiff" => AutoForwardDiff(),
18-
"ReverseDiff" => AutoReverseDiff(; compile=false),
19-
"ReverseDiffCompiled" => AutoReverseDiff(; compile=true),
20-
"Mooncake" => AutoMooncake(; config=nothing),
21-
"EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
22-
"EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
18+
"ReverseDiff" => AutoReverseDiff(; compile = false),
19+
"ReverseDiffCompiled" => AutoReverseDiff(; compile = true),
20+
"Mooncake" => AutoMooncake(; config = nothing),
21+
"EnzymeForward" => AutoEnzyme(; mode = set_runtime_activity(Forward, true)),
22+
"EnzymeReverse" => AutoEnzyme(; mode = set_runtime_activity(Reverse, true)),
2323
"Zygote" => AutoZygote(),
2424
)
2525

@@ -35,7 +35,7 @@ end
3535
# These imports tend to get used a lot in models
3636
using DynamicPPL: @model, to_submodel
3737
using Distributions
38-
using LinearAlgebra: I
38+
using LinearAlgebra
3939

4040
include("models/assume_dirichlet.jl")
4141
include("models/assume_lkjcholu.jl")
@@ -44,8 +44,21 @@ include("models/assume_normal.jl")
4444
include("models/assume_submodel.jl")
4545
include("models/assume_wishart.jl")
4646
include("models/control_flow.jl")
47-
include("models/dot_assume_observe_index.jl")
47+
include("models/demo_assume_dot_observe_literal.jl")
48+
include("models/demo_assume_dot_observe.jl")
49+
include("models/demo_assume_index_observe.jl")
50+
include("models/demo_assume_matrix_observe_matrix_index.jl")
51+
include("models/demo_assume_multivariate_observe_literal.jl")
52+
include("models/demo_assume_multivariate_observe.jl")
53+
include("models/demo_assume_observe_literal.jl")
54+
include("models/demo_assume_submodel_observe_index_literal.jl")
55+
include("models/demo_dot_assume_observe_index_literal.jl")
56+
include("models/demo_dot_assume_observe_index.jl")
57+
include("models/demo_dot_assume_observe_matrix_index.jl")
58+
include("models/demo_dot_assume_observe_submodel.jl")
59+
include("models/demo_dot_assume_observe.jl")
4860
include("models/dot_assume.jl")
61+
include("models/demo_dot_assume_observe.jl")
4962
include("models/dot_observe.jl")
5063
include("models/dynamic_constraint.jl")
5164
include("models/multiple_constraints_same_var.jl")
@@ -74,9 +87,9 @@ elseif length(ARGS) == 3 && ARGS[1] == "--run"
7487
# https://github.com/TuringLang/ADTests/issues/4
7588
vi = DynamicPPL.unflatten(VarInfo(model), [0.5, -0.5])
7689
params = [-0.5, 0.5]
77-
result = run_ad(model, adtype; varinfo=vi, params=params, benchmark=true)
90+
result = run_ad(model, adtype; varinfo = vi, params = params, benchmark = true)
7891
else
79-
result = run_ad(model, adtype; benchmark=true)
92+
result = run_ad(model, adtype; benchmark = true)
8093
end
8194
# If reached here - nothing went wrong
8295
@printf("%.3f", result.time_vs_primal)

models/demo_assume_dot_observe.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
@model function demo_assume_dot_observe(x = [1.5, 2.0])
2+
# `assume` and `dot_observe`
3+
s ~ InverseGamma(2, 3)
4+
m ~ Normal(0, sqrt(s))
5+
x .~ Normal(m, sqrt(s))
6+
end
7+
8+
@register demo_assume_dot_observe()
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
@model function demo_assume_dot_observe_literal()
2+
# `assume` and literal `dot_observe`
3+
s ~ InverseGamma(2, 3)
4+
m ~ Normal(0, sqrt(s))
5+
[1.5, 2.0] .~ Normal(m, sqrt(s))
6+
end
7+
8+
@register demo_assume_dot_observe_literal()
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
@model function demo_assume_index_observe(
2+
x = [1.5, 2.0],
3+
::Type{TV} = Vector{Float64},
4+
) where {TV}
5+
# `assume` with indexing and `observe`
6+
s = TV(undef, length(x))
7+
for i in eachindex(s)
8+
s[i] ~ InverseGamma(2, 3)
9+
end
10+
m = TV(undef, length(x))
11+
for i in eachindex(m)
12+
m[i] ~ Normal(0, sqrt(s[i]))
13+
end
14+
x ~ MvNormal(m, Diagonal(s))
15+
end
16+
17+
@register demo_assume_index_observe()
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
@model function demo_assume_matrix_observe_matrix_index(
2+
x = transpose([1.5 2.0;]),
3+
::Type{TV} = Array{Float64},
4+
) where {TV}
5+
n = length(x)
6+
d = n ÷ 2
7+
s ~ reshape(product_distribution(fill(InverseGamma(2, 3), n)), d, 2)
8+
s_vec = vec(s)
9+
m ~ MvNormal(zeros(n), Diagonal(s_vec))
10+
11+
x[:, 1] ~ MvNormal(m, Diagonal(s_vec))
12+
end
13+
14+
@register demo_assume_matrix_observe_matrix_index()
Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
@model function demo_assume_multivariate_observe(x = [1.5, 2.0])
2+
# Multivariate `assume` and `observe`
3+
s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
4+
m ~ MvNormal(zero(x), Diagonal(s))
5+
x ~ MvNormal(m, Diagonal(s))
6+
end
7+
@register demo_assume_multivariate_observe()
Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
@model function demo_assume_multivariate_observe_literal()
2+
# multivariate `assume` and literal `observe`
3+
s ~ product_distribution([InverseGamma(2, 3), InverseGamma(2, 3)])
4+
m ~ MvNormal(zeros(2), Diagonal(s))
5+
[1.5, 2.0] ~ MvNormal(m, Diagonal(s))
6+
end
7+
8+
@register demo_assume_multivariate_observe_literal()

0 commit comments

Comments
 (0)