Skip to content

Commit bcb9a92

Browse files
authored
force specialization of dims arguments that may be Colon (#59474)
In Julia, the colon (`:`, `Colon`) is a function for constructing ranges. We have: ```julia (:) isa Function ``` However, sadly, as an instance of punning, the same object is often also used as an index for choosing an entire dimension. A problem is that `Function` objects are a special case when it comes to specializing method code for given argument types: by default a `:` argument will get specialized as a `Function`, instead of as a `Colon`. This change forces specialization of dims arguments in a bunch of methods in `base/reducedim.jl`. Probably similar changes would be beneficial across the code base, and across the ecosystem. I believe this change will make the sysimage more resistant to invalidation, we'll see once it builds.
1 parent b82c939 commit bcb9a92

File tree

5 files changed

+37
-37
lines changed

5 files changed

+37
-37
lines changed

base/abstractarray.jl

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2865,7 +2865,7 @@ julia> hvcat(5, M...) |> size # hvcat puts matrices next to each other
28652865
(14, 15)
28662866
```
28672867
"""
2868-
stack(iter; dims=:) = _stack(dims, iter)
2868+
stack(iter; dims::D=:) where {D} = _stack(dims, iter)
28692869

28702870
"""
28712871
stack(f, args...; [dims])
@@ -2894,14 +2894,14 @@ julia> stack(eachrow([1 2 3; 4 5 6]), (10, 100); dims=1) do row, n
28942894
4.0 5.0 6.0 400.0 500.0 600.0 0.04 0.05 0.06
28952895
```
28962896
"""
2897-
stack(f, iter; dims=:) = _stack(dims, f(x) for x in iter)
2898-
stack(f, xs, yzs...; dims=:) = _stack(dims, f(xy...) for xy in zip(xs, yzs...))
2897+
stack(f, iter; dims::D=:) where {D} = _stack(dims, f(x) for x in iter)
2898+
stack(f, xs, yzs...; dims::D=:) where {D} = _stack(dims, f(xy...) for xy in zip(xs, yzs...))
28992899

2900-
_stack(dims::Union{Integer, Colon}, iter) = _stack(dims, IteratorSize(iter), iter)
2900+
_stack(dims::D, iter) where {D<:Union{Integer, Colon}} = _stack(dims, IteratorSize(iter), iter)
29012901

2902-
_stack(dims, ::IteratorSize, iter) = _stack(dims, collect(iter))
2902+
_stack(dims::D, ::IteratorSize, iter) where {D} = _stack(dims, collect(iter))
29032903

2904-
function _stack(dims, ::Union{HasShape, HasLength}, iter)
2904+
function _stack(dims::D, ::Union{HasShape, HasLength}, iter) where {D}
29052905
S = @default_eltype iter
29062906
T = S != Union{} ? eltype(S) : Any # Union{} occurs for e.g. stack(1,2), postpone the error
29072907
if isconcretetype(T)

base/array.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2227,7 +2227,7 @@ end
22272227
# 1d special cases of reverse(A; dims) and reverse!(A; dims):
22282228
for (f,_f) in ((:reverse,:_reverse), (:reverse!,:_reverse!))
22292229
@eval begin
2230-
$f(A::AbstractVector; dims=:) = $_f(A, dims)
2230+
$f(A::AbstractVector; dims::D=:) where {D} = $_f(A, dims)
22312231
$_f(A::AbstractVector, ::Colon) = $f(A, firstindex(A), lastindex(A))
22322232
$_f(A::AbstractVector, dim::Tuple{Integer}) = $_f(A, first(dim))
22332233
function $_f(A::AbstractVector, dim::Integer)

base/arraymath.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -56,8 +56,8 @@ julia> reverse(b)
5656
!!! compat "Julia 1.6"
5757
Prior to Julia 1.6, only single-integer `dims` are supported in `reverse`.
5858
"""
59-
reverse(A::AbstractArray; dims=:) = _reverse(A, dims)
60-
_reverse(A, dims) = reverse!(copymutable(A); dims)
59+
reverse(A::AbstractArray; dims::D=:) where {D} = _reverse(A, dims)
60+
_reverse(A, dims::D) where {D} = reverse!(copymutable(A); dims)
6161

6262
"""
6363
reverse!(A; dims=:)
@@ -67,7 +67,7 @@ Like [`reverse`](@ref), but operates in-place in `A`.
6767
!!! compat "Julia 1.6"
6868
Multidimensional `reverse!` requires Julia 1.6.
6969
"""
70-
reverse!(A::AbstractArray; dims=:) = _reverse!(A, dims)
70+
reverse!(A::AbstractArray; dims::D=:) where {D} = _reverse!(A, dims)
7171
_reverse!(A::AbstractArray{<:Any,N}, ::Colon) where {N} = _reverse!(A, ntuple(identity, Val{N}()))
7272
_reverse!(A, dim::Integer) = _reverse!(A, (Int(dim),))
7373
_reverse!(A, dims::NTuple{M,Integer}) where {M} = _reverse!(A, Int.(dims))

base/multidimensional.jl

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1499,7 +1499,7 @@ end
14991499
# contiguous multidimensional indexing: if the first dimension is a range,
15001500
# we can get some performance from using copy_chunks!
15011501

1502-
@inline function setindex!(B::BitArray, X::Union{StridedArray,BitArray}, J0::Union{Colon,AbstractUnitRange{Int}})
1502+
@inline function setindex!(B::BitArray, X::Union{StridedArray,BitArray}, J0::D) where {D<:Union{Colon,AbstractUnitRange{Int}}}
15031503
I0 = to_indices(B, (J0,))[1]
15041504
@boundscheck checkbounds(B, I0)
15051505
l0 = length(I0)
@@ -1511,7 +1511,7 @@ end
15111511
end
15121512

15131513
@inline function setindex!(B::BitArray, X::Union{StridedArray,BitArray},
1514-
I0::Union{Colon,AbstractUnitRange{Int}}, I::Union{Int,AbstractUnitRange{Int},Colon}...)
1514+
I0::DI0, I::Union{Int,AbstractUnitRange{Int},Colon}...) where {DI0<:Union{Colon,AbstractUnitRange{Int}}, }
15151515
J = to_indices(B, (I0, I...))
15161516
@boundscheck checkbounds(B, J...)
15171517
_unsafe_setindex!(B, X, J...)
@@ -1552,7 +1552,7 @@ end
15521552
end
15531553

