Skip to content

Commit de91f9a

Browse files
committed
Add batchnorm
1 parent d0eb6a0 commit de91f9a

File tree

3 files changed

+28
-0
lines changed

3 files changed

+28
-0
lines changed

ext/AMDGPUExt/AMDGPUExt.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
module AMDGPUExt
22

33
import ChainRulesCore
4+
import ChainRulesCore: NoTangent
45
import Flux
56
import Flux: FluxCPUAdaptor, _amd, _isleaf, adapt_storage, fmap
67

@@ -9,6 +10,7 @@ using Adapt
910
using Random
1011
using Zygote
1112

13+
const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat
1214
const USE_AMDGPU = Ref{Union{Nothing, Bool}}(nothing)
1315

1416
function check_use_amdgpu()

ext/AMDGPUExt/batchnorm.jl

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
function (b::Flux.BatchNorm)(x::ROCArray{T}) where T <: MIOPENFloat
2+
.(_amd_batchnorm(x, b.γ, b.β; μ=b.μ, σ²=b.σ², ϵ=b.ϵ))
3+
end
4+
5+
function _amd_batchnorm(x, γ, β; μ, σ², ϵ)
6+
if NNlib.within_gradient(x)
7+
return AMDGPU.MIOpen.batchnorm_training(x, γ, β, μ, σ²; ϵ, iteration=0) # TODO iteration
8+
else
9+
return AMDGPU.MIOpen.batchnorm_inference(x, γ, β, μ, σ²; ϵ)
10+
end
11+
end
12+
13+
function ChainRulesCore.rrule(::typeof(_amd_batchnorm), x, γ, β; μ, σ², ϵ)
14+
y, μ_saved, ν_saved = _amd_batchnorm(x, γ, β; μ, σ², ϵ)
15+
function _batchnorm_pullback(Δ)
16+
dx, dγ, dβ = MIOpen.∇batchnorm(Δ, x, γ, β, μ_saved, ν_saved)
17+
(NoTangent(), dx, dγ, dβ)
18+
end
19+
y, _batchnorm_pullback
20+
end

test/amd/basic.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -78,6 +78,12 @@ end
7878
@test Flux.onecold(y, l) == ['a', 'a', 'a']
7979
end
8080

81+
@testset "Batchnorm" begin
82+
bn = BatchNorm(3, σ)
83+
x = rand(Float32, 16, 16, 3, 4)
84+
amdgputest(bn, x; atol=1f-3)
85+
end
86+
8187
# FIXME scalar indexing. Needs NNlib.scatter?
8288
# @testset "Flux.onehot gpu" begin
8389
# y = Flux.onehotbatch(ones(3), 1:2) |> Flux.amd

0 commit comments

Comments
 (0)