Skip to content

Commit 30f076e

Browse files
committed
Add more tests
1 parent 9799053 commit 30f076e

File tree

9 files changed

+186
-40
lines changed

9 files changed

+186
-40
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,4 +57,4 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5757
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
5858

5959
[targets]
60-
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
60+
test = ["AMDGPU", "Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]

ext/AMDGPUExt/AMDGPUExt.jl

Lines changed: 28 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,20 +1,42 @@
11
module AMDGPUExt
22

3+
import ChainRulesCore
4+
import Flux
5+
import Flux: FluxCPUAdaptor, _amd, _isleaf, adapt_storage, fmap
6+
37
using AMDGPU
48
using Adapt
59
using Random
610
using Zygote
7-
import ChainRulesCore
8-
import Functors: fmap
9-
import Flux
10-
import Flux: FluxCPUAdaptor, adapt_storage, _isleaf, _amd
1111

12-
const use_amdgpu = Ref{Bool}(false)
12+
const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing)
13+
14+
function check_use_amdgpu()
15+
isnothing(USE_AMDGPU[]) || return
16+
17+
USE_AMDGPU[] = AMDGPU.functional()
18+
if USE_AMDGPU[]
19+
if !AMDGPU.functional(:MIOpen)
20+
@warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be available."
21+
end
22+
else
23+
@info """
24+
The AMDGPU function is being called but the AMDGPU is not functional.
25+
Defaulting back to the CPU. (No action is required if you want to run on the CPU).
26+
""" maxlog=1
27+
end
28+
return
29+
end
30+
ChainRulesCore.@non_differentiable check_use_amdgpu()
1331

1432
include("functor.jl")
1533

1634
function __init__()
17-
Flux.amdgpu_loaded[] = true
35+
Flux.AMDGPU_LOADED[] = true
1836
end
1937

38+
# TODO
39+
# fail early if input to the model is not on the device (e.g. on the host)
40+
# otherwise we get very cryptic errors & segfaults at the rocBLAS level
41+
2042
end

ext/AMDGPUExt/functor.jl

Lines changed: 18 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,20 @@
1-
struct FluxAMDGPUAdaptor end
1+
struct FluxAMDAdaptor end
22

