Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions .github/workflows/refresh_website.yml
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,12 @@ jobs:

- name: Download results to web/src/data
run: |
curl -O https://raw.githubusercontent.com/TuringLang/ADTests/refs/heads/gh-pages/adtests.json
curl -O https://raw.githubusercontent.com/TuringLang/ADTests/refs/heads/gh-pages/manifest.json
curl -O https://raw.githubusercontent.com/TuringLang/ADTests/refs/heads/gh-pages/model_definitions.json
curl -O https://raw.githubusercontent.com/TuringLang/ADTests/refs/heads/gh-pages/${PR}adtests.json
curl -O https://raw.githubusercontent.com/TuringLang/ADTests/refs/heads/gh-pages/${PR}manifest.json
curl -O https://raw.githubusercontent.com/TuringLang/ADTests/refs/heads/gh-pages/${PR}model_definitions.json
working-directory: web/src/data
env:
PR: ${{ github.event_name == 'pull_request' && 'pr/' || '' }}

# This isn't needed to build the website, it's just there so that the
# JSON is easily accessible on the gh-pages branch
Expand Down
27 changes: 23 additions & 4 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -13,11 +13,30 @@ You can modify the list of AD types in `main.jl`.
## I want to add more models!

You can modify the list of models by adding a new file inside the `models` directory.
This file should contain the model definition, and call the `@register` macro to register the model with the `ADTests` package.
See the existing files in that directory for examples.
Then, make sure to `include` the new source file in `main.jl`.

To make sure that the definition is included in the final website, you will have to make sure that the filename is the same as the model name (i.e. a model `@model function f()` is in `models/f.jl`).
Inside this file, you do not need to call `using Turing` or any of the AD backends.
However, you will have to make sure to import any other packages that your model uses.

This file should have, as the final line, the creation of the Turing model object using `model = model_f(...)`.
(It is mandatory for the model object to be called `model`.)

Then, inside `main.jl`, call `@include_model category_heading model_name`.

- `category_heading` is a string that is used to determine which table the model appears under on the website.
- For the automated tests to run properly, `model_name` **must** be consistent between the following:
- The name of the model itself i.e. `@model function model_name(...)`
- The filename i.e. `models/model_name.jl`
- The name of the model in `main.jl` i.e. `@include_model "Category Heading" model_name`

Ideally, `model_name` would be self-explanatory, i.e. it would serve to illustrate exactly one feature and the name would indicate this.
However, if necessary, you can add explanatory comments inside the model definition file.

You can see the existing files in that directory for examples.

> [!NOTE]
> This setup does admittedly feel a bit complicated.
> Unfortunately I could not find a simpler way to get all the components (Julia, Python, web app) to work together in an automated fashion.
> Hopefully it is a small price to pay for the ability to just add a new model and have it be automatically included on the website.

## I want to edit the website!

Expand Down
11 changes: 11 additions & 0 deletions ad.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,6 +83,15 @@ def run_ad(args):
else:
RUN_JULIA_COMMAND = JULIA_COMMAND

# Get category
try:
category = run_and_capture([*RUN_JULIA_COMMAND, "--get-category", model_key])
except sp.CalledProcessError as e:
print(f"Julia crashed when getting category for {model_key}.")
print(f"To reproduce, run: `julia --project=. main.jl --get-category {model_key}`")
category = "error"
results["__category__"] = category

# Run tests
for adtype in adtypes:
print(f"Running {model_key} with {adtype}...")
Expand Down Expand Up @@ -127,12 +136,14 @@ def html(_args):
# [
# {"model_name": "model1",
# "results": {
# "__category_": "category1",
# "AD1": "result1",
# "AD2": "result2"
# }
# },
# {"model_name": "model2",
# "results": {
# "__category_": "category2",
# "AD1": "result3",
# "AD2": "result4"
# }
Expand Down
134 changes: 68 additions & 66 deletions main.jl
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
import Test: @test, @testset
using DynamicPPL: DynamicPPL, VarInfo
using DynamicPPL.TestUtils.AD: run_ad, ADResult, ADIncorrectException
using ADTypes
using Printf: @printf

