diff --git a/Project.toml b/Project.toml index a01cab0f9f..2832ddb8e9 100644 --- a/Project.toml +++ b/Project.toml @@ -38,7 +38,7 @@ ProgressLogging = "0.1" Reexport = "0.2, 1.0" SpecialFunctions = "1.8.2, 2.1.2" StatsBase = "0.33" -Zygote = "0.6.34" +Zygote = "0.6.49" julia = "1.6" [extras] diff --git a/src/functor.jl b/src/functor.jl index 13adbe13ff..986574e33e 100644 --- a/src/functor.jl +++ b/src/functor.jl @@ -89,7 +89,10 @@ function params(m...) end # Allows caching of the parameters when params is called within gradient() to fix #2040. -@non_differentiable params(m...) +# @non_differentiable params(m...) # https://github.com/FluxML/Flux.jl/pull/2054 +# That speeds up implicit use, and silently breaks explicit use. +# From @macroexpand Zygote.@nograd params(m...) and https://github.com/FluxML/Zygote.jl/pull/1248 +Zygote._pullback(::Zygote.Context{true}, ::typeof(params), m...) = params(m), _ -> nothing struct FluxCUDAAdaptor end adapt_storage(to::FluxCUDAAdaptor, x) = CUDA.cu(x) diff --git a/test/utils.jl b/test/utils.jl index 20359daf25..fbb7f7d9d1 100644 --- a/test/utils.jl +++ b/test/utils.jl @@ -270,6 +270,20 @@ end @test size.(Flux.params(m)) == [(2,), (1, 2)] end +@testset "params gradient" begin + m = (x=[1,2.0], y=[3.0]); + + # Explicit -- was broken by #2054 + gnew = gradient(m -> (sum(norm, Flux.params(m))), m)[1] + @test gnew.x ≈ [0.4472135954999579, 0.8944271909999159] + @test gnew.y ≈ [1.0] + + # Implicit + gold = gradient(() -> (sum(norm, Flux.params(m))), Flux.params(m)) + @test gold[m.x] ≈ [0.4472135954999579, 0.8944271909999159] + @test gold[m.y] ≈ [1.0] +end + @testset "Precision" begin m = Chain(Dense(10, 5, relu), Dense(5, 2)) x64 = rand(Float64, 10)