Skip to content

Commit 211ab21

Browse files
committed
Partially handle convolution
1 parent 30f076e commit 211ab21

File tree

4 files changed

+127
-90
lines changed

4 files changed

+127
-90
lines changed

ext/AMDGPUExt/functor.jl

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,23 @@ adapt_storage(::FluxAMDAdaptor, x::AbstractRNG) = error("""
1818
Cannot map RNG of type $(typeof(x)) to AMDGPU.
1919
AMDGPU execution only supports Random.default_rng().""")
2020

21-
# TODO adaptor for Conv
21+
function adapt_storage(to::FluxAMDAdaptor, m::Flux.Conv)
22+
Flux.Conv(
23+
Adapt.adapt(to, m.σ),
24+
Adapt.adapt(to, m.weight[end:-1:1, end:-1:1, :, :]),
25+
Adapt.adapt(to, m.bias),
26+
m.stride, m.pad, m.dilation, m.groups)
27+
end
28+
29+
# # Don't adapt again.
30+
# function adapt_storage(
31+
# to::FluxAMDAdaptor, m::Flux.Conv{N, M, F, A, V},
32+
# ) where {N, M, F, A <: ROCArray, V}
33+
# return m
34+
# end
35+
36+
# TODO GPU -> CPU adaptor
37+
# TODO don't adapt again when already on AMDGPU
2238

2339
adapt_storage(::FluxCPUAdaptor, x::AMDGPU.rocRAND.RNG) = Random.default_rng()
2440

@@ -40,3 +56,12 @@ function _amd(x)
4056
fmap(x -> Adapt.adapt(FluxAMDAdaptor(), x), x; exclude=_isleaf) :
4157
x
4258
end
59+
60+
function _amd(m::Flux.Conv)
61+
to = FluxAMDAdaptor()
62+
Flux.Conv(
63+
Adapt.adapt(to, m.σ),
64+
Adapt.adapt(to, m.weight[end:-1:1, end:-1:1, :, :]),
65+
Adapt.adapt(to, m.bias),
66+
m.stride, m.pad, m.dilation, m.groups)
67+
end

test/amd/basic.jl

Lines changed: 58 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,72 @@
11
@test Flux.AMDGPU_LOADED[]
22

3-
@testset "Basic GPU movement" begin
4-
@test Flux.amd(rand(Float64, 16)) isa ROCArray{Float32, 1}
5-
@test Flux.amd(rand(Float64, 16, 16)) isa ROCArray{Float32, 2}
6-
@test Flux.amd(rand(Float32, 16, 16)) isa ROCArray{Float32, 2}
7-
@test Flux.amd(rand(Float16, 16, 16, 16)) isa ROCArray{Float16, 3}
3+
# @testset "Basic GPU movement" begin
4+
# @test Flux.amd(rand(Float64, 16)) isa ROCArray{Float32, 1}
5+
# @test Flux.amd(rand(Float64, 16, 16)) isa ROCArray{Float32, 2}
6+
# @test Flux.amd(rand(Float32, 16, 16)) isa ROCArray{Float32, 2}
7+
# @test Flux.amd(rand(Float16, 16, 16, 16)) isa ROCArray{Float16, 3}
88

9-
@test gradient(x -> sum(Flux.amd(x)), rand(Float32, 4, 4)) isa Tuple
10-
@test gradient(x -> sum(cpu(x)), AMDGPU.rand(Float32, 4, 4)) isa Tuple
11-
end
9+
# @test gradient(x -> sum(Flux.amd(x)), rand(Float32, 4, 4)) isa Tuple
10+
# @test gradient(x -> sum(cpu(x)), AMDGPU.rand(Float32, 4, 4)) isa Tuple
11+
# end
1212