import FiniteDifferences: central_fdm
import ForwardDiff
Expand All @@ -13,30 +11,26 @@ import Zygote

# AD backends to test.
ADTYPES = Dict(
"FiniteDifferences" => AutoFiniteDifferences(; fdm = central_fdm(5, 1)),
"FiniteDifferences" => AutoFiniteDifferences(; fdm=central_fdm(5, 1)),
"ForwardDiff" => AutoForwardDiff(),
"ReverseDiff" => AutoReverseDiff(; compile = false),
"ReverseDiffCompiled" => AutoReverseDiff(; compile = true),
"Mooncake" => AutoMooncake(; config = nothing),
"EnzymeForward" => AutoEnzyme(; mode = set_runtime_activity(Forward, true)),
"EnzymeReverse" => AutoEnzyme(; mode = set_runtime_activity(Reverse, true)),
"ReverseDiff" => AutoReverseDiff(; compile=false),
"ReverseDiffCompiled" => AutoReverseDiff(; compile=true),
"Mooncake" => AutoMooncake(; config=nothing),
"EnzymeForward" => AutoEnzyme(; mode=set_runtime_activity(Forward, true)),
"EnzymeReverse" => AutoEnzyme(; mode=set_runtime_activity(Reverse, true)),
"Zygote" => AutoZygote(),
)

# Models to test. The convention is that:
# a, b, c, ... are assumed variables
# x, y, z, ... are observed variables
# although it's hardly a big deal.
MODELS = Dict{String,DynamicPPL.Model}()
macro register(model)
:(MODELS[string($(esc(model)).f)] = $(esc(model)))
MODELS = Dict{String,Tuple{String,DynamicPPL.Model}}()
macro register(category, model)
:(MODELS[string($(esc(model)).f)] = ($(esc(category)), $(esc(model))))
end

"""
include_model(model_name::AbstractString)
include_model(category::AbstractString, model_name::AbstractString)

Add the model defined in `models/model_name.jl` to the full list of models
tested in this script.
tested in this script. The model is registered under the given `category`.

We want to isolate every model in its own module, so that we avoid e.g.
variable clashes, and also so that we make imports explicit.
Expand All @@ -57,17 +51,17 @@ definition file:
- Once defined, the model is created using `model = model_name(...)`. The
`model` on the left-hand side is mandatory.
"""
macro include_model(model_name::AbstractString)
macro include_model(category::AbstractString, model_name::AbstractString)
MODELS_TO_LOAD = get(ENV, "ADTESTS_MODELS_TO_LOAD", "__all__")
if MODELS_TO_LOAD == "__all__" || model_name in split(MODELS_TO_LOAD, ",")
# Declare a module containing the model. In principle esc() shouldn't
# be needed, but see https://github.com/JuliaLang/julia/issues/55677
Expr(:toplevel, esc(:(
module $(gensym())
using .Main: @register
using Turing
include("models/" * $(model_name) * ".jl")
@register model
using .Main: @register
using Turing
include("models/" * $(model_name) * ".jl")
@register $(category) model
end
)))
else
Expand All @@ -76,63 +70,70 @@ macro include_model(model_name::AbstractString)
end
end

