Skip to content

Commit cf6dd42

Browse files
committed
Fix batchnorm & handle regular convolutions
1 parent 37ce734 commit cf6dd42

File tree

8 files changed

+84
-46
lines changed

8 files changed

+84
-46
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,4 +59,4 @@ LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
5959
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
6060

6161
[targets]
62-
test = ["Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]
62+
test = ["AMDGPU", "Test", "Documenter", "IterTools", "LinearAlgebra", "FillArrays", "ComponentArrays"]

ext/AMDGPUExt/AMDGPUExt.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@ import ChainRulesCore
44
import ChainRulesCore: NoTangent
55
import Flux
66
import Flux: FluxCPUAdaptor, FluxAMDAdaptor, _amd, _isleaf, adapt_storage, fmap
7+
import Flux: DenseConvDims, Conv, conv, conv_reshape_bias
8+
import NNlib
79

810
using AMDGPU
911
using Adapt
@@ -32,6 +34,8 @@ end
3234
ChainRulesCore.@non_differentiable check_use_amdgpu()
3335

3436
include("functor.jl")
37+
include("batchnorm.jl")
38+
include("conv.jl")
3539

3640
function __init__()
3741
Flux.AMDGPU_LOADED[] = true

ext/AMDGPUExt/batchnorm.jl

Lines changed: 12 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,23 @@
11
function (b::Flux.BatchNorm)(x::ROCArray{T}) where T <: MIOPENFloat
2-
.(_amd_batchnorm(x, b.γ, b.β; μ=b.μ, σ²=b.σ², ϵ=b.ϵ))
2+
b.λ.(_amd_batchnorm(
3+
x, b.γ, b.β; μ=b.μ, σ²=b.σ², ϵ=b.ϵ,
4+
within_grad=NNlib.within_gradient(x)))
35
end
46

5-
function _amd_batchnorm(x, γ, β; μ, σ², ϵ)
6-
if NNlib.within_gradient(x)
7-
return AMDGPU.MIOpen.batchnorm_training(x, γ, β, μ, σ²; ϵ, iteration=0) # TODO iteration
7+
function _amd_batchnorm(x, γ, β; μ, σ², ϵ, within_grad::Bool)
8+
if within_grad
9+
return AMDGPU.MIOpen.batchnorm_training(x, γ, β, μ, σ²; ϵ=Float64(ϵ), iteration=0) # TODO iteration
810
else
9-
return AMDGPU.MIOpen.batchnorm_inference(x, γ, β, μ, σ²; ϵ)
11+
return AMDGPU.MIOpen.batchnorm_inference(x, γ, β, μ, σ²; ϵ=Float64(ϵ))
1012
end
1113
end
1214

13-
function ChainRulesCore.rrule(::typeof(_amd_batchnorm), x, γ, β; μ, σ², ϵ)
14-
y, μ_saved, ν_saved = _amd_batchnorm(x, γ, β; μ, σ², ϵ)
15+
function ChainRulesCore.rrule(
16+
::typeof(_amd_batchnorm), x, γ, β; μ, σ², ϵ, within_grad::Bool,
17+
)
18+
y, μ_saved, ν_saved = _amd_batchnorm(x, γ, β; μ, σ², ϵ, within_grad)
1519
function _batchnorm_pullback(Δ)
16-
dx, dγ, dβ = MIOpen.∇batchnorm(Δ, x, γ, β, μ_saved, ν_saved)
20+
dx, dγ, dβ = AMDGPU.MIOpen.∇batchnorm(Δ, x, γ, β, μ_saved, ν_saved)
1721
(NoTangent(), dx, dγ, dβ)
1822
end
1923
y, _batchnorm_pullback

ext/AMDGPUExt/conv.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,9 @@
1+
function (c::Conv)(x::T) where T <: ROCArray
2+
Flux._size_check(c, x, ndims(x) - 1 => Flux._channels_in(c))
3+
σ = NNlib.fast_act(c.σ, x)
4+
cdims = DenseConvDims(
5+
x, c.weight; stride=c.stride, padding=c.pad,
6+
dilation=c.dilation, groups=c.groups, flipkernel=true)
7+
xT = Flux._match_eltype(c, x)
8+
σ.(conv(xT, c.weight, cdims) .+ conv_reshape_bias(c))
9+
end

src/Flux.jl

Lines changed: 0 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -73,22 +73,4 @@ include("deprecations.jl")
7373

7474
include("cuda/cuda.jl")
7575

76-
const GPU_BACKENDS = Dict(
77-
"CUDA" => FluxCUDAAdaptor(),
78-
"AMD" => FluxAMDAdaptor())
79-
80-
const GPU_BACKEND = Ref{Union{FluxCUDAAdaptor, FluxAMDAdaptor}}(
81-
GPU_BACKENDS[@load_preference("gpu_backend", "CUDA")])
82-
83-
function gpu_backend!(backend::String)
84-
backend in keys(GPU_BACKENDS) || throw(ArgumentError("""
85-
Unsupported GPU backend: $backend.
86-
Supported backends are: $(keys(GPU_BACKENDS)).
87-
"""))
88-
89-
@set_preferences!("gpu_backend" => backend)
90-
GPU_BACKEND[] = GPU_BACKENDS[@load_preference("gpu_backend")]
91-
return
92-
end
93-
9476
end # module

src/functor.jl

Lines changed: 34 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,30 @@ _isbitsarray(x) = false
177177
_isleaf(::AbstractRNG) = true
178178
_isleaf(x) = _isbitsarray(x) || Functors.isleaf(x)
179179

180+
const GPU_BACKENDS = ("CUDA", "AMD")
181+
const GPU_BACKEND = @load_preference("gpu_backend", "CUDA")
182+
183+
function gpu_backend!(backend::String)
184+
if backend == GPU_BACKEND
185+
@info """
186+
GPU backend is already set to: $backend.
187+
No need to do anything else.
188+
"""
189+
return
190+
end
191+
192+
backend in GPU_BACKENDS || throw(ArgumentError("""
193+
Unsupported GPU backend: $backend.
194+
Supported backends are: $GPU_BACKENDS.
195+
"""))
196+
197+
@set_preferences!("gpu_backend" => backend)
198+
@info """
199+
New GPU backend set: $backend.
200+
Restart your Julia session for this change to take effect!
201+
"""
202+
end
203+
180204
"""
181205
gpu(x)
182206
@@ -209,7 +233,16 @@ CUDA.CuArray{Float32, 2, CUDA.Mem.DeviceBuffer}
209233
```
210234
"""
211235
function gpu(x)
212-
gpu(GPU_BACKEND[], x)
236+
@static if GPU_BACKEND == "CUDA"
237+
gpu(FluxCUDAAdaptor(), x)
238+
elseif GPU_BACKEND == "AMD"
239+
gpu(FluxAMDAdaptor(), x)
240+
else
241+
error("""
242+
Unsupported GPU backend: $GPU_BACKEND.
243+
Supported backends are: $GPU_BACKENDS.
244+
""")
245+
end
213246
end
214247

215248
function gpu(::FluxCUDAAdaptor, x)

test/amd/basic.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -81,8 +81,8 @@ end
8181
@testset "Batchnorm" begin
8282
bn = BatchNorm(3, σ)
8383
for nd in 1:3
84-
x = rand(Float32, fill(16, nd - 1)..., 3, 4)
85-
amdgputest(bn, x; atol=1f-3)
84+
x = rand(Float32, fill(2, nd - 1)..., 3, 4)
85+
amdgputest(bn, x; atol=1f-3, allow_nothing=true)
8686
end
8787
end
8888

test/amd/utils.jl

Lines changed: 22 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,6 @@
1-
function amdgputest(model, xs...; checkgrad=true, atol=1e-6)
1+
function amdgputest(
2+
model, xs...; checkgrad=true, atol=1e-6, allow_nothing::Bool = false,
3+
)
24
cpu_model = model
35
gpu_model = Flux.gpu(model)
46

@@ -12,36 +14,40 @@ function amdgputest(model, xs...; checkgrad=true, atol=1e-6)
1214
if checkgrad
1315
cpu_grad = gradient(m -> sum(m(cpu_in...)), cpu_model)
1416
gpu_grad = gradient(m -> sum(m(gpu_in...)), gpu_model)
15-
amd_check_grad(gpu_grad, cpu_grad; atol)
17+
amd_check_grad(gpu_grad, cpu_grad; atol, allow_nothing)
1618
end
1719
end
1820

19-
function amd_check_grad(g_gpu, g_cpu; atol)
20-
@show g_gpu g_cpu
21-
@test false
21+
function amd_check_grad(g_gpu, g_cpu; atol, allow_nothing)
22+
allow_nothing && return
23+
@show g_gpu g_cpu
24+
@test false
2225
end
2326

24-
amd_check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue, atol) =
25-
amd_check_grad(g_gpu[], g_cpu[]; atol)
26-
amd_check_grad(g_gpu::Nothing, g_cpu::Nothing; atol) =
27+
amd_check_grad(g_gpu::Base.RefValue, g_cpu::Base.RefValue, atol, allow_nothing) =
28+
amd_check_grad(g_gpu[], g_cpu[]; atol, allow_nothing)
29+
amd_check_grad(g_gpu::Nothing, g_cpu::Nothing; atol, allow_nothing) =
2730
@test true
28-
amd_check_grad(g_gpu::Float32, g_cpu::Float32; atol) =
31+
amd_check_grad(g_gpu::Float32, g_cpu::Float32; atol, allow_nothing) =
2932
@test g_cpu g_gpu atol=atol
30-
amd_check_grad(g_gpu::ROCArray{Float32}, g_cpu::Array{Float32}; atol) =
31-
@test g_cpu collect(g_gpu) atol=atol
3233
amd_check_grad(
33-
g_gpu::ROCArray{Float32}, g_cpu::Zygote.FillArrays.AbstractFill; atol,
34+
g_gpu::ROCArray{Float32}, g_cpu::Array{Float32};
35+
atol, allow_nothing,
36+
) = @test g_cpu collect(g_gpu) atol=atol
37+
amd_check_grad(
38+
g_gpu::ROCArray{Float32}, g_cpu::Zygote.FillArrays.AbstractFill;
39+
atol, allow_nothing
3440
) = @test collect(g_cpu) collect(g_gpu) atol=atol
3541

36-
function amd_check_grad(g_gpu::Tuple, g_cpu::Tuple; atol)
42+
function amd_check_grad(g_gpu::Tuple, g_cpu::Tuple; atol, allow_nothing)
3743
for (v1, v2) in zip(g_gpu, g_cpu)
38-
amd_check_grad(v1, v2; atol)
44+
amd_check_grad(v1, v2; atol, allow_nothing)
3945
end
4046
end
4147

42-
function amd_check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple; atol)
48+
function amd_check_grad(g_gpu::NamedTuple, g_cpu::NamedTuple; atol, allow_nothing)
4349
for ((k1, v1), (k2, v2)) in zip(pairs(g_gpu), pairs(g_cpu))
4450
@test k1 == k2
45-
amd_check_grad(v1, v2; atol)
51+
amd_check_grad(v1, v2; atol, allow_nothing)
4652
end
4753
end

0 commit comments

Comments
 (0)