Skip to content
This repository was archived by the owner on Nov 4, 2024. It is now read-only.

Commit ece7ba2

Browse files
authored
fix: correctly handle adjoints of wrapped arrays (#90)
* fix: correctly handle adjoints of wrapped arrays * fix: use fast paths for adapt * fix: adapt ranges to JuliaGPU/Adapt.jl#86
1 parent e9a2ed7 commit ece7ba2

File tree

9 files changed

+44
-38
lines changed

9 files changed

+44
-38
lines changed

Project.toml

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,13 +1,12 @@
11
name = "MLDataDevices"
22
uuid = "7e8f7934-dd98-4c1a-8fe8-92b47a384d40"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.4.1"
4+
version = "1.4.2"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"
99
Functors = "d9f16b24-f501-4c13-a1f2-28368ffc5196"
10-
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
1110
Preferences = "21216c6a-2e73-6563-6e65-726566657250"
1211
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
1312

@@ -47,14 +46,13 @@ MLDataDevicesoneAPIExt = ["GPUArrays", "oneAPI"]
4746

4847
[compat]
4948
AMDGPU = "0.9.6, 1"
50-
Adapt = "4"
49+
Adapt = "4.1"
5150
CUDA = "5.2"
5251
ChainRulesCore = "1.23"
5352
Compat = "4.15"
5453
FillArrays = "1"
5554
Functors = "0.4.8"
5655
GPUArrays = "10, 11"
57-
LinearAlgebra = "1.10"
5856
MLUtils = "0.4.4"
5957
Metal = "1"
6058
Preferences = "1.4"
Lines changed: 12 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,27 @@
11
module MLDataDevicesChainRulesCoreExt
22

33
using Adapt: Adapt
4-
using ChainRulesCore: ChainRulesCore, NoTangent, @non_differentiable
4+
using ChainRulesCore: ChainRulesCore, NoTangent, ProjectTo, @non_differentiable
55

66
using MLDataDevices: AbstractDevice, UnknownDevice, get_device, get_device_type
77

88
@non_differentiable get_device(::Any)
99
@non_differentiable get_device_type(::Any)
1010

11-
function ChainRulesCore.rrule(
12-
::typeof(Adapt.adapt_storage), to::AbstractDevice, x::AbstractArray)
13-
∇adapt_storage = let dev = get_device(x)
14-
if dev === nothing || dev isa UnknownDevice
11+
function ChainRulesCore.rrule(::typeof(Adapt.adapt), to::AbstractDevice, x::AbstractArray)
12+
dev = get_device(x)
13+
y = Adapt.adapt_storage(to, x)
14+
if dev === nothing || dev isa UnknownDevice
15+
dev isa UnknownDevice &&
1516
@warn "`get_device(::$(typeof(x)))` returned `$(dev)`." maxlog=1
16-
Δ -> (NoTangent(), NoTangent(), Δ)
17-
else
18-
Δ -> (NoTangent(), NoTangent(), dev(Δ))
17+
∇adapt_storage_unknown = Δ -> (NoTangent(), NoTangent(), Δ)
18+
return y, ∇adapt_storage_unknown
19+
else
20+
∇adapt_storage = let dev = dev, x = x
21+
Δ -> (NoTangent(), NoTangent(), ProjectTo(x)(dev(Δ)))
1922
end
23+
return Adapt.adapt_storage(to, x), ∇adapt_storage
2024
end
21-
return Adapt.adapt_storage(to, x), ∇adapt_storage
2225
end
2326

2427
end

src/MLDataDevices.jl

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,6 @@ using Functors: Functors, fleaves
55
using Preferences: @delete_preferences!, @load_preference, @set_preferences!
66
using Random: AbstractRNG, Random
77
using Compat: @compat
8-
using LinearAlgebra: Transpose, Adjoint
98

109
abstract type AbstractDevice <: Function end
1110
abstract type AbstractCPUDevice <: AbstractDevice end

src/public.jl

Lines changed: 5 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -342,8 +342,10 @@ for (dev) in (:CPU, :CUDA, :AMDGPU, :Metal, :oneAPI, :XLA)
342342
ldev = Symbol(dev, :Device)
343343
@eval begin
344344
function (D::$(ldev))(x::AbstractArray{T}) where {T}
345-
return (isbitstype(T) || Internal.special_aos(x)) ? Adapt.adapt(D, x) :
346-
map(D, x)
345+
if isbitstype(T) || Internal.special_aos(x) || x isa Adapt.WrappedArray
346+
return Adapt.adapt(D, x)
347+
end
348+
return map(D, x)
347349
end
348350
(D::$(ldev))(x::Union{Tuple, NamedTuple}) = map(D, x)
349351
function (D::$(ldev))(x)
@@ -373,14 +375,6 @@ for T in (AMDGPUDevice, CUDADevice, MetalDevice, oneAPIDevice, XLADevice)
373375
end
374376
end
375377

376-
Adapt.adapt_storage(::CPUDevice, x::AbstractRange) = x
377-
Adapt.adapt_storage(::XLADevice, x::AbstractRange) = x
378-
# Prevent Ambiguity
379-
for T in (AMDGPUDevice, AMDGPUDevice{Nothing}, CUDADevice,
380-
CUDADevice{Nothing}, MetalDevice, oneAPIDevice)
381-
@eval Adapt.adapt_storage(to::$(T), x::AbstractRange) = Adapt.adapt(to, collect(x))
382-
end
383-
384378
"""
385379
isleaf(x) -> Bool
386380
@@ -399,4 +393,4 @@ If `MLDataDevices.isleaf(x::T)` is not defined, then it will fall back to `Funct
399393
isleaf(x) = Functors.isleaf(x)
400394

401395
isleaf(::AbstractArray{T}) where {T} = isbitstype(T)
402-
isleaf(::Union{Transpose, Adjoint, PermutedDimsArray}) = false
396+
isleaf(::Adapt.WrappedArray) = false

test/amdgpu_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ using FillArrays, Zygote # Extensions
5353
@test ps_xpu.mixed[1] isa Float32
5454
@test ps_xpu.mixed[2] isa Float64
5555
@test ps_xpu.mixed[3] isa aType
56-
@test ps_xpu.range isa aType
56+
@test ps_xpu.range isa AbstractRange
5757
@test ps_xpu.e == ps.e
5858
@test ps_xpu.d == ps.d
5959
@test ps_xpu.rng_default isa rngType
@@ -83,7 +83,7 @@ using FillArrays, Zygote # Extensions
8383
@test ps_cpu.mixed[1] isa Float32
8484
@test ps_cpu.mixed[2] isa Float64
8585
@test ps_cpu.mixed[3] isa Array
86-
@test ps_cpu.range isa Array
86+
@test ps_cpu.range isa AbstractRange
8787
@test ps_cpu.e == ps.e
8888
@test ps_cpu.d == ps.d
8989
@test ps_cpu.rng_default isa Random.TaskLocalRNG

test/cuda_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ using FillArrays, Zygote # Extensions
5252
@test ps_xpu.mixed[1] isa Float32
5353
@test ps_xpu.mixed[2] isa Float64
5454
@test ps_xpu.mixed[3] isa aType
55-
@test ps_xpu.range isa aType
55+
@test ps_xpu.range isa AbstractRange
5656
@test ps_xpu.e == ps.e
5757
@test ps_xpu.d == ps.d
5858
@test ps_xpu.rng_default isa rngType
@@ -82,7 +82,7 @@ using FillArrays, Zygote # Extensions
8282
@test ps_cpu.mixed[1] isa Float32
8383
@test ps_cpu.mixed[2] isa Float64
8484
@test ps_cpu.mixed[3] isa Array
85-
@test ps_cpu.range isa Array
85+
@test ps_cpu.range isa AbstractRange
8686
@test ps_cpu.e == ps.e
8787
@test ps_cpu.d == ps.d
8888
@test ps_cpu.rng_default isa Random.TaskLocalRNG

test/metal_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions
5151
@test ps_xpu.mixed[1] isa Float32
5252
@test ps_xpu.mixed[2] isa Float64
5353
@test ps_xpu.mixed[3] isa aType
54-
@test ps_xpu.range isa aType
54+
@test ps_xpu.range isa AbstractRange
5555
@test ps_xpu.e == ps.e
5656
@test ps_xpu.d == ps.d
5757
@test ps_xpu.rng_default isa rngType
@@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions
8181
@test ps_cpu.mixed[1] isa Float32
8282
@test ps_cpu.mixed[2] isa Float64
8383
@test ps_cpu.mixed[3] isa Array
84-
@test ps_cpu.range isa Array
84+
@test ps_cpu.range isa AbstractRange
8585
@test ps_cpu.e == ps.e
8686
@test ps_cpu.d == ps.d
8787
@test ps_cpu.rng_default isa Random.TaskLocalRNG

test/misc_tests.jl

Lines changed: 17 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -50,17 +50,17 @@ end
5050

5151
@testset "CRC Tests" begin
5252
dev = cpu_device() # Other devices don't work with FiniteDifferences.jl
53-
test_rrule(Adapt.adapt_storage, dev, randn(Float64, 10); check_inferred=true)
53+
test_rrule(Adapt.adapt, dev, randn(Float64, 10); check_inferred=true)
5454

5555
gdev = gpu_device()
5656
if !(gdev isa MetalDevice) # On intel devices causes problems
5757
x = randn(10)
58-
∂dev, ∂x = Zygote.gradient(sum Adapt.adapt_storage, gdev, x)
58+
∂dev, ∂x = Zygote.gradient(sum Adapt.adapt, gdev, x)
5959
@test ∂dev === nothing
6060
@test ∂x ones(10)
6161

6262
x = randn(10) |> gdev
63-
∂dev, ∂x = Zygote.gradient(sum Adapt.adapt_storage, cpu_device(), x)
63+
∂dev, ∂x = Zygote.gradient(sum Adapt.adapt, cpu_device(), x)
6464
@test ∂dev === nothing
6565
@test ∂x gdev(ones(10))
6666
@test get_device(∂x) isa parameterless_type(typeof(gdev))
@@ -181,7 +181,6 @@ end
181181
end
182182

183183
@testset "shared parameters" begin
184-
# from
185184
x = rand(1)
186185
m = (; a=x, b=x')
187186
count = Ref(0)
@@ -199,11 +198,24 @@ end
199198
y::Float64
200199
end
201200

202-
for x in [1.0, 'a', BitsType(1, 2.0)]
201+
@testset for x in [1.0, 'a', BitsType(1, 2.0)]
203202
@test MLDataDevices.isleaf([x])
204203
@test !MLDataDevices.isleaf([x]')
205204
@test !MLDataDevices.isleaf(transpose([x]))
206205
@test !MLDataDevices.isleaf(PermutedDimsArray([x;;], (1, 2)))
207206
end
208207
end
209208
end
209+
210+
@testset "Zygote.gradient(wrapped arrays)" begin
211+
using Zygote
212+
213+
x = rand(4, 4)
214+
cdev = cpu_device()
215+
216+
@test only(Zygote.gradient(x -> sum(abs2, cdev(x)), x')) isa Matrix{Float64}
217+
218+
gdev = gpu_device()
219+
220+
@test only(Zygote.gradient(x -> sum(abs2, gdev(x)), x')) isa Matrix{Float64}
221+
end

test/oneapi_tests.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -51,7 +51,7 @@ using FillArrays, Zygote # Extensions
5151
@test ps_xpu.mixed[1] isa Float32
5252
@test ps_xpu.mixed[2] isa Float64
5353
@test ps_xpu.mixed[3] isa aType
54-
@test ps_xpu.range isa aType
54+
@test ps_xpu.range isa AbstractRange
5555
@test ps_xpu.e == ps.e
5656
@test ps_xpu.d == ps.d
5757
@test ps_xpu.rng_default isa rngType
@@ -81,7 +81,7 @@ using FillArrays, Zygote # Extensions
8181
@test ps_cpu.mixed[1] isa Float32
8282
@test ps_cpu.mixed[2] isa Float64
8383
@test ps_cpu.mixed[3] isa Array
84-
@test ps_cpu.range isa Array
84+
@test ps_cpu.range isa AbstractRange
8585
@test ps_cpu.e == ps.e
8686
@test ps_cpu.d == ps.d
8787
@test ps_cpu.rng_default isa Random.TaskLocalRNG

0 commit comments

Comments
 (0)