Skip to content

Commit ea66453

Browse files
authored
Merge pull request #35 from TuringLang/py/subsets
Add categories
2 parents 6e88b49 + c072b81 commit ea66453

File tree

10 files changed

+155
-87
lines changed

10 files changed

+155
-87
lines changed

.github/workflows/refresh_website.yml

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -32,10 +32,12 @@ jobs:
3232

3333
- name: Download results to web/src/data
3434
run: |
35-
curl -O https://raw.githubusercontent.com/TuringLang/ADTests/refs/heads/gh-pages/adtests.json
36-
curl -O https://raw.githubusercontent.com/TuringLang/ADTests/refs/heads/gh-pages/manifest.json
37-
curl -O https://raw.githubusercontent.com/TuringLang/ADTests/refs/heads/gh-pages/model_definitions.json
35+
curl -O https://raw.githubusercontent.com/TuringLang/ADTests/refs/heads/gh-pages/${PR}adtests.json
36+
curl -O https://raw.githubusercontent.com/TuringLang/ADTests/refs/heads/gh-pages/${PR}manifest.json
37+
curl -O https://raw.githubusercontent.com/TuringLang/ADTests/refs/heads/gh-pages/${PR}model_definitions.json
3838
working-directory: web/src/data
39+
env:
40+
PR: ${{ github.event_name == 'pull_request' && 'pr/' || '' }}
3941

4042
# This isn't needed to build the website, it's just there so that the
4143
# JSON is easily accessible on the gh-pages branch

README.md

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,11 +13,30 @@ You can modify the list of AD types in `main.jl`.
1313
## I want to add more models!
1414

1515
You can modify the list of models by adding a new file inside the `models` directory.
16-
This file should contain the model definition, and call the `@register` macro to register the model with the `ADTests` package.
17-
See the existing files in that directory for examples.
18-
Then, make sure to `include` the new source file in `main.jl`.
1916

20-
To make sure that the definition is included in the final website, you will have to make sure that the filename is the same as the model name (i.e. a model `@model function f()` is in `models/f.jl`).
17+
Inside this file, you do not need to call `using Turing` or any of the AD backends.
18+
However, you will have to make sure to import any other packages that your model uses.
19+
20+
This file should have, as the final line, the creation of the Turing model object using `model = model_f(...)`.
21+
(It is mandatory for the model object to be called `model`.)
22+
23+
Then, inside `main.jl`, call `@include_model category_heading model_name`.
24+
25+
- `category_heading` is a string that is used to determine which table the model appears under on the website.
26+
- For the automated tests to run properly, `model_name` **must** be consistent between the following:
27+
- The name of the model itself i.e. `@model function model_name(...)`
28+
- The filename i.e. `models/model_name.jl`
29+
- The name of the model in `main.jl` i.e. `@include_model "Category Heading" model_name`
30+
31+
Ideally, `model_name` would be self-explanatory, i.e. it would serve to illustrate exactly one feature and the name would indicate this.
32+
However, if necessary, you can add explanatory comments inside the model definition file.
33+
34+
You can see the existing files in that directory for examples.
35+
36+
> [!NOTE]
37+
> This setup does admittedly feel a bit complicated.
38+
> Unfortunately I could not find a simpler way to get all the components (Julia, Python, web app) to work together in an automated fashion.
39+
> Hopefully it is a small price to pay for the ability to just add a new model and have it be automatically included on the website.
2140
2241
## I want to edit the website!
2342

ad.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -83,6 +83,15 @@ def run_ad(args):
8383
else:
8484
RUN_JULIA_COMMAND = JULIA_COMMAND
8585

86+
# Get category
87+
try:
88+
category = run_and_capture([*RUN_JULIA_COMMAND, "--get-category", model_key])
89+
except sp.CalledProcessError as e:
90+
print(f"Julia crashed when getting category for {model_key}.")
91+
print(f"To reproduce, run: `julia --project=. main.jl --get-category {model_key}`")
92+
category = "error"
93+
results["__category__"] = category
94+
8695
# Run tests
8796
for adtype in adtypes:
8897
print(f"Running {model_key} with {adtype}...")
@@ -127,12 +136,14 @@ def html(_args):
127136
# [
128137
# {"model_name": "model1",
129138
# "results": {
139+
# "__category_": "category1",
130140
# "AD1": "result1",
131141
# "AD2": "result2"
132142
# }
133143
# },
134144
# {"model_name": "model2",
135145
# "results": {
146+
# "__category_": "category2",
136147
# "AD1": "result3",
137148
# "AD2": "result4"
138149
# }

main.jl

