Skip to content

Commit df85c6f

Browse files
committed
Add AD tests to sghmc.jl
1 parent d8f6dfa commit df85c6f

File tree

1 file changed

+18
-0
lines changed

1 file changed

+18
-0
lines changed

test/mcmc/sghmc.jl

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,9 @@ module SGHMCTests
33
using ..Models: gdemo_default
44
using ..NumericalTests: check_gdemo
55
import ..ADUtils
6+
using DynamicPPL.TestUtils.AD: run_ad
7+
using DynamicPPL.TestUtils: DEMO_MODELS
8+
using DynamicPPL: DynamicPPL
69
using Distributions: sample
710
import ForwardDiff
811
using LinearAlgebra: dot
@@ -12,6 +15,21 @@ import Mooncake
1215
using Test: @test, @testset
1316
using Turing
1417

18+
@testset "AD with SGHMC / SGLD" begin
19+
@testset "adtype=$adtype" for adtype in ADUtils.adbackends
20+
@testset "alg=$alg" for alg in [
21+
SGHMC(; learning_rate=0.02, momentum_decay=0.5, adtype=adtype),
22+
SGLD(; stepsize=PolynomialStepsize(0.25), adtype=adtype),
23+
]
24+
@testset "model=$(model.f)" for model in DEMO_MODELS
25+
rng = StableRNG(123)
26+
ctx = DynamicPPL.SamplingContext(rng, DynamicPPL.Sampler(alg))
27+
@test run_ad(model, adtype; context=ctx, test=true, benchmark=false) isa Any
28+
end
29+
end
30+
end
31+
end
32+
1533
@testset "Testing sghmc.jl" begin
1634
@testset "sghmc constructor" begin
1735
alg = SGHMC(; learning_rate=0.01, momentum_decay=0.1, adtype=Turing.DEFAULT_ADTYPE)

0 commit comments

Comments
 (0)