Skip to content

Commit 746caa5

Browse files
committed
Handle ConvTranspose correctly & refactor
1 parent 0a9daf7 commit 746caa5

File tree

5 files changed

+53
-28
lines changed

5 files changed

+53
-28
lines changed

ext/AMDGPUExt/AMDGPUExt.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@ import ChainRulesCore
44
import ChainRulesCore: NoTangent
55
import Flux
66
import Flux: FluxCPUAdaptor, FluxAMDAdaptor, _amd, _isleaf, adapt_storage, fmap
7-
import Flux: DenseConvDims, Conv, conv, conv_reshape_bias
7+
import Flux: DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
88
import NNlib
99

1010
using AMDGPU

ext/AMDGPUExt/conv.jl

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,21 @@
1-
function (c::Conv)(x::T) where T <: ROCArray
2-
Flux._size_check(c, x, ndims(x) - 1 => Flux._channels_in(c))
3-
σ = NNlib.fast_act(c.σ, x)
4-
cdims = DenseConvDims(
1+
function Flux.conv_dims(c::Conv, x::T) where T <: ROCArray
2+
DenseConvDims(
53
x, c.weight; stride=c.stride, padding=c.pad,
64
dilation=c.dilation, groups=c.groups, flipkernel=true)
7-
xT = Flux._match_eltype(c, x)
8-
σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
5+
end
6+
7+
function Flux.conv_transpose_dims(c::ConvTranspose, x::T) where T <: ROCArray
8+
# Calculate size of "input", from ∇conv_data()'s perspective...
9+
combined_pad = (c.pad[1:2:end] .+ c.pad[2:2:end])
10+
I = (size(x)[1:end - 2] .- 1) .* c.stride .+ 1 .+
11+
(size(c.weight)[1:end - 2] .- 1) .* c.dilation .- combined_pad
12+
C_in = size(c.weight)[end - 1] * c.groups
13+
batch_size = size(x)[end]
14+
15+
# Create DenseConvDims() that looks like the corresponding conv().
16+
w_size = size(c.weight)
17+
DenseConvDims(
18+
(I..., C_in, batch_size), w_size;
19+
stride=c.stride, padding=c.pad, dilation=c.dilation,
20+
groups=c.groups, flipkernel=true)
921
end

ext/AMDGPUExt/functor.jl

Lines changed: 31 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -44,9 +44,12 @@ end
4444

4545
# CPU -> GPU
4646

47-
function adapt_storage(to::FluxAMDAdaptor, m::Flux.Conv)
47+
_conv_basetype(c::Type{C}) where C <: Conv = Conv
48+
_conv_basetype(c::Type{C}) where C <: ConvTranspose = ConvTranspose
49+
50+
function adapt_storage(to::FluxAMDAdaptor, m::C) where C <: Union{Conv, ConvTranspose}
4851
flipped_weight = reverse(m.weight; dims=ntuple(i -> i, ndims(m.weight) - 2))
49-
Flux.Conv(
52+
_conv_basetype(C)(
5053
Adapt.adapt(to, m.σ),
5154
Adapt.adapt(to, flipped_weight),
5255
Adapt.adapt(to, m.bias),
@@ -55,26 +58,43 @@ end
5558

5659
# Don't adapt again.
5760
function adapt_storage(
58-
to::FluxAMDAdaptor, m::Flux.Conv{N, M, F, A, V},
61+
to::FluxAMDAdaptor, m::Conv{N, M, F, A, V},
5962
) where {N, M, F, A <: ROCArray, V}
6063
return m
6164
end
6265

63-
_amd(m::Flux.Conv) = adapt_storage(FluxAMDAdaptor(), m)
66+
function adapt_storage(
67+
to::FluxAMDAdaptor, m::ConvTranspose{N, M, F, A, V},
68+
) where {N, M, F, A <: ROCArray, V}
69+
return m
70+
end
71+
72+
_amd(m::Union{Conv, ConvTranspose}) = adapt_storage(FluxAMDAdaptor(), m)
6473

6574
# GPU -> CPU
6675

67-
function Flux.cpu(m::Flux.Conv{N, M, F, A, V}) where {N, M, F, A <: ROCArray, V}
76+
function Flux.cpu(m::Conv{N, M, F, A, V}) where {N, M, F, A <: ROCArray, V}
77+
adapt_storage(FluxCPUAdaptor(), m)
78+
end
79+
80+
function Flux.cpu(m::ConvTranspose{N, M, F, A, V}) where {N, M, F, A <: ROCArray, V}
6881
adapt_storage(FluxCPUAdaptor(), m)
6982
end
7083

7184
function adapt_storage(
72-
to::FluxCPUAdaptor, m::Flux.Conv{N, M, F, A, V},
85+
to::FluxCPUAdaptor, m::Conv{N, M, F, A, V},
7386
) where {N, M, F, A <: ROCArray, V}
7487
dims = ntuple(i -> i, ndims(m.weight) - 2)
75-
Flux.Conv(
76-
Adapt.adapt(to, m.σ),
77-
reverse(Adapt.adapt(to, m.weight); dims),
78-
Adapt.adapt(to, m.bias),
79-
m.stride, m.pad, m.dilation, m.groups)
88+
Conv(
89+
Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims),
90+
Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups)
91+
end
92+
93+
function adapt_storage(
94+
to::FluxCPUAdaptor, m::ConvTranspose{N, M, F, A, V},
95+
) where {N, M, F, A <: ROCArray, V}
96+
dims = ntuple(i -> i, ndims(m.weight) - 2)
97+
ConvTranspose(
98+
Adapt.adapt(to, m.σ), reverse(Adapt.adapt(to, m.weight); dims),
99+
Adapt.adapt(to, m.bias), m.stride, m.pad, m.dilation, m.groups)
80100
end

test/amd/basic.jl

Lines changed: 2 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -25,8 +25,8 @@ end
2525
end
2626

2727
@testset "Convolution" begin
28-
for nd in 1:3
29-
m = Conv(tuple(fill(2, nd)...), 3 => 4) |> f32
28+
for conv_type in (Conv, ConvTranspose), nd in 1:3
29+
m = conv_type(tuple(fill(2, nd)...), 3 => 4) |> f32
3030
x = rand(Float32, fill(10, nd)..., 3, 5)
3131

3232
# Ensure outputs are the same.
@@ -85,10 +85,3 @@ end
8585
amdgputest(bn, x; atol=1f-3, allow_nothing=true)
8686
end
8787
end
88-
89-
# FIXME scalar indexing. Needs NNlib.scatter?
90-
# @testset "Flux.onehot gpu" begin
91-
# y = Flux.onehotbatch(ones(3), 1:2) |> Flux.gpu
92-
# x = rand(3, 2) |> Flux.gpu
93-
# @test gradient(x -> sum(x * y), x)[1] isa ROCArray
94-
# end

test/amd/utils.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ amd_check_grad(
3737
amd_check_grad(
3838
g_gpu::ROCArray{Float32}, g_cpu::Zygote.FillArrays.AbstractFill;
3939
atol, allow_nothing
40-
) = @test collect(g_cpu) collect(g_gpu) atol=atol
40+
) = @test g_cpu collect(g_gpu) atol=atol
4141

4242
function amd_check_grad(g_gpu::Tuple, g_cpu::Tuple; atol, allow_nothing)
4343
for (v1, v2) in zip(g_gpu, g_cpu)

0 commit comments

Comments
 (0)