Lines changed: 68 additions & 66 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
1-
import Test: @test, @testset
21
using DynamicPPL: DynamicPPL, VarInfo
32
using DynamicPPL.TestUtils.AD: run_ad, ADResult, ADIncorrectException
43
using ADTypes
5-
using Printf: @printf
64

75
import FiniteDifferences: central_fdm
86
import ForwardDiff
@@ -13,30 +11,26 @@ import Zygote
1311

1412
# AD backends to test.
1513
ADTYPES = Dict(
16-
"FiniteDifferences" => AutoFiniteDifferences(; fdm = central_fdm(5, 1)),
14+
"FiniteDifferences" => AutoFiniteDifferences(; fdm=central_fdm(5, 1)),
1715
"ForwardDiff" => AutoForwardDiff(),
18-
"ReverseDiff" => AutoReverseDiff(; compile = false),
19-
"ReverseDiffCompiled" => AutoReverseDiff(; compile = true),
20-
"Mooncake" => AutoMooncake(; config = nothing),
21-
"EnzymeForward" => AutoEnzyme(; mode = set_runtime_activity(Forward, true)),
22-
"EnzymeReverse" => AutoEnzyme(; mode = set_runtime_activity(Reverse, true)),
16+
"ReverseDiff" => AutoReverseDiff(; compile=false),
17+
"ReverseDiffCompiled" => AutoReverseDiff(; compile=true),
18+
"Mooncake" => AutoMooncake(; config=nothing),
19+
"EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
20+
"EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
2321
"Zygote" => AutoZygote(),
2422
)
2523

26-
# Models to test. The convention is that:
27-
# a, b, c, ... are assumed variables
28-
# x, y, z, ... are observed variables
29-
# although it's hardly a big deal.
30-
MODELS = Dict{String,DynamicPPL.Model}()
31-
macro register(model)
32-
:(MODELS[string($(esc(model)).f)] = $(esc(model)))
24+
MODELS = Dict{String,Tuple{String,DynamicPPL.Model}}()
25+
macro register(category, model)
26+
:(MODELS[string($(esc(model)).f)] = ($(esc(category)), $(esc(model))))
3327
end
3428

3529
"""
36-
include_model(model_name::AbstractString)
30+
include_model(category::AbstractString, model_name::AbstractString)
3731
3832
Add the model defined in `models/model_name.jl` to the full list of models
39-
tested in this script.
33+
tested in this script. The model is registered under the given `category`.
4034
4135
We want to isolate every model in its own module, so that we avoid e.g.
4236
variable clashes, and also so that we make imports explicit.
@@ -57,17 +51,17 @@ definition file:
5751
- Once defined, the model is created using `model = model_name(...)`. The
5852
`model` on the left-hand side is mandatory.
5953
"""
60-
macro include_model(model_name::AbstractString)
54+
macro include_model(category::AbstractString, model_name::AbstractString)
6155
MODELS_TO_LOAD = get(ENV, "ADTESTS_MODELS_TO_LOAD", "__all__")
6256
if MODELS_TO_LOAD == "__all__" || model_name in split(MODELS_TO_LOAD, ",")
6357
# Declare a module containing the model. In principle esc() shouldn't
6458
# be needed, but see https://github.com/JuliaLang/julia/issues/55677
6559
Expr(:toplevel, esc(:(
6660
module $(gensym())
67-
using .Main: @register
68-
using Turing
69-
include("models/" * $(model_name) * ".jl")
70-
@register model
61+
using .Main: @register
62+
using Turing
63+
include("models/" * $(model_name) * ".jl")
64+
@register $(category) model
7165
end
7266
)))
7367
else
@@ -76,63 +70,70 @@ macro include_model(model_name::AbstractString)
7670
end
7771
end
7872

