@@ -6,6 +6,28 @@ using NormalizingFlows
6
6
using Bijectors, CUDA, Distributions, Flux, LinearAlgebra, Test
7
7
8
8
@testset " rand with CUDA" begin
9
+
10
+ # Bijectors versions use dot for broadcasting, which causes issues with CUDA.
11
+ function Bijectors. get_u_hat (u:: CuVector{T} , w:: CuVector{T} ) where {T<: Real }
12
+ wT_u = dot (w, u)
13
+ scale = (Bijectors. LogExpFunctions. log1pexp (- wT_u) - 1 ) / sum (abs2, w)
14
+ û = CUDA. broadcast (+ , u, CUDA. broadcast (* , scale, w))
15
+ wT_û = Bijectors. LogExpFunctions. log1pexp (wT_u) - 1
16
+ return û, wT_û
17
+ end
18
+ function Bijectors. _transform (flow:: PlanarLayer , z:: CuArray{T} ) where {T<: Real }
19
+ w = CuArray (flow. w)
20
+ b = T (first (flow. b)) # Scalar
21
+
22
+ û, wT_û = Bijectors. get_u_hat (CuArray (flow. u), w)
23
+ wT_z = Bijectors. aT_b (w, z)
24
+
25
+ tanh_term = CUDA. tanh .(CUDA. broadcast (+ , wT_z, b))
26
+ transformed = CUDA. broadcast (+ , z, CUDA. broadcast (* , û, tanh_term))
27
+
28
+ return (transformed= transformed, wT_û= wT_û, wT_z= wT_z)
29
+ end
30
+
9
31
dists = [
10
32
MvNormal (CUDA. zeros (2 ), cu (Matrix {Float64} (I, 2 , 2 ))),
11
33
MvNormal (CUDA. zeros (2 ), cu ([1.0 0.5 ; 0.5 1.0 ])),
@@ -14,18 +36,24 @@ using Bijectors, CUDA, Distributions, Flux, LinearAlgebra, Test
14
36
@testset " $dist " for dist in dists
15
37
x = NormalizingFlows. rand_device (CUDA. default_rng (), dist)
16
38
xs = NormalizingFlows. rand_device (CUDA. default_rng (), dist, 100 )
39
+ @test_nowarn logpdf (dist, x)
17
40
@test x isa CuArray
18
41
@test xs isa CuArray
19
42
end
20
43
21
44
@testset " $dist " for dist in dists
22
45
CUDA. allowscalar (true )
23
- ts = reduce (∘ , [Bijectors. PlanarLayer (2 ) for _ in 1 : 2 ])
24
- ts_g = gpu (ts)
25
- flow = Bijectors. transformed (dist, ts_g)
46
+ pl1 = PlanarLayer (
47
+ identity (CUDA. rand (2 )), identity (CUDA. rand (2 )), identity (CUDA. rand (1 ))
48
+ )
49
+ pl2 = PlanarLayer (
50
+ identity (CUDA. rand (2 )), identity (CUDA. rand (2 )), identity (CUDA. rand (1 ))
51
+ )
52
+ flow = Bijectors. transformed (dist, ComposedFunction (pl1, pl2))
26
53
27
54
y = NormalizingFlows. rand_device (CUDA. default_rng (), flow)
28
55
ys = NormalizingFlows. rand_device (CUDA. default_rng (), flow, 100 )
56
+ @test_nowarn logpdf (flow, y)
29
57
@test y isa CuArray
30
58
@test ys isa CuArray
31
59
end
0 commit comments