Skip to content

Commit 266d7ab

Browse files
committed
DynamicPPL 0.36
1 parent bf95f95 commit 266d7ab

File tree

4 files changed

+32
-211
lines changed

4 files changed

+32
-211
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,4 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"
1616
Zygote = "e88e6eb3-aa80-5325-afca-941959d7151f"
1717

1818
[compat]
19-
DynamicPPL = "0.35"
19+
DynamicPPL = "0.36"

README.md

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,10 +26,6 @@ Note that the links-to-existing-GitHub-issues in the table are also defined in t
2626
The workflow is the most complicated part of this repository.
2727
This section attempts to explain it from the 'bottom up'; if you prefer a 'top down' approach start by looking at the GitHub Action workflow, `.github/workflows/test.yml`.
2828

29-
Firstly, there is library code for running the benchmarks.
30-
This is in `lib.jl`; it should (in the near future) be put directly into DynamicPPL.jl.
31-
Until then, it has to live here.
32-
3329
Under the hood, the main thing that actually runs the AD tests / benchmarks is `main.jl`.
3430
You can run `julia --project=. main.jl` and it will print some usage information.
3531
However, it is the Python script `ad.py` that controls how this Julia script is called.

lib.jl

Lines changed: 0 additions & 186 deletions
This file was deleted.

main.jl

Lines changed: 31 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
import Test: @test, @testset
22
using DynamicPPL: DynamicPPL, VarInfo
3+
using DynamicPPL.TestUtils.AD: run_ad, ADResult, ADIncorrectException
34
using ADTypes
45
using Printf: @printf
56

@@ -26,11 +27,6 @@ ADTYPES = Dict(
2627
include("models.jl")
2728
using .Models: MODELS
2829

29-
# Benchmarking code is defined here. In time this will be put into DynamicPPL.
30-
# See https://github.com/TuringLang/DynamicPPL.jl/pull/882
31-
include("lib.jl")
32-
using .Lib: run_ad, ADIncorrectException
33-
3430
# The entry point to this script itself begins here
3531
if ARGS == ["--list-model-keys"]
3632
foreach(println, sort(collect(keys(MODELS))))
@@ -39,22 +35,37 @@ elseif ARGS == ["--list-adtype-keys"]
3935
elseif length(ARGS) == 3 && ARGS[1] == "--run"
4036
model, adtype = MODELS[ARGS[2]], ADTYPES[ARGS[3]]
4137

42-
if ARGS[2] == "control_flow"
43-
# https://github.com/TuringLang/ADTests/issues/4
44-
vi = DynamicPPL.unflatten(VarInfo(model), [0.5, -0.5])
45-
params = [-0.5, 0.5]
46-
result = run_ad(model, adtype; varinfo=vi, params=params, benchmark=true)
47-
else
48-
result = run_ad(model, adtype; benchmark=true)
49-
end
50-
51-
if isnothing(result.error)
38+
try
39+
if ARGS[2] == "control_flow"
40+
# https://github.com/TuringLang/ADTests/issues/4
41+
vi = DynamicPPL.unflatten(VarInfo(model), [0.5, -0.5])
42+
params = [-0.5, 0.5]
43+
result = run_ad(model, adtype; varinfo=vi, params=params, benchmark=true)
44+
else
45+
result = run_ad(model, adtype; benchmark=true)
46+
end
47+
# If reached here - nothing went wrong
5248
@printf("%.3f", result.time_vs_primal)
53-
elseif result.error isa ADIncorrectException
54-
println("wrong")
55-
else
56-
# some other error happened
57-
println("error")
49+
catch e
50+
if result.error isa ADIncorrectException
51+
# First check for completely incorrect ones
52+
for (a, b) in zip(result.grad_expected, result.grad_actual)
53+
if !isnan(a) && !isnan(b) && abs(a - b) > 1e-6
54+
println("wrong")
55+
exit()
56+
end
57+
end
58+
# If not, check for NaN's and report those
59+
if any(isnan, result.grad_expected) || any(isnan, result.grad_actual)
60+
println("NaN")
61+
else
62+
# Something else went wrong, shouldn't happen
63+
println("wrong")
64+
end
65+
else
66+
# Some other error, just say it's an error
67+
println("error")
68+
end
5869
end
5970
else
6071
println("Usage: julia main.jl --list-model-keys")

0 commit comments

Comments
 (0)