@@ -3,6 +3,9 @@ module SGHMCTests
3
3
using .. Models: gdemo_default
4
4
using .. NumericalTests: check_gdemo
5
5
import .. ADUtils
6
+ using DynamicPPL. TestUtils. AD: run_ad
7
+ using DynamicPPL. TestUtils: DEMO_MODELS
8
+ using DynamicPPL: DynamicPPL
6
9
using Distributions: sample
7
10
import ForwardDiff
8
11
using LinearAlgebra: dot
@@ -12,6 +15,21 @@ import Mooncake
12
15
using Test: @test , @testset
13
16
using Turing
14
17
18
+ @testset " AD with SGHMC / SGLD" begin
19
+ @testset " adtype=$adtype " for adtype in ADUtils. adbackends
20
+ @testset " alg=$alg " for alg in [
21
+ SGHMC (; learning_rate= 0.02 , momentum_decay= 0.5 , adtype= adtype),
22
+ SGLD (; stepsize= PolynomialStepsize (0.25 ), adtype= adtype),
23
+ ]
24
+ @testset " model=$(model. f) " for model in DEMO_MODELS
25
+ rng = StableRNG (123 )
26
+ ctx = DynamicPPL. SamplingContext (rng, DynamicPPL. Sampler (alg))
27
+ @test run_ad (model, adtype; context= ctx, test= true , benchmark= false ) isa Any
28
+ end
29
+ end
30
+ end
31
+ end
32
+
15
33
@testset " Testing sghmc.jl" begin
16
34
@testset " sghmc constructor" begin
17
35
alg = SGHMC (; learning_rate= 0.01 , momentum_decay= 0.1 , adtype= Turing. DEFAULT_ADTYPE)
0 commit comments