5050
5151@testset " CRC Tests" begin
5252 dev = cpu_device () # Other devices don't work with FiniteDifferences.jl
53- test_rrule (Adapt. adapt_storage , dev, randn (Float64, 10 ); check_inferred= true )
53+ test_rrule (Adapt. adapt , dev, randn (Float64, 10 ); check_inferred= true )
5454
5555 gdev = gpu_device ()
5656 if ! (gdev isa MetalDevice) # On intel devices causes problems
5757 x = randn (10 )
58- ∂dev, ∂x = Zygote. gradient (sum ∘ Adapt. adapt_storage , gdev, x)
58+ ∂dev, ∂x = Zygote. gradient (sum ∘ Adapt. adapt , gdev, x)
5959 @test ∂dev === nothing
6060 @test ∂x ≈ ones (10 )
6161
6262 x = randn (10 ) |> gdev
63- ∂dev, ∂x = Zygote. gradient (sum ∘ Adapt. adapt_storage , cpu_device (), x)
63+ ∂dev, ∂x = Zygote. gradient (sum ∘ Adapt. adapt , cpu_device (), x)
6464 @test ∂dev === nothing
6565 @test ∂x ≈ gdev (ones (10 ))
6666 @test get_device (∂x) isa parameterless_type (typeof (gdev))
181181 end
182182
183183 @testset " shared parameters" begin
184- # from
185184 x = rand (1 )
186185 m = (; a= x, b= x' )
187186 count = Ref (0 )
@@ -199,11 +198,24 @@ end
199198 y:: Float64
200199 end
201200
202- for x in [1.0 , ' a' , BitsType (1 , 2.0 )]
201+ @testset for x in [1.0 , ' a' , BitsType (1 , 2.0 )]
203202 @test MLDataDevices. isleaf ([x])
204203 @test ! MLDataDevices. isleaf ([x]' )
205204 @test ! MLDataDevices. isleaf (transpose ([x]))
206205 @test ! MLDataDevices. isleaf (PermutedDimsArray ([x;;], (1 , 2 )))
207206 end
208207 end
209208end
209+
210+ @testset " Zygote.gradient(wrapped arrays)" begin
211+ using Zygote
212+
213+ x = rand (4 , 4 )
214+ cdev = cpu_device ()
215+
216+ @test only (Zygote. gradient (x -> sum (abs2, cdev (x)), x' )) isa Matrix{Float64}
217+
218+ gdev = gpu_device ()
219+
220+ @test only (Zygote. gradient (x -> sum (abs2, gdev (x)), x' )) isa Matrix{Float64}
221+ end
0 commit comments