79-
@include_model "assume_beta"
80-
@include_model "assume_dirichlet"
81-
@include_model "assume_lkjcholu"
82-
@include_model "assume_mvnormal"
83-
@include_model "assume_normal"
84-
@include_model "assume_submodel"
85-
@include_model "assume_wishart"
86-
@include_model "broadcast_macro"
87-
@include_model "control_flow"
88-
@include_model "demo_assume_dot_observe"
89-
@include_model "demo_assume_dot_observe_literal"
90-
@include_model "demo_assume_index_observe"
91-
@include_model "demo_assume_matrix_observe_matrix_index"
92-
@include_model "demo_assume_multivariate_observe"
93-
@include_model "demo_assume_multivariate_observe_literal"
94-
@include_model "demo_assume_observe_literal"
95-
@include_model "demo_assume_submodel_observe_index_literal"
96-
@include_model "demo_dot_assume_observe"
97-
@include_model "demo_dot_assume_observe_index"
98-
@include_model "demo_dot_assume_observe_index_literal"
99-
@include_model "demo_dot_assume_observe_matrix_index"
100-
@include_model "demo_dot_assume_observe_submodel"
101-
@include_model "dot_assume"
102-
@include_model "dot_observe"
103-
@include_model "dynamic_constraint"
104-
@include_model "multiple_constraints_same_var"
105-
@include_model "multithreaded"
106-
@include_model "n010"
107-
@include_model "n050"
108-
@include_model "n100"
109-
@include_model "n500"
110-
@include_model "observe_bernoulli"
111-
@include_model "observe_categorical"
112-
@include_model "observe_index"
113-
@include_model "observe_literal"
114-
@include_model "observe_multivariate"
115-
@include_model "observe_submodel"
116-
@include_model "pdb_eight_schools_centered"
117-
@include_model "pdb_eight_schools_noncentered"
118-
@include_model "von_mises"
73+
# Models to test. The convention is that:
74+
# a, b, c, ... are assumed variables
75+
# x, y, z, ... are observed variables
76+
# although it's hardly a big deal.
77+
@include_model "Base Julia features" "control_flow"
78+
@include_model "Base Julia features" "multithreaded"
79+
@include_model "Core Turing syntax" "broadcast_macro"
80+
@include_model "Core Turing syntax" "dot_assume"
81+
@include_model "Core Turing syntax" "dot_observe"
82+
@include_model "Core Turing syntax" "dynamic_constraint"
83+
@include_model "Core Turing syntax" "multiple_constraints_same_var"
84+
@include_model "Core Turing syntax" "observe_index"
85+
@include_model "Core Turing syntax" "observe_literal"
86+
@include_model "Core Turing syntax" "observe_multivariate"
87+
@include_model "Core Turing syntax" "observe_submodel"
88+
@include_model "Distributions" "assume_beta"
89+
@include_model "Distributions" "assume_dirichlet"
90+
@include_model "Distributions" "assume_lkjcholu"
91+
@include_model "Distributions" "assume_mvnormal"
92+
@include_model "Distributions" "assume_normal"
93+
@include_model "Distributions" "assume_submodel"
94+
@include_model "Distributions" "assume_wishart"
95+
@include_model "Distributions" "observe_bernoulli"
96+
@include_model "Distributions" "observe_categorical"
97+
@include_model "Distributions" "observe_von_mises"
98+
@include_model "DynamicPPL demo models" "demo_assume_dot_observe"
99+
@include_model "DynamicPPL demo models" "demo_assume_dot_observe_literal"
100+
@include_model "DynamicPPL demo models" "demo_assume_index_observe"
101+
@include_model "DynamicPPL demo models" "demo_assume_matrix_observe_matrix_index"
102+
@include_model "DynamicPPL demo models" "demo_assume_multivariate_observe"
103+
@include_model "DynamicPPL demo models" "demo_assume_multivariate_observe_literal"
104+
@include_model "DynamicPPL demo models" "demo_assume_observe_literal"
105+
@include_model "DynamicPPL demo models" "demo_assume_submodel_observe_index_literal"
106+
@include_model "DynamicPPL demo models" "demo_dot_assume_observe"
107+
@include_model "DynamicPPL demo models" "demo_dot_assume_observe_index"
108+
@include_model "DynamicPPL demo models" "demo_dot_assume_observe_index_literal"
109+
@include_model "DynamicPPL demo models" "demo_dot_assume_observe_matrix_index"
110+
@include_model "DynamicPPL demo models" "demo_dot_assume_observe_submodel"
111+
@include_model "Effect of model size" "n010"
112+
@include_model "Effect of model size" "n050"
113+
@include_model "Effect of model size" "n100"
114+
@include_model "Effect of model size" "n500"
115+
@include_model "PosteriorDB" "pdb_eight_schools_centered"
116+
@include_model "PosteriorDB" "pdb_eight_schools_noncentered"
119117

120118
# The entry point to this script itself begins here
121119
if ARGS == ["--list-model-keys"]
122120
foreach(println, sort(collect(keys(MODELS))))
123121
elseif ARGS == ["--list-adtype-keys"]
124122
foreach(println, sort(collect(keys(ADTYPES))))
123+
elseif length(ARGS) == 2 && ARGS[1] == "--get-category"
124+
println(MODELS[ARGS[2]][1])
125125
elseif length(ARGS) == 3 && ARGS[1] == "--run"
126-
model, adtype = MODELS[ARGS[2]], ADTYPES[ARGS[3]]
126+
model_name, adtype_name = ARGS[2], ARGS[3]
127+
model, adtype = MODELS[model_name][2], ADTYPES[adtype_name]
127128

