Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 9 additions & 5 deletions src/stage1/generated.jl
Original file line number Diff line number Diff line change
Expand Up @@ -206,20 +206,20 @@ function (::∂⃖{N})(f::Core.IntrinsicFunction, args...) where {N}
end

# The static parameter on `f` disables the compileable_sig heuristic
function (::∂⃖{N})(f::T, args...) where {T, N}
function (::∂⃖{N})(f::T, args...; kwargs...) where {T, N}
if N == 1
# Base case (inlined to avoid ambiguities with manually specified
# higher order rules)
z = rrule(DiffractorRuleConfig(), f, args...)
z = rrule(DiffractorRuleConfig(), f, args...; kwargs...)
if z === nothing
return ∂⃖recurse{1}()(f, args...)
return ∂⃖recurse{1}()(f, args...; kwargs...)
end
return z
else
∂⃖p = ∂⃖{N-1}()
@destruct z, z̄ = ∂⃖p(rrule, f, args...)
@destruct z, z̄ = ∂⃖p(rrule, f, args...; kwargs...)
if z === nothing
return ∂⃖recurse{N}()(f, args...)
return ∂⃖recurse{N}()(f, args...; kwargs...)
else
return ∂⃖rrule{N}()(z, z̄)
end
Expand All @@ -244,6 +244,10 @@ struct KwFunc{T,S}
end
(kw::KwFunc)(args...) = kw.kwf(args...)

function ChainRulesCore.rrule(::typeof(Core.kwcall), kwargs, f, args...)
rrule(KwFunc(f), kwargs, f, args...)
Copy link
Member

@oxinabox oxinabox Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

why isn't this

Suggested change
rrule(KwFunc(f), kwargs, f, args...)
rrule(f, args...; kwargs...)

is that the same, or is it different?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Copy link
Contributor Author

@nmheim nmheim Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I am not exactly sure why the KwFunc struct is needed though.. it seems like could be done via rrule(::typeof(Core.kwcall), kwargs, f, args...) directly?

Copy link
Contributor Author

@nmheim nmheim Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing the KwFunc and dispatching on kwcall directly seems to work, but I was afraid to remove something which I don't exactly understand

function ChainRulesCore.rrule(::typeof(Core.kwcall), kwargs, f, args...)
    r = Core.kwfunc(rrule)(kwargs, rrule, f, args...)
    if r === nothing
        return nothing
    end
    x, back = r
    x, Δ->begin
        (NoTangent(), NoTangent(), back(Δ)...)
    end
end

Copy link
Member

@oxinabox oxinabox Feb 13, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oh this might be the thing that is there to avoid ADing through so much of the kwarg machinery in the nested AD case

end

function ChainRulesCore.rrule(::typeof(Core.kwfunc), f)
KwFunc(f), Δ->(NoTangent(), Δ)
end
Expand Down
81 changes: 31 additions & 50 deletions test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -99,10 +99,11 @@ end
@test gradcheck(x -> sum(i->x[i], 1:length(x)), randn(10)) # issue #231
@test gradcheck(x -> sum((i->x[i]).(1:length(x))), randn(10))
@test gradcheck(X -> sum(x -> x^2, X), randn(10))
@test jacobicheck(x -> sum(x, dims = (2, 3)), (3,4,5))
@test jacobicheck(x -> sum(abs2, x; dims=1), randn(4, 3, 2))

# MethodError: no method matching copy(::Nothing)
@test_broken jacobicheck(x -> sum(x, dims = (2, 3)), (3,4,5))
@test_broken jacobicheck(x -> sum(abs2, x; dims=1), randn(4, 3, 2))
# TODO: interesting that this is the only one that is not fixed
@test_broken gradcheck(X -> sum(sum(x -> x^2, X; dims=1)), randn(10)) # issue #681

# Non-differentiable sum of booleans
Expand All @@ -119,23 +120,15 @@ end

@test gradcheck(x -> prod(x), (3,4))
@test gradient(x -> prod(x), (1,2,3))[1] == (6,3,2)