3-
adapt_storage(::FluxAMDGPUAdaptor, x) = ROCArray(x)
4-
adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.FillArrays.AbstractFill) =
3+
# Convert Float64 to Float32, but preserve Float16.
4+
adapt_storage(::FluxAMDAdaptor, x::T) where T <: AbstractArray =
5+
isbits(x) ? x : ROCArray(x)
6+
adapt_storage(::FluxAMDAdaptor, x::AbstractArray{T, N}) where {T <: AbstractFloat, N} =
7+
isbits(x) ? x : ROCArray{Float32, N}(x)
8+
adapt_storage(::FluxAMDAdaptor, x::AbstractArray{Float16, N}) where N =
9+
isbits(x) ? x : ROCArray{Float16, N}(x)
10+
11+
adapt_storage(::FluxAMDAdaptor, x::Zygote.FillArrays.AbstractFill) =
512
ROCArray(collect(x))
6-
adapt_storage(::FluxAMDGPUAdaptor, x::Zygote.OneElement) = ROCArray(collect(x))
7-
adapt_storage(::FluxAMDGPUAdaptor, x::Random.TaskLocalRNG) =
13+
adapt_storage(::FluxAMDAdaptor, x::Zygote.OneElement) = ROCArray(collect(x))
14+
adapt_storage(::FluxAMDAdaptor, x::Random.TaskLocalRNG) =
815
AMDGPU.rocRAND.default_rng()
9-
adapt_storage(::FluxAMDGPUAdaptor, x::AMDGPU.rocRAND.RNG) = x
10-
adapt_storage(::FluxAMDGPUAdaptor, x::AbstractRNG) = error("""
16+
adapt_storage(::FluxAMDAdaptor, x::AMDGPU.rocRAND.RNG) = x
17+
adapt_storage(::FluxAMDAdaptor, x::AbstractRNG) = error("""
1118
Cannot map RNG of type $(typeof(x)) to AMDGPU.
1219
AMDGPU execution only supports Random.default_rng().""")
1320

@@ -24,28 +31,12 @@ function ChainRulesCore.rrule(
2431
)
2532
adapt_storage(to, x), dx -> (
2633
NoTangent(), NoTangent(),
27-
adapt_storage(FluxAMDGPUAdaptor(), unthunk(dx)))
34+
adapt_storage(FluxAMDAdaptor(), unthunk(dx)))
2835
end
2936

3037
function _amd(x)
3138
check_use_amdgpu()
32-
use_amdgpu[] ? fmap(x -> Adapt.adapt(FluxAMDGPUAdaptor(), x)) : x
33-
end
34-
35-
function check_use_amdgpu()
36-
use_amdgpu[] === nothing || return
37-
38-
use_amdgpu[] = AMDGPU.functional()
39-
if use_amdgpu[]
40-
if !AMDGPU.functional(:MIOpen)
41-
@warn "MIOpen is not functional in AMDGPU.jl, some functionality will not be available."
42-
end
43-
else
44-
@info """
45-
The AMDGPU function is being called but the AMDGPU is not functional.
46-
Defaulting back to the CPU. (No action is required if you want to run on the CPU).
47-
""" maxlog=1
48-
end
49-
return
39+
USE_AMDGPU[] ?
40+
fmap(x -> Adapt.adapt(FluxAMDAdaptor(), x), x; exclude=_isleaf) :
41+
x
5042
end
51-
ChainRulesCore.@non_differentiable check_use_amdgpu()

src/functor.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -282,10 +282,10 @@ trainable(c::Cholesky) = ()
282282

283283
# AMDGPU extension.
284284

285-
const amdgpu_loaded = Ref{Bool}(false)
285+
const AMDGPU_LOADED = Ref{Bool}(false)
286286

287287
function amd(x)
288-
if amdgpu_loaded[]
288+
if AMDGPU_LOADED[]
289289
return _amd(x)
290290
else
291291
@info """

test/amd/basic.jl

Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
@test Flux.AMDGPU_LOADED[]
2+
3+
@testset "Basic GPU movement" begin
4+
@test Flux.amd(rand(Float64, 16)) isa ROCArray{Float32, 1}
5+
@test Flux.amd(rand(Float64, 16, 16)) isa ROCArray{Float32, 2}
6+
@test Flux.amd(rand(Float32, 16, 16)) isa ROCArray{Float32, 2}
7+
@test Flux.amd(rand(Float16, 16, 16, 16)) isa ROCArray{Float16, 3}
8+
9+
@test gradient(x -> sum(Flux.amd(x)), rand(Float32, 4, 4)) isa Tuple
10+
@test gradient(x -> sum(cpu(x)), AMDGPU.rand(Float32, 4, 4)) isa Tuple
11+
end
12+
13+
@testset "Dense no bias" begin
14+
m = Dense(3 => 4; bias=false) |> Flux.amd
15+
x = zeros(Float32, 3, 4) |> Flux.amd
16+
@test sum(m(x)) 0f0
17+
gs = gradient(m -> sum(m(x)), m)
18+
@test isnothing(gs[1].bias)
19+
end
20+
21+
@testset "Chain of Dense layers" begin
22+
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) |> f32
23+
x = rand(Float32, 10, 10)
24+
amdgputest(m, x)
25+
end
26+
27+
@testset "Cross-correlation" begin
28+
m = CrossCor((2, 2), 3 => 4) |> f32
29+
x = rand(Float32, 10, 10, 3, 2)
30+
amdgputest(m, x; atol=1f-3)
31+
end
32+
33+
@testset "Restructure" begin
34+
m = Dense(1, 1) |> Flux.amd
35+
θ, m̂ = Flux.destructure(m)
36+
foo(x) = sum(re(p)(x))
37+
38+
x = Flux.amd(rand(Float32, 1))
39+
@test gradient(x -> sum((θ)(x)), x)[1] isa ROCArray{Float32}
40+
end
41+
42+
@testset "Flux.amd(x) on structured arrays" begin
43+
g1 = Zygote.OneElement(1, (2, 3), axes(ones(4, 5)))
44+
@test Flux.amd(g1) isa ROCMatrix{Int64}
45+
g2 = Zygote.Fill(1f0, 2)
46+
@test Flux.amd(g2) isa ROCArray{Float32, 1}
47+
g3 = transpose(Float32[1 2; 3 4])
48+
@test parent(Flux.amd(g3)) isa ROCMatrix{Float32}
49+
end
50+
51+
@testset "Flux.onecold gpu" begin
52+
y = Flux.onehotbatch(ones(3), 1:10) |> Flux.amd
53+
l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
54+
@test Flux.onecold(y) isa ROCArray
55+
@test y[3, :] isa ROCArray
56+
@test Flux.onecold(y, l) == ['a', 'a', 'a']
57+
end
58+
59+
# FIXME scalar indexing. Needs NNlib.scatter?
60+
# @testset "Flux.onehot gpu" begin
61+
# y = Flux.onehotbatch(ones(3), 1:2) |> Flux.amd
62+
# x = rand(3, 2) |> Flux.amd
63+
# @test gradient(x -> sum(x * y), x)[1] isa ROCArray
64+
# end