15541554
@propagate_inbounds function setindex!(B::BitArray, X::AbstractArray,
1555-
I0::Union{Colon,AbstractUnitRange{Int}}, I::Union{Int,AbstractUnitRange{Int},Colon}...)
1555+
I0::DI0, I::Union{Int,AbstractUnitRange{Int},Colon}...) where {DI0<:Union{Colon,AbstractUnitRange{Int}}}
15561556
_setindex!(IndexStyle(B), B, X, to_indices(B, (I0, I...))...)
15571557
end
15581558

@@ -1747,7 +1747,7 @@ julia> unique(A, dims=3)
17471747
0 0
17481748
```
17491749
"""
1750-
unique(A::AbstractArray; dims::Union{Colon,Integer} = :) = _unique_dims(A, dims)
1750+
unique(A::AbstractArray; dims::D = :) where {D<:Union{Colon,Integer}} = _unique_dims(A, dims)
17511751

17521752
_unique_dims(A::AbstractArray, dims::Colon) = invoke(unique, Tuple{Any}, A)
17531753

base/reducedim.jl

Lines changed: 23 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -327,7 +327,7 @@ julia> mapreduce(isodd, |, a, dims=1)
327327
1 1 1 1
328328
```
329329
"""
330-
mapreduce(f, op, A::AbstractArrayOrBroadcasted; dims=:, init=_InitialValue()) =
330+
mapreduce(f, op, A::AbstractArrayOrBroadcasted; dims::D=:, init=_InitialValue()) where {D} =
331331
_mapreduce_dim(f, op, init, A, dims)
332332
mapreduce(f, op, A::AbstractArrayOrBroadcasted, B::AbstractArrayOrBroadcasted...; kw...) =
333333
reduce(op, map(f, A, B...); kw...)
@@ -338,10 +338,10 @@ _mapreduce_dim(f, op, nt, A::AbstractArrayOrBroadcasted, ::Colon) =
338338
_mapreduce_dim(f, op, ::_InitialValue, A::AbstractArrayOrBroadcasted, ::Colon) =
339339
_mapreduce(f, op, IndexStyle(A), A)
340340

341-
_mapreduce_dim(f, op, nt, A::AbstractArrayOrBroadcasted, dims) =
341+
_mapreduce_dim(f, op, nt, A::AbstractArrayOrBroadcasted, dims::D) where {D} =
342342
mapreducedim!(f, op, reducedim_initarray(A, dims, nt), A)
343343

344-
_mapreduce_dim(f, op, ::_InitialValue, A::AbstractArrayOrBroadcasted, dims) =
344+
_mapreduce_dim(f, op, ::_InitialValue, A::AbstractArrayOrBroadcasted, dims::D) where {D} =
345345
mapreducedim!(f, op, reducedim_init(f, op, A, dims), A)
346346

347347
"""
@@ -409,8 +409,8 @@ julia> count(<=(2), A, dims=2)
409409
0
410410
```
411411
"""
412-
count(A::AbstractArrayOrBroadcasted; dims=:, init=0) = count(identity, A; dims, init)
413-
count(f, A::AbstractArrayOrBroadcasted; dims=:, init=0) = _count(f, A, dims, init)
412+
count(A::AbstractArrayOrBroadcasted; dims::D=:, init=0) where {D} = count(identity, A; dims, init)
413+
count(f, A::AbstractArrayOrBroadcasted; dims::D=:, init=0) where {D} = _count(f, A, dims, init)
414414

415415
_count(f, A::AbstractArrayOrBroadcasted, dims::Colon, init) = _simple_count(f, A, init)
416416
_count(f, A::AbstractArrayOrBroadcasted, dims, init) = mapreduce(_bool(f), add_sum, A; dims, init)
@@ -980,20 +980,20 @@ for (fname, _fname, op) in [(:sum, :_sum, :add_sum), (:prod, :_prod,
980980
mapf = fname === :extrema ? :(ExtremaMap(f)) : :f
981981
@eval begin
982982
# User-facing methods with keyword arguments
983-
@inline ($fname)(a::AbstractArray; dims=:, kw...) = ($_fname)(a, dims; kw...)
984-
@inline ($fname)(f, a::AbstractArray; dims=:, kw...) = ($_fname)(f, a, dims; kw...)
983+
@inline ($fname)(a::AbstractArray; dims::D=:, kw...) where {D} = ($_fname)(a, dims; kw...)
984+
@inline ($fname)(f, a::AbstractArray; dims::D=:, kw...) where {D} = ($_fname)(f, a, dims; kw...)
985985

986986
# Underlying implementations using dispatch
987987
($_fname)(a, ::Colon; kw...) = ($_fname)(identity, a, :; kw...)
988988
($_fname)(f, a, ::Colon; kw...) = mapreduce($mapf, $op, a; kw...)
989989
end
990990
end
991991

992-
any(a::AbstractArray; dims=:) = _any(a, dims)
993-
any(f::Function, a::AbstractArray; dims=:) = _any(f, a, dims)
992+
any(a::AbstractArray; dims::D=:) where {D} = _any(a, dims)
993+
any(f::Function, a::AbstractArray; dims::D=:) where {D} = _any(f, a, dims)
994994
_any(a, ::Colon) = _any(identity, a, :)
995-
all(a::AbstractArray; dims=:) = _all(a, dims)
996-
all(f::Function, a::AbstractArray; dims=:) = _all(f, a, dims)
995+
all(a::AbstractArray; dims::D=:) where {D} = _all(a, dims)
996+
all(f::Function, a::AbstractArray; dims::D=:) where {D} = _all(f, a, dims)
997997
_all(a, ::Colon) = _all(identity, a, :)
998998

999999
for (fname, op) in [(:sum, :add_sum), (:prod, :mul_prod),
@@ -1008,8 +1008,8 @@ for (fname, op) in [(:sum, :add_sum), (:prod, :mul_prod),
10081008
mapreducedim!($mapf, $(op), initarray!(r, $mapf, $(op), init, A), A)
10091009
$(fname!)(r::AbstractArray, A::AbstractArray; init::Bool=true) = $(fname!)(identity, r, A; init=init)
10101010

1011-
$(_fname)(A, dims; kw...) = $(_fname)(identity, A, dims; kw...)
1012-
$(_fname)(f, A, dims; kw...) = mapreduce($mapf, $(op), A; dims=dims, kw...)
1011+
$(_fname)(A, dims::D; kw...) where {D} = $(_fname)(identity, A, dims; kw...)
1012+
$(_fname)(f, A, dims::D; kw...) where {D} = mapreduce($mapf, $(op), A; dims=dims, kw...)
10131013
end
10141014
end
10151015

@@ -1100,8 +1100,8 @@ julia> findmin(A, dims=2)
11001100
([1.0; 3.0;;], CartesianIndex{2}[CartesianIndex(1, 1); CartesianIndex(2, 1);;])
11011101
```
11021102
"""
1103-
findmin(A::AbstractArray; dims=:) = _findmin(A, dims)
1104-
_findmin(A, dims) = _findmin(identity, A, dims)
1103+
findmin(A::AbstractArray; dims::D=:) where {D} = _findmin(A, dims)
1104+
_findmin(A, dims::D) where {D} = _findmin(identity, A, dims)
11051105

11061106
"""
11071107
findmin(f, A; dims) -> (f(x), index)
@@ -1123,9 +1123,9 @@ julia> findmin(abs2, A, dims=2)
11231123
([1.0; 0.25;;], CartesianIndex{2}[CartesianIndex(1, 1); CartesianIndex(2, 1);;])
11241124
```
11251125
"""
1126-
findmin(f, A::AbstractArray; dims=:) = _findmin(f, A, dims)
1126+
findmin(f, A::AbstractArray; dims::D=:) where {D} = _findmin(f, A, dims)
11271127

1128-
function _findmin(f, A, region)
1128+
function _findmin(f, A, region::D) where {D}
11291129
ri = reduced_indices0(A, region)
11301130
if isempty(A)
11311131
if prod(map(length, reduced_indices(A, region))) != 0
@@ -1173,8 +1173,8 @@ julia> findmax(A, dims=2)
11731173
([2.0; 4.0;;], CartesianIndex{2}[CartesianIndex(1, 2); CartesianIndex(2, 2);;])
11741174
```
11751175
"""
1176-
findmax(A::AbstractArray; dims=:) = _findmax(A, dims)
1177-
_findmax(A, dims) = _findmax(identity, A, dims)
1176+
findmax(A::AbstractArray; dims::D=:) where {D} = _findmax(A, dims)
1177+
_findmax(A, dims::D) where {D} = _findmax(identity, A, dims)
11781178

11791179
"""
11801180
findmax(f, A; dims) -> (f(x), index)
@@ -1196,9 +1196,9 @@ julia> findmax(abs2, A, dims=2)
11961196
([1.0; 4.0;;], CartesianIndex{2}[CartesianIndex(1, 1); CartesianIndex(2, 2);;])
11971197
```
11981198
"""
1199-
findmax(f, A::AbstractArray; dims=:) = _findmax(f, A, dims)
1199+
findmax(f, A::AbstractArray; dims::D=:) where {D} = _findmax(f, A, dims)
12001200

1201-
function _findmax(f, A, region)
1201+
function _findmax(f, A, region::D) where {D}
12021202
ri = reduced_indices0(A, region)
12031203
if isempty(A)
12041204
if prod(map(length, reduced_indices(A, region))) != 0
@@ -1247,7 +1247,7 @@ julia> argmin(A, dims=2)
12471247
CartesianIndex(2, 1)
12481248
```
12491249
"""
1250-
argmin(A::AbstractArray; dims=:) = findmin(A; dims=dims)[2]
1250+
argmin(A::AbstractArray; dims::D=:) where {D} = findmin(A; dims=dims)[2]
12511251

12521252
"""
12531253
argmax(A; dims) -> indices
@@ -1272,4 +1272,4 @@ julia> argmax(A, dims=2)
12721272
CartesianIndex(2, 2)
12731273
```
12741274
"""
1275-
argmax(A::AbstractArray; dims=:) = findmax(A; dims=dims)[2]
1275+
argmax(A::AbstractArray; dims::D=:) where {D} = findmax(A; dims=dims)[2]

0 commit comments

Comments
 (0)