Skip to content

Commit e8b7944

Browse files
authored
Merge pull request #13 from TuringLang/py/dppl-0.36
DynamicPPL 0.36
2 parents bf95f95 + 9f3bf29 commit e8b7944

File tree

5 files changed

+39
-212
lines changed

5 files changed

+39
-212
lines changed

.github/workflows/generate_website.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -86,7 +86,6 @@ jobs:
8686
8787
collect-results:
8888
runs-on: ubuntu-latest
89-
if: github.event_name != 'pull_request'
9089
needs: [setup-keys, run-models]
9190

9291
steps:
@@ -106,3 +105,4 @@ jobs:
106105
with:
107106
github_token: ${{ secrets.GITHUB_TOKEN }}
108107
publish_dir: ./html
108+
destination_dir: ${{ github.event_name == 'pull_request' && 'pr' || '' }}

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1616
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1717

1818
[compat]
19-
DynamicPPL = "0.35"
19+
DynamicPPL = "0.36"

README.md

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,17 @@ You can edit it there.
2121

2222
Note that the links-to-existing-GitHub-issues in the table are also defined in this script.
2323

24+
## I want to see the HTML generated by a PR!
25+
26+
The latest workflow run across all PRs will be published to https://turinglang.org/ADTests/pr.
27+
28+
This is a bit messy, but works for now on the assumption that there aren't many PRs being worked on simultaneously.
29+
2430
## What's going on?
2531

2632
The workflow is the most complicated part of this repository.
2733
This section attempts to explain it from the 'bottom up'; if you prefer a 'top down' approach start by looking at the GitHub Action workflow, `.github/workflows/test.yml`.
2834

29-
Firstly, there is library code for running the benchmarks.
30-
This is in `lib.jl`; it should (in the near future) be put directly into DynamicPPL.jl.
31-
Until then, it has to live here.
32-
3335
Under the hood, the main thing that actually runs the AD tests / benchmarks is `main.jl`.
3436
You can run `julia --project=. main.jl` and it will print some usage information.
3537
However, it is the Python script `ad.py` that controls how this Julia script is called.

lib.jl

Lines changed: 0 additions & 186 deletions
This file was deleted.

main.jl

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import Test: @test, @testset
22
using DynamicPPL: DynamicPPL, VarInfo
3+
using DynamicPPL.TestUtils.AD: run_ad, ADResult, ADIncorrectException
34
using ADTypes
45
using Printf: @printf
56

@@ -26,11 +27,6 @@ ADTYPES = Dict(
2627
include("models.jl")
2728
using .Models: MODELS
2829

29-
# Benchmarking code is defined here. In time this will be put into DynamicPPL.
30-
# See https://github.com/TuringLang/DynamicPPL.jl/pull/882
31-
include("lib.jl")
32-
using .Lib: run_ad, ADIncorrectException
33-
3430
# The entry point to this script itself begins here
3531
if ARGS == ["--list-model-keys"]
3632
foreach(println, sort(collect(keys(MODELS))))
@@ -39,22 +35,37 @@ elseif ARGS == ["--list-adtype-keys"]
3935
elseif length(ARGS) == 3 && ARGS[1] == "--run"
4036
model, adtype = MODELS[ARGS[2]], ADTYPES[ARGS[3]]
4137

42-
if ARGS[2] == "control_flow"
43-
# https://github.com/TuringLang/ADTests/issues/4
44-
vi = DynamicPPL.unflatten(VarInfo(model), [0.5, -0.5])
45-
params = [-0.5, 0.5]
46-
result = run_ad(model, adtype; varinfo=vi, params=params, benchmark=true)
47-
else
48-
result = run_ad(model, adtype; benchmark=true)
49-
end
50-
51-
if isnothing(result.error)
38+
try
39+
if ARGS[2] == "control_flow"
40+
# https://github.com/TuringLang/ADTests/issues/4
41+
vi = DynamicPPL.unflatten(VarInfo(model), [0.5, -0.5])
42+
params = [-0.5, 0.5]
43+
result = run_ad(model, adtype; varinfo=vi, params=params, benchmark=true)
44+
else
45+
result = run_ad(model, adtype; benchmark=true)
46+
end
47+
# If reached here - nothing went wrong
5248
@printf("%.3f", result.time_vs_primal)
53-
elseif result.error isa ADIncorrectException
54-
println("wrong")
55-
else
56-
# some other error happened
57-
println("error")
49+
catch e
50+
if result.error isa ADIncorrectException
51+
# First check for completely incorrect ones
52+
for (a, b) in zip(result.grad_expected, result.grad_actual)
53+
if !isnan(a) && !isnan(b) && abs(a - b) > 1e-6
54+
println("wrong")
55+
exit()
56+
end
57+
end
58+
# If not, check for NaN's and report those
59+
if any(isnan, result.grad_expected) || any(isnan, result.grad_actual)
60+
println("NaN")
61+
else
62+
# Something else went wrong, shouldn't happen
63+
println("wrong")
64+
end
65+
else
66+
# Some other error, just say it's an error
67+
println("error")
68+
end
5869
end
5970
else
6071
println("Usage: julia main.jl --list-model-keys")

0 commit comments

Comments
 (0)