test/amd/runtests.jl

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,7 @@
1+
include("utils.jl")
2+
3+
AMDGPU.allowscalar(false)
4+
5+
@testset "Basic" begin
6+
include("basic.jl")
7+
end

test/amd/utils.jl

Lines changed: 47 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,47 @@
1+
function amdgputest(model, xs...; checkgrad=true, atol=1e-6, kws...)
2+
cpu_model = model
3+
gpu_model = Flux.amd(model)
4+
5+
cpu_in = xs
6+
gpu_in = Flux.amd.(xs)
7+
8+
cpu_out = cpu_model(cpu_in...)
9+
gpu_out = gpu_model(gpu_in...)
10+
@test collect(cpu_out) collect(gpu_out) atol=atol
11+
12+
if checkgrad
13+
cpu_grad = gradient(m -> sum(m(cpu_in...)), cpu_model)
14+
gpu_grad = gradient(m -> sum(m(gpu_in...)), gpu_model)
15+
amd_check_grad(gpu_grad, cpu_grad; atol)
16+
end
17+
end
18+
19+
function amd_check_grad(g_gpu, g_cpu; atol)
20+
@show g_gpu g_cpu
21+
@test false
22+
end
23+
24+
amd_check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue, atol) =
25+
amd_check_grad(g_gpu[], g_cpu[]; atol)
26+
amd_check_grad(g_gpu::Nothing, g_cpu::Nothing; atol) =
27+
@test true
28+
amd_check_grad(g_gpu::Float32, g_cpu::Float32; atol) =
29+
@test g_cpu g_gpu atol=atol
30+
amd_check_grad(g_gpu::ROCArray{Float32}, g_cpu::Array{Float32}; atol) =
31+
@test g_cpu collect(g_gpu) atol=atol
32+
amd_check_grad(
33+
g_gpu::ROCArray{Float32}, g_cpu::Zygote.FillArrays.AbstractFill; atol,
34+
) = @test collect(g_cpu) collect(g_gpu) atol=atol
35+
36+
function amd_check_grad(g_gpu::Tuple, g_cpu::Tuple; atol)
37+
for (v1, v2) in zip(g_gpu, g_cpu)
38+
amd_check_grad(v1, v2; atol)
39+
end
40+
end
41+
42+
function amd_check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple; atol)
43+
for ((k1, v1), (k2, v2)) in zip(pairs(g_gpu), pairs(g_cpu))
44+
@test k1 == k2
45+
amd_check_grad(v1, v2; atol)
46+
end
47+
end

test/cuda/cuda.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ end
9191
struct SimpleBits
9292
field::Int32
9393
end
94-
94+
9595
@test gpu((;a=ones(1))).a isa CuVector{Float32}
9696
@test gpu((;a=['a', 'b', 'c'])).a isa CuVector{Char}
9797
@test gpu((;a=[SimpleBits(1)])).a isa CuVector{SimpleBits}
@@ -167,14 +167,14 @@ end
167167
@test parent(gpu(g3)) isa CuArray
168168

169169

170-
#Issue #2116
170+
#Issue #2116
171171
struct A2116
172172
x::Int
173173
y::Int
174174
end
175175
x = [A2116(1,1), A2116(2,2)]
176-
xgpu = gpu(x)
176+
xgpu = gpu(x)
177177
@test xgpu isa CuVector{A2116}
178-
@test cpu(xgpu) isa Vector{A2116}
178+
@test cpu(xgpu) isa Vector{A2116}
179179
@test cpu(gpu([CartesianIndex(1)])) isa Vector{CartesianIndex{1}}
180180
end

test/runtests.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -60,4 +60,19 @@ Random.seed!(0)
6060
doctest(Flux)
6161
end
6262
end
63+
64+
if get(ENV, "FLUX_TEST_AMDGPU", "false") == "true"
65+
using AMDGPU
66+
AMDGPU.versioninfo()
67+
if AMDGPU.functional() && AMDGPU.functional(:MIOpen)
68+
@show AMDGPU.MIOpen.version()
69+
@testset "AMDGPU" begin
70+
include("amd/runtests.jl")
71+
end
72+
else
73+
@info "AMDGPU.jl package is not functional. Skipping AMDGPU tests."
74+
end
75+
else
76+
@info "Skipping AMDGPU tests, set FLUX_TEST_CUDA=true to run them."
77+
end
6378
end

0 commit comments

Comments
 (0)