Skip to content

Commit 54c4946

Browse files
committed
Unify gpu test utils
1 parent 98ce4e9 commit 54c4946

File tree

7 files changed

+108
-132
lines changed

7 files changed

+108
-132
lines changed

test/amd/basic.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ end
2121
@testset "Chain of Dense layers" begin
2222
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) |> f32
2323
x = rand(Float32, 10, 10)
24-
amdgputest(m, x)
24+
gpu_autodiff_test(m, x)
2525
end
2626

2727
@testset "Convolution" begin
@@ -30,7 +30,7 @@ end
3030
x = rand(Float32, fill(10, nd)..., 3, 5)
3131

3232
# Ensure outputs are the same.
33-
amdgputest(m, x; atol=1f-3, checkgrad=false)
33+
gpu_autodiff_test(m, x; atol=1f-3, checkgrad=false)
3434

3535
# Gradients are flipped as well.
3636
md, xd = Flux.gpu.((m, x))
@@ -49,7 +49,7 @@ end
4949
@testset "Cross-correlation" begin
5050
m = CrossCor((2, 2), 3 => 4) |> f32
5151
x = rand(Float32, 10, 10, 3, 2)
52-
amdgputest(m, x; atol=1f-3)
52+
gpu_autodiff_test(m, x)
5353
end
5454

5555
@testset "Restructure" begin
@@ -82,6 +82,6 @@ end
8282
bn = BatchNorm(3, σ)
8383
for nd in 1:3
8484
x = rand(Float32, fill(2, nd - 1)..., 3, 4)
85-
amdgputest(bn, x; atol=1f-3, allow_nothing=true)
85+
gpu_autodiff_test(bn, x; atol=1f-3, allow_nothing=true)
8686
end
8787
end

test/amd/runtests.jl

Lines changed: 18 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,25 @@
11
Flux.gpu_backend!("AMD")
22

3-
include("utils.jl")
4-
53
AMDGPU.allowscalar(false)
64

5+
# Extend test utils to AMDGPU.
6+
7+
function check_grad(
8+
g_gpu::ROCArray{Float32}, g_cpu::Array{Float32}, atol, rtol;
9+
allow_nothing::Bool,
10+
)
11+
@test g_cpu collect(g_gpu) atol=atol rtol=rtol
12+
end
13+
14+
function check_grad(
15+
g_gpu::ROCArray{Float32}, g_cpu::Zygote.FillArrays.AbstractFill,
16+
atol, rtol; allow_nothing::Bool,
17+
)
18+
@test g_cpu collect(g_gpu) atol=atol rtol=rtol
19+
end
20+
21+
check_type(x::ROCArray{Float32}) = true
22+
723
@testset "Basic" begin
824
include("basic.jl")
925
end

test/amd/utils.jl

Lines changed: 0 additions & 53 deletions
This file was deleted.

test/cuda/runtests.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@ using Random, LinearAlgebra, Statistics
66
@info "Testing GPU Support"
77
CUDA.allowscalar(false)
88

9-
include("test_utils.jl")
109
include("cuda.jl")
1110
include("losses.jl")
1211
include("layers.jl")

test/cuda/test_utils.jl

Lines changed: 0 additions & 72 deletions
This file was deleted.

test/runtests.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ using IterTools: ncycle
77
using Zygote
88
using CUDA
99

10+
include("test_utils.jl")
11+
1012
Random.seed!(0)
1113

1214
@testset verbose=true "Flux.jl" begin

