Skip to content

Commit 893d3dd

Browse files
authored
[Containers] fix sum of DenseAxisArray that doesn't support zero(T) (#4097)
1 parent b8b5dd0 commit 893d3dd

File tree

3 files changed

+37
-6
lines changed

3 files changed

+37
-6
lines changed

src/Containers/DenseAxisArray.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -854,11 +854,13 @@ function Base.summary(io::IO, x::DenseAxisArrayView)
854854
return print(io, "view(::DenseAxisArray, ", join(x.axes, ", "), "), over")
855855
end
856856

857+
struct _InitNotProvided end
858+
857859
function Base.sum(
858860
f::F,
859861
x::Union{DenseAxisArray{T},DenseAxisArrayView{T}};
860862
dims = Colon(),
861-
init = zero(T),
863+
init = _InitNotProvided(),
862864
) where {F<:Function,T}
863865
if dims != Colon()
864866
return error(
@@ -867,7 +869,11 @@ function Base.sum(
867869
"for-loop summation instead.",
868870
)
869871
end
870-
return sum(f(xi) for xi in x; init)
872+
if init == _InitNotProvided()
873+
return sum(f(xi) for xi in x)
874+
else
875+
return sum(f(xi) for xi in x; init)
876+
end
871877
end
872878

873879
function Base.sum(x::Union{DenseAxisArray,DenseAxisArrayView}; kwargs...)

test/Containers/test_DenseAxisArray.jl

Lines changed: 20 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -992,15 +992,31 @@ end
992992

993993
function test_sum_init()
994994
x = Containers.@container([i in Int[]], i)
995-
@test sum(x) == 0
995+
if VERSION < v"1.7"
996+
@test sum(x) == 0
997+
else
998+
@test_throws ArgumentError sum(x)
999+
end
9961000
@test sum(x; init = 1) == 1
9971001
y = Containers.@container([i in BigInt[]], i)
998-
y_1 = sum(y)
999-
@test y_1 == 0
1000-
@test y_1 isa BigInt
1002+
if VERSION < v"1.7"
1003+
@test sum(y) == 0
1004+
else
1005+
@test_throws ArgumentError sum(y)
1006+
end
10011007
y_2 = sum(y; init = 0)
10021008
@test y_2 === 0
10031009
return
10041010
end
10051011

1012+
function test_sum_init_any()
1013+
x = Containers.@container([i in Any[]], i)
1014+
if VERSION < v"1.7"
1015+
@test_throws MethodError sum(x)
1016+
else
1017+
@test_throws ArgumentError sum(x)
1018+
end
1019+
return
1020+
end
1021+
10061022
end # module

test/Containers/test_macro.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -279,4 +279,13 @@ function test_trailing_semicolon()
279279
return
280280
end
281281

282+
function test_init_issue_4096()
283+
x = Containers.DenseAxisArray(Any[1, 2, 3], Int[1, 2, 3])
284+
@test sum(x) == 6
285+
@test (@allocated sum(x)) == 0
286+
@test sum(x; init = 0) == 6
287+
@test (@allocated sum(x; init = 0)) == 0
288+
return
289+
end
290+
282291
end # module

0 commit comments

Comments
 (0)