|
1 |
| -using CUDA, Test, LinearAlgebra, Distributions |
2 |
| -using Flux |
| 1 | +using Pkg |
| 2 | +Pkg.activate(@__DIR__) |
| 3 | +Pkg.develop(; path=joinpath(@__DIR__, "..", "..", "..")) |
3 | 4 |
|
4 |
| -if CUDA.functional() |
5 |
| - @testset "rand with CUDA" begin |
6 |
| - dists = [ |
7 |
| - MvNormal(CUDA.zeros(2), I), MvNormal(CUDA.zeros(2), cu([1.0 0.5; 0.5 1.0])) |
8 |
| - ] |
| 5 | +using NormalizingFlows |
| 6 | +using CUDA, Test, LinearAlgebra, Distributions, Flux |
9 | 7 |
|
10 |
| - @testset "$dist" for dist in dists |
11 |
| - x = rand_device(CUDA.default_rng(), dist) |
12 |
| - xs = rand_device(CUDA.default_rng(), dist, 100) |
13 |
| - @info logpdf(dist, x) |
14 |
| - @test x isa CuArray |
15 |
| - @test xs isa CuArray |
16 |
| - end |
| 8 | +@testset "rand with CUDA" begin |
| 9 | + dists = [MvNormal(CUDA.zeros(2), I), MvNormal(CUDA.zeros(2), cu([1.0 0.5; 0.5 1.0]))] |
17 | 10 |
|
18 |
| - @testset "$dist" for dist in dists |
19 |
| - CUDA.allowscalar(true) |
20 |
| - ts = reduce(∘, [Bijectors.PlanarLayer(2) for _ in 1:2]) |
21 |
| - ts_g = gpu(ts) |
22 |
| - flow = Bijectors.transformed(dist, ts_g) |
| 11 | + @testset "$dist" for dist in dists |
| 12 | + x = rand_device(CUDA.default_rng(), dist) |
| 13 | + xs = rand_device(CUDA.default_rng(), dist, 100) |
| 14 | + @info logpdf(dist, x) |
| 15 | + @test x isa CuArray |
| 16 | + @test xs isa CuArray |
| 17 | + end |
| 18 | + |
| 19 | + @testset "$dist" for dist in dists |
| 20 | + CUDA.allowscalar(true) |
| 21 | + ts = reduce(∘, [Bijectors.PlanarLayer(2) for _ in 1:2]) |
| 22 | + ts_g = gpu(ts) |
| 23 | + flow = Bijectors.transformed(dist, ts_g) |
23 | 24 |
|
24 |
| - y = rand_device(CUDA.default_rng(), flow) |
25 |
| - ys = rand_device(CUDA.default_rng(), flow, 100) |
26 |
| - @info logpdf(flow, y) |
27 |
| - @test y isa CuArray |
28 |
| - @test ys isa CuArray |
29 |
| - end |
| 25 | + y = rand_device(CUDA.default_rng(), flow) |
| 26 | + ys = rand_device(CUDA.default_rng(), flow, 100) |
| 27 | + @info logpdf(flow, y) |
| 28 | + @test y isa CuArray |
| 29 | + @test ys isa CuArray |
30 | 30 | end
|
31 | 31 | end
|
0 commit comments