Skip to content

Commit 414dba7

Browse files
jishnubdlfivefifty
andauthored
sum for OneElement (#375)
* sum for OneElement * Add tests * Accept dims in sum * Add tests * Bump version to v1.13.0 * Ensure that init kwarg works * Update tests for v1.6 --------- Co-authored-by: Sheehan Olver <[email protected]>
1 parent 7b2bb11 commit 414dba7

File tree

3 files changed

+61
-8
lines changed

3 files changed

+61
-8
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "FillArrays"
22
uuid = "1a297f60-69ca-5386-bcde-b61e274b549b"
3-
version = "1.12.0"
3+
version = "1.13.0"
44

55
[deps]
66
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"

src/oneelement.jl

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -159,13 +159,6 @@ function *(A::OneElementMatrix, B::OneElementVecOrMat)
159159
OneElement(val, (A.ind[1], B.ind[2:end]...), (axes(A,1), axes(B)[2:end]...))
160160
end
161161

162-
function *(A::AbstractFillMatrix, x::OneElementVector)
163-
check_matmul_sizes(A, x)
164-
val = getindex_value(A) * getindex_value(x)
165-
Fill(val, (axes(A,1),))
166-
end
167-
*(A::AbstractZerosMatrix, x::OneElementVector) = mult_zeros(A, x)
168-
169162
*(A::OneElementMatrix, x::AbstractZerosVector) = mult_zeros(A, x)
170163

171164
function *(A::OneElementMatrix, B::AbstractFillVector)
@@ -448,3 +441,13 @@ _maybesize(t) = t
448441
Base.show(io::IO, A::OneElement) = print(io, OneElement, "(", A.val, ", ", A.ind, ", ", _maybesize(axes(A)), ")")
449442
Base.show(io::IO, A::OneElement{<:Any,1,Tuple{Int},Tuple{Base.OneTo{Int}}}) =
450443
print(io, OneElement, "(", A.val, ", ", A.ind[1], ", ", size(A,1), ")")
444+
445+
# mapreduce
446+
Base.sum(O::OneElement; dims=:, kw...) = _sum(O, dims; kw...)
447+
_sum(O::OneElement, ::Colon; kw...) = sum((getindex_value(O),); kw...)
448+
function _sum(O::OneElement, dims; kw...)
449+
v = _sum(O, :; kw...)
450+
ax = Base.reduced_indices(axes(O), dims)
451+
ind = ntuple(x -> x in dims ? first(ax[x]) + (O.ind[x] in axes(O)[x]) - 1 : O.ind[x], ndims(O))
452+
OneElement(v, ind, ax)
453+
end

test/runtests.jl

Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2702,6 +2702,56 @@ end
27022702
@test repr(B) == "OneElement(2, (1, 2), (Base.IdentityUnitRange(1:1), Base.IdentityUnitRange(2:2)))"
27032703
end
27042704

2705+
@testset "sum" begin
2706+
@testset "OneElement($v, $ind, $sz)" for (v, ind, sz) in (
2707+
(Int8(2), 3, 4),
2708+
(3.0, 5, 4),
2709+
(3.0, 0, 0),
2710+
(SMatrix{2,2}(1:4), (4, 2), (12,6)),
2711+
)
2712+
O = OneElement(v,ind,sz)
2713+
A = Array(O)
2714+
if VERSION >= v"1.10"
2715+
@test @inferred(sum(O)) === sum(A)
2716+
else
2717+
@test @inferred(sum(O)) == sum(A)
2718+
end
2719+
@test @inferred(sum(O, init=zero(eltype(O)))) === sum(A, init=zero(eltype(O)))
2720+
@test @inferred(sum(x->1, O, init=0)) === sum(Fill(1, axes(O)), init=0)
2721+
end
2722+
2723+
@testset for O in (OneElement(Int8(2), (1,2), (2,4)),
2724+
OneElement(3, (1,2,3), (2,4,4)),
2725+
OneElement(2.0, (3,2,5), (2,3,2)),
2726+
OneElement(SMatrix{2,2}(1:4), (1,2), (2,4)),
2727+
)
2728+
A = Array(O)
2729+
init = sum((zero(FillArrays.getindex_value(O)),))
2730+
for i in 1:3
2731+
@test @inferred(sum(O, dims=i)) == sum(A, dims=i)
2732+
@test @inferred(sum(O, dims=i, init=init)) == sum(A, dims=i, init=init)
2733+
@test @inferred(sum(x->1, O, dims=i, init=0)) == sum(Fill(1, axes(O)), dims=i, init=0)
2734+
end
2735+
@test @inferred(sum(O, dims=1:1)) == sum(A, dims=1:1)
2736+
@test @inferred(sum(O, dims=1:2)) == sum(A, dims=1:2)
2737+
@test @inferred(sum(O, dims=1:3)) == sum(A, dims=1:3)
2738+
@test @inferred(sum(O, dims=(1,))) == sum(A, dims=(1,))
2739+
@test @inferred(sum(O, dims=(1,2))) == sum(A, dims=(1,2))
2740+
@test @inferred(sum(O, dims=(1,3))) == sum(A, dims=(1,3))
2741+
@test @inferred(sum(O, dims=(2,3))) == sum(A, dims=(2,3))
2742+
@test @inferred(sum(O, dims=(1,2,3))) == sum(A, dims=(1,2,3))
2743+
@test @inferred(sum(O, dims=1:1, init=init)) == sum(A, dims=1:1, init=init)
2744+
@test @inferred(sum(O, dims=1:2, init=init)) == sum(A, dims=1:2, init=init)
2745+
@test @inferred(sum(O, dims=1:3, init=init)) == sum(A, dims=1:3, init=init)
2746+
@test @inferred(sum(O, dims=(1,), init=init)) == sum(A, dims=(1,), init=init)
2747+
@test @inferred(sum(O, dims=(1,2), init=init)) == sum(A, dims=(1,2), init=init)
2748+
@test @inferred(sum(O, dims=(1,3), init=init)) == sum(A, dims=(1,3), init=init)
2749+
@test @inferred(sum(O, dims=(2,3), init=init)) == sum(A, dims=(2,3), init=init)
2750+
@test @inferred(sum(O, dims=(1,2,3), init=init)) == sum(A, dims=(1,2,3), init=init)
2751+
@test @inferred(sum(x->1, O, dims=(1,2,3), init=0)) == sum(Fill(1, axes(O)), dims=(1,2,3), init=0)
2752+
end
2753+
end
2754+
27052755
@testset "diag" begin
27062756
@testset for sz in [(0,0), (0,1), (1,0), (1,1), (4,4), (4,6), (6,3)], ind in CartesianIndices(sz)
27072757
O = OneElement(4, Tuple(ind), sz)

0 commit comments

Comments
 (0)