@include_model "assume_beta"
@include_model "assume_dirichlet"
@include_model "assume_lkjcholu"
@include_model "assume_mvnormal"
@include_model "assume_normal"
@include_model "assume_submodel"
@include_model "assume_wishart"
@include_model "broadcast_macro"
@include_model "control_flow"
@include_model "demo_assume_dot_observe"
@include_model "demo_assume_dot_observe_literal"
@include_model "demo_assume_index_observe"
@include_model "demo_assume_matrix_observe_matrix_index"
@include_model "demo_assume_multivariate_observe"
@include_model "demo_assume_multivariate_observe_literal"
@include_model "demo_assume_observe_literal"
@include_model "demo_assume_submodel_observe_index_literal"
@include_model "demo_dot_assume_observe"
@include_model "demo_dot_assume_observe_index"
@include_model "demo_dot_assume_observe_index_literal"
@include_model "demo_dot_assume_observe_matrix_index"
@include_model "demo_dot_assume_observe_submodel"
@include_model "dot_assume"
@include_model "dot_observe"
@include_model "dynamic_constraint"
@include_model "multiple_constraints_same_var"
@include_model "multithreaded"
@include_model "n010"
@include_model "n050"
@include_model "n100"
@include_model "n500"
@include_model "observe_bernoulli"
@include_model "observe_categorical"
@include_model "observe_index"
@include_model "observe_literal"
@include_model "observe_multivariate"
@include_model "observe_submodel"
@include_model "pdb_eight_schools_centered"
@include_model "pdb_eight_schools_noncentered"
@include_model "von_mises"
# Models to test. The convention is that:
# a, b, c, ... are assumed variables
# x, y, z, ... are observed variables
# although it's hardly a big deal.
@include_model "Base Julia features" "control_flow"
@include_model "Base Julia features" "multithreaded"
@include_model "Core Turing syntax" "broadcast_macro"
@include_model "Core Turing syntax" "dot_assume"
@include_model "Core Turing syntax" "dot_observe"
@include_model "Core Turing syntax" "dynamic_constraint"
@include_model "Core Turing syntax" "multiple_constraints_same_var"
@include_model "Core Turing syntax" "observe_index"
@include_model "Core Turing syntax" "observe_literal"
@include_model "Core Turing syntax" "observe_multivariate"
@include_model "Core Turing syntax" "observe_submodel"
@include_model "Distributions" "assume_beta"
@include_model "Distributions" "assume_dirichlet"
@include_model "Distributions" "assume_lkjcholu"
@include_model "Distributions" "assume_mvnormal"
@include_model "Distributions" "assume_normal"
@include_model "Distributions" "assume_submodel"
@include_model "Distributions" "assume_wishart"
@include_model "Distributions" "observe_bernoulli"
@include_model "Distributions" "observe_categorical"
@include_model "Distributions" "observe_von_mises"
@include_model "DynamicPPL demo models" "demo_assume_dot_observe"
@include_model "DynamicPPL demo models" "demo_assume_dot_observe_literal"
@include_model "DynamicPPL demo models" "demo_assume_index_observe"
@include_model "DynamicPPL demo models" "demo_assume_matrix_observe_matrix_index"
@include_model "DynamicPPL demo models" "demo_assume_multivariate_observe"
@include_model "DynamicPPL demo models" "demo_assume_multivariate_observe_literal"
@include_model "DynamicPPL demo models" "demo_assume_observe_literal"
@include_model "DynamicPPL demo models" "demo_assume_submodel_observe_index_literal"
@include_model "DynamicPPL demo models" "demo_dot_assume_observe"
@include_model "DynamicPPL demo models" "demo_dot_assume_observe_index"
@include_model "DynamicPPL demo models" "demo_dot_assume_observe_index_literal"
@include_model "DynamicPPL demo models" "demo_dot_assume_observe_matrix_index"
@include_model "DynamicPPL demo models" "demo_dot_assume_observe_submodel"
@include_model "Effect of model size" "n010"
@include_model "Effect of model size" "n050"
@include_model "Effect of model size" "n100"
@include_model "Effect of model size" "n500"
@include_model "PosteriorDB" "pdb_eight_schools_centered"
@include_model "PosteriorDB" "pdb_eight_schools_noncentered"

# The entry point to this script itself begins here
if ARGS == ["--list-model-keys"]
foreach(println, sort(collect(keys(MODELS))))
elseif ARGS == ["--list-adtype-keys"]
foreach(println, sort(collect(keys(ADTYPES))))
elseif length(ARGS) == 2 && ARGS[1] == "--get-category"
println(MODELS[ARGS[2]][1])
elseif length(ARGS) == 3 && ARGS[1] == "--run"
model, adtype = MODELS[ARGS[2]], ADTYPES[ARGS[3]]
model_name, adtype_name = ARGS[2], ARGS[3]
model, adtype = MODELS[model_name][2], ADTYPES[adtype_name]