128129
try
129-
if ARGS[2] == "control_flow"
130+
if model_name == "control_flow"
130131
# https://github.com/TuringLang/ADTests/issues/4
131132
vi = DynamicPPL.unflatten(VarInfo(model), [0.5, -0.5])
132133
params = [-0.5, 0.5]
133-
result = run_ad(model, adtype; varinfo = vi, params = params, benchmark = true)
134+
result = run_ad(model, adtype; varinfo=vi, params=params, benchmark=true)
134135
else
135-
result = run_ad(model, adtype; benchmark = true)
136+
result = run_ad(model, adtype; benchmark=true)
136137
end
137138
# If reached here - nothing went wrong
138139
println(result.time_vs_primal)
@@ -162,4 +163,5 @@ else
162163
println("Usage: julia main.jl --list-model-keys")
163164
println(" julia main.jl --list-adtype-keys")
164165
println(" julia main.jl --run <model> <adtype>")
166+
println(" julia main.jl --get-category <model>")
165167
end

models/observe_von_mises.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
@model function observe_von_mises(x)
2+
a ~ InverseGamma(2, 3)
3+
x ~ VonMises(0, a)
4+
end
5+
6+
model = observe_von_mises(0.4)

models/von_mises.jl

Lines changed: 0 additions & 6 deletions
This file was deleted.

web/src/App.svelte

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,30 @@
22
import data from "./data/adtests.json";
33
import modelDefinitions from "./data/model_definitions.json";
44
5+
// Parse data into nice JS objects.
6+
// Obviously, the nested strings are a bit ugly. From outer to inner, they are:
7+
// category -> model_name -> adtype -> result
8+
let categorisedData = new Map<
9+
string,
10+
Map<string, Map<string, string | number>>
11+
>();
12+
for (const [model_name, results] of Object.entries(data)) {
13+
let category = results.__category__;
14+
delete results.__category__;
15+
let resultsMap = new Map<string, string | number>();
16+
for (const [adtype, result] of Object.entries(results)) {
17+
resultsMap.set(adtype, result);
18+
}
19+
if (!categorisedData.has(category)) {
20+
categorisedData.set(
21+
category,
22+
new Map<string, Map<string, string | number>>(),
23+
);
24+
}
25+
categorisedData.get(category).set(model_name, resultsMap);
26+
}
27+
console.log(categorisedData);
28+
529
import Manifest from "./lib/Manifest.svelte";
630
import ResultsTable from "./lib/ResultsTable.svelte";
731
</script>
@@ -78,7 +102,10 @@
78102
>Download the raw data (JSON)</a
79103
>
80104
</p>
81-
<ResultsTable {data} {modelDefinitions} />
105+
{#each categorisedData.entries() as [category, modelData]}
106+
<h3>{category}</h3>
107+
<ResultsTable data={modelData} {modelDefinitions} />
108+
{/each}
82109

83110
<h2>Manifest</h2>
84111
<p>The tests above were run with the following package versions:</p>

web/src/lib/Manifest.svelte

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,11 @@
11
<script lang="ts">
2-
import manifest from "../data/manifest.json";
2+
import manifestObj from "../data/manifest.json";
3+
4+
// convert manifest to a Map
5+
let manifest = new Map<string, string | null>();
6+
for (const [packageName, version] of Object.entries(manifestObj)) {
7+
manifest.set(packageName, version === "" ? null : version);
8+
}
39
410
import { getSortedEntries } from "./utils";
511
</script>

web/src/lib/ResultsTable.svelte

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@
55
import { getSortedEntries } from "./utils";
66
77
interface Props {
8-
data: object;
8+
// model name -> adtype -> result
9+
data: Map<string, Map<string, string | number>>;
910
modelDefinitions: object;
1011
}
1112
const { data, modelDefinitions }: Props = $props();
1213
13-
const models = Object.keys(data);
14-
const adtypes = Object.keys(data[models[0]]);
14+
const models = [...data.keys()];
15+
const adtypes = data.get(models[0]).keys();
1516
1617
// Known errors
1718
const ENZYME_FWD_BLAS = "https://github.com/EnzymeAD/Enzyme.jl/issues/1995";

web/src/lib/utils.ts

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
1-
export function getSortedEntries(obj: object) {
2-
return Object.entries(obj).sort(([a, _x], [b, _y]) =>
3-
a.localeCompare(b),
1+
export function getSortedEntries(m: Map<string, any>) {
2+
return [...m.entries()].sort(([k1, _v1], [k2, _v2]) =>
3+
k1.localeCompare(k2),
44
);
55
}

0 commit comments

Comments
 (0)