Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "1.44.1"
version = "1.44.2"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down
4 changes: 2 additions & 2 deletions src/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -328,13 +328,13 @@ end
unbroadcast(x::Base.AbstractArrayOrBroadcasted, dx::AbstractZero) = dx

function unbroadcast(x::T, dx) where {T<:Tuple{Vararg{Any,N}}} where {N}
val = if length(x) == length(dx)
val = if N == length(dx)
dx
else
sum(dx; dims=2:ndims(dx))
end
eltype(val) <: AbstractZero && return NoTangent()
return ProjectTo(x)(NTuple{length(x)}(val)) # Tangent
return ProjectTo(x)(Tuple{Vararg{Any,N}}(val)) # Tangent
end
unbroadcast(x::Tuple, dx::AbstractZero) = dx

Expand Down
4 changes: 4 additions & 0 deletions test/rulesets/Base/broadcast.jl
Original file line number Diff line number Diff line change
Expand Up @@ -173,4 +173,8 @@ BT1 = Broadcast.BroadcastStyle(Tuple)
test_rrule(copy∘broadcasted, complex, rand())
end
end

@testset "bugs" begin
@test ChainRules.unbroadcast((1,2,[3]), [4,5,[6]]) isa Tangent # earlier, NTuple demanded same type
end
end