13-
@testset "Dense no bias" begin
14-
m = Dense(3 => 4; bias=false) |> Flux.amd
15-
x = zeros(Float32, 3, 4) |> Flux.amd
16-
@test sum(m(x)) 0f0
17-
gs = gradient(m -> sum(m(x)), m)
18-
@test isnothing(gs[1].bias)
19-
end
13+
# @testset "Dense no bias" begin
14+
# m = Dense(3 => 4; bias=false) |> Flux.amd
15+
# x = zeros(Float32, 3, 4) |> Flux.amd
16+
# @test sum(m(x)) ≈ 0f0
17+
# gs = gradient(m -> sum(m(x)), m)
18+
# @test isnothing(gs[1].bias)
19+
# end
2020

21-
@testset "Chain of Dense layers" begin
22-
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) |> f32
23-
x = rand(Float32, 10, 10)
24-
amdgputest(m, x)
25-
end
21+
# @testset "Chain of Dense layers" begin
22+
# m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) |> f32
23+
# x = rand(Float32, 10, 10)
24+
# amdgputest(m, x)
25+
# end
2626

27-
@testset "Cross-correlation" begin
28-
m = CrossCor((2, 2), 3 => 4) |> f32
29-
x = rand(Float32, 10, 10, 3, 2)
30-
amdgputest(m, x; atol=1f-3)
27+
@testset "Convolution" begin
28+
m = Conv((2, 2), 1 => 1) |> f32
29+
x = rand(Float32, 4, 4, 1, 1)
30+
amdgputest(m, x; atol=1f-3, checkgrad=false)
31+
32+
# Gradients are flipped as well.
33+
md, xd = Flux.amd.((m, x))
34+
gs = gradient(m -> sum(m(x)), m)
35+
gsd = gradient(m -> sum(m(xd)), md)
36+
@test gs[1].weight[end:-1:1, end:-1:1, :, :] Array(gsd[1].weight) atol=1f-3
3137
end
3238

33-
@testset "Restructure" begin
34-
m = Dense(1, 1) |> Flux.amd
35-
θ, m̂ = Flux.destructure(m)
36-
foo(x) = sum(re(p)(x))
39+
# @testset "Cross-correlation" begin
40+
# m = CrossCor((2, 2), 3 => 4) |> f32
41+
# x = rand(Float32, 10, 10, 3, 2)
42+
# amdgputest(m, x; atol=1f-3)
43+
# end
3744

38-
x = Flux.amd(rand(Float32, 1))
39-
@test gradient(x -> sum((θ)(x)), x)[1] isa ROCArray{Float32}
40-
end
45+
# @testset "Restructure" begin
46+
# m = Dense(1, 1) |> Flux.amd
47+
# θ, m̂ = Flux.destructure(m)
48+
# foo(x) = sum(re(p)(x))
4149

42-
@testset "Flux.amd(x) on structured arrays" begin
43-
g1 = Zygote.OneElement(1, (2, 3), axes(ones(4, 5)))
44-
@test Flux.amd(g1) isa ROCMatrix{Int64}
45-
g2 = Zygote.Fill(1f0, 2)
46-
@test Flux.amd(g2) isa ROCArray{Float32, 1}
47-
g3 = transpose(Float32[1 2; 3 4])
48-
@test parent(Flux.amd(g3)) isa ROCMatrix{Float32}
49-
end
50+
# x = Flux.amd(rand(Float32, 1))
51+
# @test gradient(x -> sum(m̂(θ)(x)), x)[1] isa ROCArray{Float32}
52+
# end
5053

51-
@testset "Flux.onecold gpu" begin
52-
y = Flux.onehotbatch(ones(3), 1:10) |> Flux.amd
53-
l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
54-
@test Flux.onecold(y) isa ROCArray
55-
@test y[3, :] isa ROCArray
56-
@test Flux.onecold(y, l) == ['a', 'a', 'a']
57-
end
54+
# @testset "Flux.amd(x) on structured arrays" begin
55+
# g1 = Zygote.OneElement(1, (2, 3), axes(ones(4, 5)))
56+
# @test Flux.amd(g1) isa ROCMatrix{Int64}
57+
# g2 = Zygote.Fill(1f0, 2)
58+
# @test Flux.amd(g2) isa ROCArray{Float32, 1}
59+
# g3 = transpose(Float32[1 2; 3 4])
60+
# @test parent(Flux.amd(g3)) isa ROCMatrix{Float32}
61+
# end
62+
63+
# @testset "Flux.onecold gpu" begin
64+
# y = Flux.onehotbatch(ones(3), 1:10) |> Flux.amd
65+
# l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
66+
# @test Flux.onecold(y) isa ROCArray
67+
# @test y[3, :] isa ROCArray
68+
# @test Flux.onecold(y, l) == ['a', 'a', 'a']
69+
# end
5870

