Skip to content
Merged
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.44.7"
version = "1.45.0"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand All @@ -20,7 +20,7 @@ StructArrays = "09ab397b-f2b6-538f-b94a-2f83cf4a842a"
Adapt = "3.4.0"
ChainRulesCore = "1.15.3"
ChainRulesTestUtils = "1.5"
Compat = "3.42.0, 4"
Compat = "3.46, 4.2"
FiniteDifferences = "0.12.20"
GPUArraysCore = "0.1.0"
IrrationalConstants = "0.1.1"
Expand Down
6 changes: 6 additions & 0 deletions src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@ import ChainRulesCore: rrule, frule
# Experimental:
using ChainRulesCore: derivatives_given_output

if isdefined(Base, :stack)
using Base: stack
else
using Compat: stack
end

# numbers that we know commute under multiplication
const CommutativeMulNumber = Union{Real,Complex}

Expand Down
50 changes: 50 additions & 0 deletions src/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -610,3 +610,53 @@ function _extrema_dims(x, dims)
end
return y, extrema_pullback_dims
end

#####
##### `stack`
#####

# function rrule(::typeof(stack), xs; dims::Union{Integer, Colon} = :)
# dims = dims === Colon() ? ndims(first(xs)) + 1 : dims
# function stack_pullback(Δ)
# dy = unthunk(Δ)
# return (NoTangent(), [copy(selectdim(dy, dims, i)) for i in 1:size(dy, dims)])
# end
# return stack(xs; dims), stack_pullback
# end


function frule((_, ẋ), ::typeof(stack), x; dims::Union{Integer, Colon} = :)
return stack(x; dims), stack(ẋ; dims)
end

# Other iterable X also allowed, maybe this should be wider?
function rrule(::typeof(stack), X::AbstractArray; dims::Union{Integer, Colon} = :)
Y = stack(X; dims)
sdims = if dims isa Colon
N = ndims(Y) - ndims(X)
X isa AbstractVector ? ndims(Y) : ntuple(i -> i + N, ndims(X))
else
dims
end
project = ProjectTo(X)
function stack_pullback(Δ)
dY = unthunk(Δ)
dY isa NoTangent && return (NoTangent(), NoTangent())
dY isa ZeroTangent && return (NoTangent(), ZeroTangent())
dX = collect(eachslice(unthunk(dY); dims = sdims))
return (NoTangent(), project(dX))
end
return Y, stack_pullback
end

# # This wants #671, but ought to work with Zygote already?
# function rrule(config::RuleConfig, ::typeof(stack), f, args...; dims::Union{Integer, Colon} = :)
# y, unmap = rrule_via_ad(config, map, f, args...)
# z, unstack = rrule(stack, y)
# function stack_pullback_f(dz)
# _, dy = unstack(dz)
# _, df, dargs... = unmap(dy)
# return (NoTangent(), df, dargs...)
# end
# return z, stack_pullback_f
# end
17 changes: 17 additions & 0 deletions test/rulesets/Base/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,3 +416,20 @@ end
B = hcat(A[:,:,1], A[:,:,1])
@test extrema(B, dims=2) == rrule(extrema, B, dims=2)[1]
end

@testset "stack" begin
# vector container
xs = [rand(3, 4), rand(3, 4)]

test_rrule(stack, xs, check_inferred=false)
test_rrule(stack, xs, fkwargs=(dims=1,), check_inferred=false)
test_rrule(stack, xs, fkwargs=(dims=2,), check_inferred=false)
test_rrule(stack, xs, fkwargs=(dims=3,), check_inferred=false)

# multidimensional container
xs = [(1,2,3) (4,5,6); (7,8,9) (10,11,12)]

test_rrule(stack, xs, check_inferred=false)
test_rrule(stack, xs, fkwargs=(dims=1,), check_inferred=false)
test_rrule(stack, xs, fkwargs=(dims=2,), check_inferred=false)
end