try
if ARGS[2] == "control_flow"
if model_name == "control_flow"
# https://github.com/TuringLang/ADTests/issues/4
vi = DynamicPPL.unflatten(VarInfo(model), [0.5, -0.5])
params = [-0.5, 0.5]
result = run_ad(model, adtype; varinfo = vi, params = params, benchmark = true)
result = run_ad(model, adtype; varinfo=vi, params=params, benchmark=true)
else
result = run_ad(model, adtype; benchmark = true)
result = run_ad(model, adtype; benchmark=true)
end
# If reached here - nothing went wrong
println(result.time_vs_primal)
Expand Down Expand Up @@ -162,4 +163,5 @@ else
println("Usage: julia main.jl --list-model-keys")
println(" julia main.jl --list-adtype-keys")
println(" julia main.jl --run <model> <adtype>")
println(" julia main.jl --get-category <model>")
end
6 changes: 6 additions & 0 deletions models/observe_von_mises.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
@model function observe_von_mises(x)
a ~ InverseGamma(2, 3)
x ~ VonMises(0, a)
end

model = observe_von_mises(0.4)
6 changes: 0 additions & 6 deletions models/von_mises.jl

This file was deleted.

29 changes: 28 additions & 1 deletion web/src/App.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,30 @@
import data from "./data/adtests.json";
import modelDefinitions from "./data/model_definitions.json";

// Parse data into nice JS objects.
// Obviously, the nested strings are a bit ugly. From outer to inner, they are:
// category -> model_name -> adtype -> result
let categorisedData = new Map<
string,
Map<string, Map<string, string | number>>
>();
for (const [model_name, results] of Object.entries(data)) {
let category = results.__category__;
delete results.__category__;
let resultsMap = new Map<string, string | number>();
for (const [adtype, result] of Object.entries(results)) {
resultsMap.set(adtype, result);
}
if (!categorisedData.has(category)) {
categorisedData.set(
category,
new Map<string, Map<string, string | number>>(),
);
}
categorisedData.get(category).set(model_name, resultsMap);
}
console.log(categorisedData);

import Manifest from "./lib/Manifest.svelte";
import ResultsTable from "./lib/ResultsTable.svelte";
</script>
Expand Down Expand Up @@ -78,7 +102,10 @@
>Download the raw data (JSON)</a
>
</p>
<ResultsTable {data} {modelDefinitions} />
{#each categorisedData.entries() as [category, modelData]}
<h3>{category}</h3>
<ResultsTable data={modelData} {modelDefinitions} />
{/each}

<h2>Manifest</h2>
<p>The tests above were run with the following package versions:</p>
Expand Down
8 changes: 7 additions & 1 deletion web/src/lib/Manifest.svelte
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
<script lang="ts">
import manifest from "../data/manifest.json";
import manifestObj from "../data/manifest.json";

// convert manifest to a Map
let manifest = new Map<string, string | null>();
for (const [packageName, version] of Object.entries(manifestObj)) {
manifest.set(packageName, version === "" ? null : version);
}

import { getSortedEntries } from "./utils";
</script>
Expand Down
7 changes: 4 additions & 3 deletions web/src/lib/ResultsTable.svelte
Original file line number Diff line number Diff line change
Expand Up @@ -5,13 +5,14 @@
import { getSortedEntries } from "./utils";

interface Props {
data: object;
// model name -> adtype -> result
data: Map<string, Map<string, string | number>>;
modelDefinitions: object;
}
const { data, modelDefinitions }: Props = $props();

const models = Object.keys(data);
const adtypes = Object.keys(data[models[0]]);
const models = [...data.keys()];
const adtypes = data.get(models[0]).keys();

// Known errors
const ENZYME_FWD_BLAS = "https://github.com/EnzymeAD/Enzyme.jl/issues/1995";
Expand Down
6 changes: 3 additions & 3 deletions web/src/lib/utils.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
export function getSortedEntries(obj: object) {
return Object.entries(obj).sort(([a, _x], [b, _y]) =>
a.localeCompare(b),
export function getSortedEntries(m: Map<string, any>) {
return [...m.entries()].sort(([k1, _v1], [k2, _v2]) =>
k1.localeCompare(k2),
);
}
Loading