5971
# FIXME scalar indexing. Needs NNlib.scatter?
6072
# @testset "Flux.onehot gpu" begin

test/amd/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
function amdgputest(model, xs...; checkgrad=true, atol=1e-6, kws...)
1+
function amdgputest(model, xs...; checkgrad=true, atol=1e-6)
22
cpu_model = model
33
gpu_model = Flux.amd(model)
44

test/runtests.jl

Lines changed: 42 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -11,55 +11,55 @@ Random.seed!(0)
1111

1212
@testset verbose=true "Flux.jl" begin
1313

14-
@testset "Utils" begin
15-
include("utils.jl")
16-
end
14+
# @testset "Utils" begin
15+
# include("utils.jl")
16+
# end
1717

18-
@testset "Optimise / Train" begin
19-
include("optimise.jl")
20-
include("train.jl")
21-
end
18+
# @testset "Optimise / Train" begin
19+
# include("optimise.jl")
20+
# include("train.jl")
21+
# end
2222

23-
@testset "Data" begin
24-
include("data.jl")
25-
end
23+
# @testset "Data" begin
24+
# include("data.jl")
25+
# end
2626

27-
@testset "Losses" begin
28-
include("losses.jl")
29-
include("ctc.jl")
30-
CUDA.functional() && include("ctc-gpu.jl")
31-
end
27+
# @testset "Losses" begin
28+
# include("losses.jl")
29+
# include("ctc.jl")
30+
# CUDA.functional() && include("ctc-gpu.jl")
31+
# end
3232

33-
@testset "Layers" begin
34-
include("layers/basic.jl")
35-
include("layers/normalisation.jl")
36-
include("layers/stateless.jl")
37-
include("layers/recurrent.jl")
38-
include("layers/conv.jl")
39-
include("layers/upsample.jl")
40-
include("layers/show.jl")
41-
end
33+
# @testset "Layers" begin
34+
# include("layers/basic.jl")
35+
# include("layers/normalisation.jl")
36+
# include("layers/stateless.jl")
37+
# include("layers/recurrent.jl")
38+
# include("layers/conv.jl")
39+
# include("layers/upsample.jl")
40+
# include("layers/show.jl")
41+
# end
4242

43-
@testset "outputsize" begin
44-
using Flux: outputsize
45-
include("outputsize.jl")
46-
end
43+
# @testset "outputsize" begin
44+
# using Flux: outputsize
45+
# include("outputsize.jl")
46+
# end
4747

48-
@testset "CUDA" begin
49-
if CUDA.functional()
50-
include("cuda/runtests.jl")
51-
else
52-
@warn "CUDA unavailable, not testing GPU support"
53-
end
54-
end
48+
# @testset "CUDA" begin
49+
# if CUDA.functional()
50+
# include("cuda/runtests.jl")
51+
# else
52+
# @warn "CUDA unavailable, not testing GPU support"
53+
# end
54+
# end
5555

56-
@static if VERSION == v"1.6"
57-
using Documenter
58-
@testset "Docs" begin
59-
DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
60-
doctest(Flux)
61-
end
62-
end
56+
# @static if VERSION == v"1.6"
57+
# using Documenter
58+
# @testset "Docs" begin
59+
# DocMeta.setdocmeta!(Flux, :DocTestSetup, :(using Flux); recursive=true)
60+
# doctest(Flux)
61+
# end
62+
# end
6363

6464
if get(ENV, "FLUX_TEST_AMDGPU", "false") == "true"
6565
using AMDGPU

0 commit comments

Comments
 (0)