# MethodError: no method matching copy(::Nothing)
@test_broken jacobicheck(x -> prod(x, dims = (2, 3)), (3,4,5))
@test jacobicheck(x -> prod(x, dims = (2, 3)), (3,4,5))
end

@testset "cumsum" begin
@test jacobicheck(x -> cumsum(x), (4,))

# TypeError: in typeassert, expected Int64, got a value of type Nothing
@test_broken jacobicheck(x -> cumsum(x, dims=2), (3,4,5))
@test_broken jacobicheck(x -> cumsum(x, dims=3), (3,4)) # trivial

# MethodError: no method matching copy(::Nothing)
@test_broken jacobicheck(x -> cumsum(x, dims=1), (3,))

# Rewrite reached intrinsic function bitcast. Missing rule?
@test_broken jacobicheck(x -> cumsum(x, dims=3), (5,)) # trivial
@test jacobicheck(x -> cumsum(x, dims=2), (3,4,5))
@test jacobicheck(x -> cumsum(x, dims=3), (3,4)) # trivial
@test jacobicheck(x -> cumsum(x, dims=1), (3,))
@test jacobicheck(x -> cumsum(x, dims=3), (5,)) # trivial
end

@testset "getindex" begin
Expand Down Expand Up @@ -221,8 +214,7 @@ end
@test jacobicheck(x -> reverse(x), rand(17))
@test jacobicheck(x -> reverse(x, 8), rand(17))
@test jacobicheck(x -> reverse(x, 8, 13), rand(17))
# Rewrite reached intrinsic function bitcast. Missing rule?
@test_broken jacobicheck(x -> reverse(x, dims=2), rand(17, 42))
@test jacobicheck(x -> reverse(x, dims=2), rand(17, 42))
end

@testset "permutedims" begin
Expand All @@ -237,11 +229,9 @@ end
end

@testset "repeat" begin
# MethodError: no method matching copy(::Nothing)
@test_broken jacobicheck(x -> repeat(x; inner=2), rand(5))
@test_broken jacobicheck(x -> repeat(x; inner=2, outer=3), rand(5))
@test_broken jacobicheck(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))

@test jacobicheck(x -> repeat(x; inner=2), rand(5))
@test jacobicheck(x -> repeat(x; inner=2, outer=3), rand(5))
@test jacobicheck(x -> repeat(x; inner=(2,2,1), outer=(1,1,3)), rand(5,4,3))
@test jacobicheck(x -> repeat(x, 3), rand(5))
@test jacobicheck(x -> repeat(x, 2, 3), rand(5))
@test jacobicheck(x -> repeat(x, 5), rand(5,7))
Expand Down Expand Up @@ -453,11 +443,10 @@ end
@test gradient(v->sort(v)[i], [1.,2,3])[1][correct[2][i]] == 1
end
for i = 1:3
# Rewrite reached intrinsic function bitcast. Missing rule?
@test_broken gradient(v->sort(v,by=x->x%10)[i], [11,2,99])[1][correct[3][i]] == 1
@test_broken gradient(v->sort(v,by=x->x%10)[i], [2,11,99])[1][correct[4][i]] == 1
@test_broken gradient(v->sort(v,rev=true)[i], [3.,1,2])[1][correct[5][i]] == 1
@test_broken gradient(v->sort(v,rev=true)[i], [1.,2,3])[1][correct[6][i]] == 1
@test gradient(v->sort(v,by=x->x%10)[i], [11,2,99])[1][correct[3][i]] == 1
@test gradient(v->sort(v,by=x->x%10)[i], [2,11,99])[1][correct[4][i]] == 1
@test gradient(v->sort(v,rev=true)[i], [3.,1,2])[1][correct[5][i]] == 1
@test gradient(v->sort(v,rev=true)[i], [1.,2,3])[1][correct[6][i]] == 1
end
end

Expand All @@ -473,27 +462,21 @@ end

@testset "maximum" begin
@test jacobicheck(maximum, rand(2, 3))

