Skip to content

Commit b1a3a93

Browse files
authored
Update deps & bump to 0.16.1 (#2574)
* Update deps * [AMDGPU] Correct batchnorm rrule * Mark test as unbroken
1 parent 44695a0 commit b1a3a93

File tree

4 files changed

+6
-11
lines changed

4 files changed

+6
-11
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,6 @@ Reexport = "1.0"
6464
Setfield = "1.1"
6565
SpecialFunctions = "2.1.2"
6666
Statistics = "1"
67-
Zygote = "0.6.67"
67+
Zygote = "0.6.67, 0.7"
6868
cuDNN = "1"
6969
julia = "1.10"

ext/FluxAMDGPUExt/FluxAMDGPUExt.jl

Lines changed: 2 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
module FluxAMDGPUExt
22

33
import ChainRulesCore
4-
import ChainRulesCore: NoTangent
4+
import ChainRulesCore: NoTangent, unthunk
55
import Flux
66
import Flux: fmap, DenseConvDims, Conv, ConvTranspose, conv, conv_reshape_bias
77
import NNlib
8+
89
using MLDataDevices
910
using AMDGPU
1011
using Adapt
@@ -13,14 +14,8 @@ using Zygote
1314

1415
const MIOPENFloat = AMDGPU.MIOpen.MIOPENFloat
1516

16-
1717
include("functor.jl")
1818
include("batchnorm.jl")
1919
include("conv.jl")
2020

21-
22-
# TODO
23-
# fail early if input to the model is not on the device (e.g. on the host)
24-
# otherwise we get very cryptic errors & segfaults at the rocBLAS level
25-
2621
end

ext/FluxAMDGPUExt/batchnorm.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ function ChainRulesCore.rrule(
1717
)
1818
y, μ_saved, ν_saved = _amdgpu_batchnorm(x, γ, β; μ, σ², ϵ, within_grad)
1919
function _batchnorm_pullback(Δ)
20-
dx, dγ, dβ = AMDGPU.MIOpen.∇batchnorm(Δ, x, γ, β, μ_saved, ν_saved)
20+
dx, dγ, dβ = AMDGPU.MIOpen.∇batchnorm(unthunk(Δ), x, γ, β, μ_saved, ν_saved)
2121
(NoTangent(), dx, dγ, dβ)
2222
end
2323
y, _batchnorm_pullback

test/ext_cuda/cuda.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -106,7 +106,7 @@ end
106106
# Trivial functions
107107
@test gradient(x -> sum(abs, gpu(x)), a)[1] isa Matrix
108108
@test gradient(x -> sum(gpu(x)), a)[1] isa Matrix
109-
@test_broken gradient(x -> sum(gpu(x)), a')[1] isa Matrix # sum(::Adjoint{T,CuArray}) makes a Fill
109+
@test gradient(x -> sum(gpu(x)), a')[1] isa Matrix # sum(::Adjoint{T,CuArray}) makes a Fill
110110
@test gradient(x -> sum(abs, cpu(x)), ca)[1] isa CuArray
111111
# This test should really not go through indirections and pull out Fills for efficiency
112112
# but we forcefully materialise. TODO: remove materialising CuArray here
@@ -207,4 +207,4 @@ end
207207
@test collect(post2) isa Vector{<:NamedTuple{(:x, :y)}} # collect makes no sense, but check eltype?
208208

209209
# @test_throws Exception gpu(((x = Flux.DataLoader(X), y = Y),))
210-
end
210+
end

0 commit comments

Comments
 (0)