Skip to content

Commit d0eb6a0

Browse files
committed
Handle convolutions correctly
1 parent 211ab21 commit d0eb6a0

File tree

2 files changed

+101
-76
lines changed

2 files changed

+101
-76
lines changed

ext/AMDGPUExt/functor.jl

Lines changed: 36 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -18,24 +18,6 @@ adapt_storage(::FluxAMDAdaptor, x::AbstractRNG) = error("""
1818
Cannot map RNG of type $(typeof(x)) to AMDGPU.
1919
AMDGPU execution only supports Random.default_rng().""")
2020

21-
function adapt_storage(to::FluxAMDAdaptor, m::Flux.Conv)
22-
Flux.Conv(
23-
Adapt.adapt(to, m.σ),
24-
Adapt.adapt(to, m.weight[end:-1:1, end:-1:1, :, :]),
25-
Adapt.adapt(to, m.bias),
26-
m.stride, m.pad, m.dilation, m.groups)
27-
end
28-
29-
# # Don't adapt again.
30-
# function adapt_storage(
31-
# to::FluxAMDAdaptor, m::Flux.Conv{N, M, F, A, V},
32-
# ) where {N, M, F, A <: ROCArray, V}
33-
# return m
34-
# end
35-
36-
# TODO GPU -> CPU adaptor
37-
# TODO don't adapt again when already on AMDGPU
38-
3921
adapt_storage(::FluxCPUAdaptor, x::AMDGPU.rocRAND.RNG) = Random.default_rng()
4022

4123
function ChainRulesCore.rrule(::Type{Array}, x::ROCArray)
@@ -57,11 +39,44 @@ function _amd(x)
5739
x
5840
end
5941

60-
function _amd(m::Flux.Conv)
61-
to = FluxAMDAdaptor()
42+
# Since MIOpen supports only cross-correlation as convolution,
43+
# for the actual convolution, we flip horizontally and vertically the weights.
44+
# Same for CPU -> GPU & GPU -> CPU movements.
45+
# Note, that gradients are also flipped.
46+
47+
# CPU -> GPU
48+
49+
function adapt_storage(to::FluxAMDAdaptor, m::Flux.Conv)
50+
flipped_weight = reverse(m.weight; dims=ntuple(i -> i, ndims(m.weight) - 2))
51+
Flux.Conv(
52+
Adapt.adapt(to, m.σ),
53+
Adapt.adapt(to, flipped_weight),
54+
Adapt.adapt(to, m.bias),
55+
m.stride, m.pad, m.dilation, m.groups)
56+
end
57+
58+
# Don't adapt again.
59+
function adapt_storage(
60+
to::FluxAMDAdaptor, m::Flux.Conv{N, M, F, A, V},
61+
) where {N, M, F, A <: ROCArray, V}
62+
return m
63+
end
64+
65+
_amd(m::Flux.Conv) = adapt_storage(FluxAMDAdaptor(), m)
66+
67+
# GPU -> CPU
68+
69+
function Flux.cpu(m::Flux.Conv{N, M, F, A, V}) where {N, M, F, A <: ROCArray, V}
70+
adapt_storage(FluxCPUAdaptor(), m)
71+
end
72+
73+
function adapt_storage(
74+
to::FluxCPUAdaptor, m::Flux.Conv{N, M, F, A, V},
75+
) where {N, M, F, A <: ROCArray, V}
76+
dims = ntuple(i -> i, ndims(m.weight) - 2)
6277
Flux.Conv(
6378
Adapt.adapt(to, m.σ),
64-
Adapt.adapt(to, m.weight[end:-1:1, end:-1:1, :, :]),
79+
reverse(Adapt.adapt(to, m.weight); dims),
6580
Adapt.adapt(to, m.bias),
6681
m.stride, m.pad, m.dilation, m.groups)
6782
end

test/amd/basic.jl

Lines changed: 65 additions & 55 deletions
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,82 @@
11
@test Flux.AMDGPU_LOADED[]
22

3-
# @testset "Basic GPU movement" begin
4-
# @test Flux.amd(rand(Float64, 16)) isa ROCArray{Float32, 1}
5-
# @test Flux.amd(rand(Float64, 16, 16)) isa ROCArray{Float32, 2}
6-
# @test Flux.amd(rand(Float32, 16, 16)) isa ROCArray{Float32, 2}
7-
# @test Flux.amd(rand(Float16, 16, 16, 16)) isa ROCArray{Float16, 3}
3+
@testset "Basic GPU movement" begin
4+
@test Flux.amd(rand(Float64, 16)) isa ROCArray{Float32, 1}
5+
@test Flux.amd(rand(Float64, 16, 16)) isa ROCArray{Float32, 2}
6+
@test Flux.amd(rand(Float32, 16, 16)) isa ROCArray{Float32, 2}
7+
@test Flux.amd(rand(Float16, 16, 16, 16)) isa ROCArray{Float16, 3}
88

9-
# @test gradient(x -> sum(Flux.amd(x)), rand(Float32, 4, 4)) isa Tuple
10-
# @test gradient(x -> sum(cpu(x)), AMDGPU.rand(Float32, 4, 4)) isa Tuple
11-
# end
9+
@test gradient(x -> sum(Flux.amd(x)), rand(Float32, 4, 4)) isa Tuple
10+
@test gradient(x -> sum(cpu(x)), AMDGPU.rand(Float32, 4, 4)) isa Tuple
11+
end
1212