test/test_utils.jl

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
function check_grad(g_gpu, g_cpu, atol, rtol; allow_nothing::Bool)
2+
allow_nothing && return
3+
@show g_gpu g_cpu
4+
@test false
5+
end
6+
check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue, atol, rtol; allow_nothing::Bool) =
7+
check_grad(g_gpu[], g_cpu[], atol, rtol; allow_nothing)
8+
check_grad(g_gpu::Nothing, g_cpu::Nothing, atol, rtol; allow_nothing::Bool) =
9+
@test true
10+
check_grad(g_gpu::Float32, g_cpu::Float32, atol, rtol; allow_nothing::Bool) =
11+
@test g_cpu g_gpu rtol=rtol atol=atol
12+
check_grad(g_gpu::CuArray{Float32}, g_cpu::Array{Float32}, atol, rtol; allow_nothing::Bool) =
13+
@test g_cpu collect(g_gpu) rtol=rtol atol=atol
14+
15+
function check_grad(g_gpu::Tuple, g_cpu::Tuple, atol, rtol; allow_nothing::Bool)
16+
for (v1, v2) in zip(g_gpu, g_cpu)
17+
check_grad(v1, v2, atol, rtol; allow_nothing)
18+
end
19+
end
20+
21+
function check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple, atol, rtol; allow_nothing::Bool)
22+
for ((k1,v1), (k2,v2)) in zip(pairs(g_gpu), pairs(g_cpu))
23+
@test k1 == k2
24+
check_grad(v1, v2, atol, rtol; allow_nothing)
25+
end
26+
end
27+
28+
check_type(x) = false
29+
check_type(x::Float32) = true
30+
check_type(x::CuArray{Float32}) = true
31+
check_type(x::Array{Float32}) = true
32+
33+
function gpu_autodiff_test(
34+
f_cpu, xs_cpu::Array{Float32}...;
35+
test_equal=true, rtol=1e-4, atol=1e-4,
36+
checkgrad::Bool = true, allow_nothing::Bool = false,
37+
)
38+
# Compare CPU & GPU function outputs.
39+
f_gpu = f_cpu |> gpu
40+
xs_gpu = gpu.(xs_cpu)
41+
42+
y_cpu = f_cpu(xs_cpu...)
43+
y_gpu = f_gpu(xs_gpu...)
44+
@test collect(y_cpu) collect(y_gpu) atol=atol rtol=rtol
45+
46+
checkgrad || return
47+
48+
### GRADIENT WITH RESPECT TO INPUT ###
49+
50+
y_cpu, back_cpu = pullback((x...) -> f_cpu(x...), xs_cpu...)
51+
@test check_type(y_cpu)
52+
Δ_cpu = size(y_cpu) == () ? randn(Float32) : randn(Float32, size(y_cpu))
53+
gs_cpu = back_cpu(Δ_cpu)
54+
55+
Δ_gpu = Δ_cpu |> gpu
56+
y_gpu, back_gpu = pullback((x...) -> f_gpu(x...), xs_gpu...)
57+
@test check_type(y_gpu)
58+
gs_gpu = back_gpu(Δ_gpu)
59+
60+
if test_equal
61+
@test collect(y_cpu) collect(y_gpu) rtol=rtol atol=atol
62+
for (g_gpu, g_cpu) in zip(gs_gpu, gs_cpu)
63+
check_grad(g_gpu, g_cpu, atol, rtol; allow_nothing)
64+
end
65+
end
66+
67+
### GRADIENT WITH RESPECT TO f ###
68+
69+
ps_cpu = Flux.params(f_cpu)
70+
y_cpu, back_cpu = pullback(() -> f_cpu(xs_cpu...), ps_cpu)
71+
gs_cpu = back_cpu(Δ_cpu)
72+
73+
ps_gpu = Flux.params(f_gpu)
74+
y_gpu, back_gpu = pullback(() -> f_gpu(xs_gpu...), ps_gpu)
75+
gs_gpu = back_gpu(Δ_gpu)
76+
77+
if test_equal
78+
@test collect(y_cpu) collect(y_gpu) rtol=rtol atol=atol
79+
@assert length(ps_gpu) == length(ps_cpu)
80+
for (p_gpu, p_cpu) in zip(ps_gpu, ps_cpu)
81+
check_grad(gs_gpu[p_gpu], gs_cpu[p_cpu], atol, rtol; allow_nothing)
82+
end
83+
end
84+
end

0 commit comments

Comments
 (0)