Skip to content

Commit 668047d

Browse files
N5N3vtjnashDilumAluthge
authored
Fix ndims for Broadcasted with no args (#45477)
Follow up #44061. This PR makes `collect(Base.broadcast(randn))` works correctly, and improve the inference result. Test added. --------- Co-authored-by: Jameson Nash <[email protected]> Co-authored-by: Dilum Aluthge <[email protected]>
1 parent 30c34ef commit 668047d

File tree

2 files changed

+18
-18
lines changed

2 files changed

+18
-18
lines changed

base/broadcast.jl

Lines changed: 11 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -264,6 +264,17 @@ Base.LinearIndices(bc::Broadcasted{<:Any,<:Tuple{Any}}) = LinearIndices(axes(bc)
264264

265265
Base.ndims(bc::Broadcasted) = ndims(typeof(bc))
266266
Base.ndims(::Type{<:Broadcasted{<:Any,<:NTuple{N,Any}}}) where {N} = N
267+
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims(argtype(BC))
268+
function Base.ndims(BC::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N}
269+
N isa Int ? N : _maxndims(argtype(BC))
270+
end
271+
_maxndims(::Type{Tuple{}}) = 0
272+
_maxndims(::Type{Tuple{T}}) where {T} = T <: Tuple ? 1 : Int(ndims(T))::Int
273+
function _maxndims(Args::Type{<:Tuple{T,Vararg}}) where {T}
274+
m = T <: Tuple ? 1 : Int(ndims(T))::Int
275+
n = _maxndims(Base.tuple_type_tail(Args))
276+
max(m, n)
277+
end
267278

268279
Base.size(bc::Broadcasted) = map(length, axes(bc))
269280
Base.length(bc::Broadcasted) = prod(size(bc))
@@ -280,20 +291,6 @@ Base.@propagate_inbounds function Base.iterate(bc::Broadcasted, s)
280291
end
281292

282293
Base.IteratorSize(::Type{T}) where {T<:Broadcasted} = Base.HasShape{ndims(T)}()
283-
Base.ndims(BC::Type{<:Broadcasted{<:Any,Nothing}}) = _maxndims_broadcasted(BC)
284-
# the `AbstractArrayStyle` type parameter is required to be either equal to `Any` or be an `Int` value
285-
Base.ndims(BC::Type{<:Broadcasted{<:AbstractArrayStyle{Any},Nothing}}) = _maxndims_broadcasted(BC)
286-
Base.ndims(::Type{<:Broadcasted{<:AbstractArrayStyle{N},Nothing}}) where {N} = N::Int
287-
288-
function _maxndims_broadcasted(BC::Type{<:Broadcasted})
289-
_maxndims(fieldtype(BC, :args))
290-
end
291-
_maxndims(::Type{T}) where {T<:Tuple} = reduce(max, ntuple(n -> (F = fieldtype(T, n); F <: Tuple ? 1 : ndims(F)), Base._counttuple(T)))
292-
_maxndims(::Type{<:Tuple{T}}) where {T} = T <: Tuple ? 1 : ndims(T)
293-
function _maxndims(::Type{<:Tuple{T, S}}) where {T, S}
294-
return max(T <: Tuple ? 1 : ndims(T), S <: Tuple ? 1 : ndims(S))
295-
end
296-
297294
Base.IteratorEltype(::Type{<:Broadcasted}) = Base.EltypeUnknown()
298295

299296
## Instantiation fills in the "missing" fields in Broadcasted.

test/broadcast.jl

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -938,11 +938,14 @@ let
938938

939939
@test @inferred(Base.IteratorSize(Base.broadcasted(randn))) === Base.HasShape{0}()
940940

941+
@test @inferred(Base.IteratorSize(convert(Broadcast.Broadcasted{Nothing}, Base.broadcasted(randn)))) === Base.HasShape{0}()
942+
941943
# inference on nested
942-
bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3)))
943-
bc_nest = Base.broadcasted(+, bc , bc)
944-
@test @inferred(Base.IteratorSize(bc_nest)) === Base.HasShape{1}()
945-
end
944+
bc = Base.broadcasted(+, AD1(randn(3)), AD1(randn(3)), AD1(randn(3)))
945+
bc_nest = Base.broadcasted(*, bc, bc, bc, bc, AD1(randn(3)))
946+
bc_nest2 = Base.broadcasted(-, bc_nest, bc_nest)
947+
@test @inferred(Base.IteratorSize(bc_nest2)) === Base.HasShape{1}()
948+
end
946949

947950
# issue #31295
948951
let a = rand(5), b = rand(5), c = copy(a)

0 commit comments

Comments
 (0)