Skip to content

Commit 9783220

Browse files
committed
try to fix test error
1 parent 4f8cac8 commit 9783220

File tree

1 file changed

+31
-3
lines changed

1 file changed

+31
-3
lines changed

test/ext/CUDA/cuda.jl

Lines changed: 31 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,28 @@ using NormalizingFlows
66
using Bijectors, CUDA, Distributions, Flux, LinearAlgebra, Test
77

88
@testset "rand with CUDA" begin
9+
10+
# Bijectors versions use dot for broadcasting, which causes issues with CUDA.
11+
function Bijectors.get_u_hat(u::CuVector{T}, w::CuVector{T}) where {T<:Real}
12+
wT_u = dot(w, u)
13+
scale = (Bijectors.LogExpFunctions.log1pexp(-wT_u) - 1) / sum(abs2, w)
14+
û = CUDA.broadcast(+, u, CUDA.broadcast(*, scale, w))
15+
wT_û = Bijectors.LogExpFunctions.log1pexp(wT_u) - 1
16+
return û, wT_û
17+
end
18+
function Bijectors._transform(flow::PlanarLayer, z::CuArray{T}) where {T<:Real}
19+
w = CuArray(flow.w)
20+
b = T(first(flow.b)) # Scalar
21+
22+
û, wT_û = Bijectors.get_u_hat(CuArray(flow.u), w)
23+
wT_z = Bijectors.aT_b(w, z)
24+
25+
tanh_term = CUDA.tanh.(CUDA.broadcast(+, wT_z, b))
26+
transformed = CUDA.broadcast(+, z, CUDA.broadcast(*, û, tanh_term))
27+
28+
return (transformed=transformed, wT_û=wT_û, wT_z=wT_z)
29+
end
30+
931
dists = [
1032
MvNormal(CUDA.zeros(2), cu(Matrix{Float64}(I, 2, 2))),
1133
MvNormal(CUDA.zeros(2), cu([1.0 0.5; 0.5 1.0])),
@@ -14,18 +36,24 @@ using Bijectors, CUDA, Distributions, Flux, LinearAlgebra, Test
1436
@testset "$dist" for dist in dists
1537
x = NormalizingFlows.rand_device(CUDA.default_rng(), dist)
1638
xs = NormalizingFlows.rand_device(CUDA.default_rng(), dist, 100)
39+
@test_nowarn logpdf(dist, x)
1740
@test x isa CuArray
1841
@test xs isa CuArray
1942
end
2043

2144
@testset "$dist" for dist in dists
2245
CUDA.allowscalar(true)
23-
ts = reduce(, [Bijectors.PlanarLayer(2) for _ in 1:2])
24-
ts_g = gpu(ts)
25-
flow = Bijectors.transformed(dist, ts_g)
46+
pl1 = PlanarLayer(
47+
identity(CUDA.rand(2)), identity(CUDA.rand(2)), identity(CUDA.rand(1))
48+
)
49+
pl2 = PlanarLayer(
50+
identity(CUDA.rand(2)), identity(CUDA.rand(2)), identity(CUDA.rand(1))
51+
)
52+
flow = Bijectors.transformed(dist, ComposedFunction(pl1, pl2))
2653

2754
y = NormalizingFlows.rand_device(CUDA.default_rng(), flow)
2855
ys = NormalizingFlows.rand_device(CUDA.default_rng(), flow, 100)
56+
@test_nowarn logpdf(flow, y)
2957
@test y isa CuArray
3058
@test ys isa CuArray
3159
end

0 commit comments

Comments
 (0)