13-
# @testset "Dense no bias" begin
14-
# m = Dense(3 => 4; bias=false) |> Flux.amd
15-
# x = zeros(Float32, 3, 4) |> Flux.amd
16-
# @test sum(m(x)) ≈ 0f0
17-
# gs = gradient(m -> sum(m(x)), m)
18-
# @test isnothing(gs[1].bias)
19-
# end
13+
@testset "Dense no bias" begin
14+
m = Dense(3 => 4; bias=false) |> Flux.amd
15+
x = zeros(Float32, 3, 4) |> Flux.amd
16+
@test sum(m(x)) 0f0
17+
gs = gradient(m -> sum(m(x)), m)
18+
@test isnothing(gs[1].bias)
19+
end
2020

21-
# @testset "Chain of Dense layers" begin
22-
# m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) |> f32
23-
# x = rand(Float32, 10, 10)
24-
# amdgputest(m, x)
25-
# end
21+
@testset "Chain of Dense layers" begin
22+
m = Chain(Dense(10, 5, tanh), Dense(5, 2), softmax) |> f32
23+
x = rand(Float32, 10, 10)
24+
amdgputest(m, x)
25+
end
2626

2727
@testset "Convolution" begin
28-
m = Conv((2, 2), 1 => 1) |> f32
29-
x = rand(Float32, 4, 4, 1, 1)
30-
amdgputest(m, x; atol=1f-3, checkgrad=false)
28+
for nd in (1, 2, 3)
29+
m = Conv(tuple(fill(2, nd)...), 3 => 4) |> f32
30+
x = rand(Float32, fill(10, nd)..., 3, 5)
3131

32-
# Gradients are flipped as well.
33-
md, xd = Flux.amd.((m, x))
34-
gs = gradient(m -> sum(m(x)), m)
35-
gsd = gradient(m -> sum(m(xd)), md)
36-
@test gs[1].weight[end:-1:1, end:-1:1, :, :] Array(gsd[1].weight) atol=1f-3
32+
# Ensure outputs are the same.
33+
amdgputest(m, x; atol=1f-3, checkgrad=false)
34+
35+
# Gradients are flipped as well.
36+
md, xd = Flux.amd.((m, x))
37+
gs = gradient(m -> sum(m(x)), m)
38+
gsd = gradient(m -> sum(m(xd)), md)
39+
40+
dims = ntuple(i -> i, ndims(m.weight) - 2)
41+
@test reverse(gs[1].weight; dims) Array(gsd[1].weight) atol=1f-2
42+
43+
# Movement back to CPU flips weights back.
44+
mh = Flux.cpu(md)
45+
@test m.weight mh.weight
46+
end
3747
end
3848

39-
# @testset "Cross-correlation" begin
40-
# m = CrossCor((2, 2), 3 => 4) |> f32
41-
# x = rand(Float32, 10, 10, 3, 2)
42-
# amdgputest(m, x; atol=1f-3)
43-
# end
49+
@testset "Cross-correlation" begin
50+
m = CrossCor((2, 2), 3 => 4) |> f32
51+
x = rand(Float32, 10, 10, 3, 2)
52+
amdgputest(m, x; atol=1f-3)
53+
end
4454

45-
# @testset "Restructure" begin
46-
# m = Dense(1, 1) |> Flux.amd
47-
# θ, m̂ = Flux.destructure(m)
48-
# foo(x) = sum(re(p)(x))
55+
@testset "Restructure" begin
56+
m = Dense(1, 1) |> Flux.amd
57+
θ, m̂ = Flux.destructure(m)
58+
foo(x) = sum(re(p)(x))
4959

50-
# x = Flux.amd(rand(Float32, 1))
51-
# @test gradient(x -> sum(m̂(θ)(x)), x)[1] isa ROCArray{Float32}
52-
# end
60+
x = Flux.amd(rand(Float32, 1))
61+
@test gradient(x -> sum((θ)(x)), x)[1] isa ROCArray{Float32}
62+
end
5363

54-
# @testset "Flux.amd(x) on structured arrays" begin
55-
# g1 = Zygote.OneElement(1, (2, 3), axes(ones(4, 5)))
56-
# @test Flux.amd(g1) isa ROCMatrix{Int64}
57-
# g2 = Zygote.Fill(1f0, 2)
58-
# @test Flux.amd(g2) isa ROCArray{Float32, 1}
59-
# g3 = transpose(Float32[1 2; 3 4])
60-
# @test parent(Flux.amd(g3)) isa ROCMatrix{Float32}
61-
# end
64+
@testset "Flux.amd(x) on structured arrays" begin
65+
g1 = Zygote.OneElement(1, (2, 3), axes(ones(4, 5)))
66+
@test Flux.amd(g1) isa ROCMatrix{Int64}
67+
g2 = Zygote.Fill(1f0, 2)
68+
@test Flux.amd(g2) isa ROCArray{Float32, 1}
69+
g3 = transpose(Float32[1 2; 3 4])
70+
@test parent(Flux.amd(g3)) isa ROCMatrix{Float32}
71+
end
6272

63-
# @testset "Flux.onecold gpu" begin
64-
# y = Flux.onehotbatch(ones(3), 1:10) |> Flux.amd
65-
# l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
66-
# @test Flux.onecold(y) isa ROCArray
67-
# @test y[3, :] isa ROCArray
68-
# @test Flux.onecold(y, l) == ['a', 'a', 'a']
69-
# end
73+
@testset "Flux.onecold gpu" begin
74+
y = Flux.onehotbatch(ones(3), 1:10) |> Flux.amd
75+
l = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j']
76+
@test Flux.onecold(y) isa ROCArray
77+
@test y[3, :] isa ROCArray
78+
@test Flux.onecold(y, l) == ['a', 'a', 'a']
79+
end
7080

7181
# FIXME scalar indexing. Needs NNlib.scatter?
7282
# @testset "Flux.onehot gpu" begin

0 commit comments

Comments
 (0)