# MethodError: no method matching copy(::Nothing)
@test_broken jacobicheck(x -> maximum(x, dims=1), rand(2, 3))
@test_broken jacobicheck(x -> maximum(x, dims=3), rand(2, 3, 4))
@test_broken jacobicheck(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))

@test jacobicheck(x -> maximum(x, dims=1), rand(2, 3))
@test jacobicheck(x -> maximum(x, dims=3), rand(2, 3, 4))
@test jacobicheck(x -> maximum(x, dims=[1, 2]), rand(2, 3, 4))
@test gradient(x -> 1 / maximum(x), [1., 2, 3])[1] == [0, 0, -1/9]
end

@testset "minimum" begin
@test jacobicheck(minimum, rand(2, 3))

# MethodError: no method matching copy(::Nothing)
@test_broken jacobicheck(x -> minimum(x, dims=1), rand(2, 3))
@test_broken jacobicheck(x -> minimum(x, dims=2), rand(2, 3))
@test jacobicheck(x -> minimum(x, dims=1), rand(2, 3))
@test jacobicheck(x -> minimum(x, dims=2), rand(2, 3))
end

@testset "dropdims" begin # https://github.com/JuliaDiff/Diffractor.jl/issues/72
# TypeError: in typeassert, expected Int64, got a value of type Nothing
@test_broken jacobicheck(x -> dropdims(x, dims = 3), rand(2, 2, 1, 2))
@test_broken jacobicheck(x -> dropdims(x, dims = (2, 3)), rand(2, 1, 1, 3))
@test jacobicheck(x -> dropdims(x, dims = 3), rand(2, 2, 1, 2))
@test jacobicheck(x -> dropdims(x, dims = (2, 3)), rand(2, 1, 1, 3))
end

@testset "vcat" begin
Expand Down Expand Up @@ -544,20 +527,19 @@ end
end

@testset "cat(...; dims = $dim)" for dim in 1:3
# Rewrite reached intrinsic function bitcast. Missing rule?

catdim = (x...) -> cat(x..., dims = dim)
@test_broken jacobicheck(catdim, rand(4,1))
@test_broken jacobicheck(catdim, rand(5), rand(5,1))
@test_broken jacobicheck(catdim, rand(2,5), rand(2,5), rand(2,5))
@test jacobicheck(catdim, rand(4,1))
@test jacobicheck(catdim, rand(5), rand(5,1))
@test jacobicheck(catdim, rand(2,5), rand(2,5), rand(2,5))

catdimval = (x...) -> cat(x...; dims = Val(dim))
@test_broken jacobicheck(catdimval, rand(5), rand(5))
@test_broken jacobicheck(catdimval, rand(2,5), rand(2,5,1))
@test jacobicheck(catdimval, rand(5), rand(5))
@test jacobicheck(catdimval, rand(2,5), rand(2,5,1))

# one empty
dim == 1 || continue
@test_broken jacobicheck(catdim, rand(0,5,3), rand(2,5,3))
@test jacobicheck(catdim, rand(0,5,3), rand(2,5,3))
end

@testset "one(s) and zero(s)" begin
Expand Down Expand Up @@ -586,8 +568,7 @@ end
# tests for https://github.com/FluxML/Zygote.jl/issues/724
x1 = rand(3, 3)
@test gradient(x -> sum(x .== 0.5), x1) |> only |> isZero
# MethodError: no method matching copy(::Nothing)
@test_broken gradient(x -> sum(x .* (x .== maximum(x, dims=1))), x1)[1] == (x1 .== maximum(x1, dims=1))
@test gradient(x -> sum(x .* (x .== maximum(x, dims=1))), x1)[1] == (x1 .== maximum(x1, dims=1))

# tests for un-broadcasting *, / via scalar rules
@test all(gradient((x,y) -> sum(x .* y), [1,2], 5) .≈ ([5, 5], 3))
Expand Down Expand Up @@ -620,7 +601,7 @@ end
@test_broken jacobicheck(+, A, B, A)
@test jacobicheck(-, A)
# in typeassert, expected Int64, got a value of type Nothing
@test_broken jacobicheck(-, A, B)
@test jacobicheck(-, A, B)
end

end