Skip to content

Sparse mapping refactor #63

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 14 commits into from
Jul 17, 2025
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "SparseArraysBase"
uuid = "0d5efcca-f356-4864-8770-e1ed8d78f208"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.5.11"
version = "0.5.12"

[deps]
Accessors = "7d9f7c33-5ae7-4f3b-8dc6-eff91059b697"
Expand Down
1 change: 1 addition & 0 deletions src/SparseArraysBase.jl
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@ export SparseArrayDOK,
include("abstractsparsearrayinterface.jl")
include("sparsearrayinterface.jl")
include("indexing.jl")
include("map.jl")
include("wrappers.jl")
include("abstractsparsearray.jl")
include("sparsearraydok.jl")
Expand Down
32 changes: 16 additions & 16 deletions src/abstractsparsearrayinterface.jl
Original file line number Diff line number Diff line change
Expand Up @@ -153,22 +153,22 @@ function preserves_unstored(f, a_dest::AbstractArray, as::AbstractArray...)
return iszero(f(map(a -> getunstoredindex(a, I), as)...))
end

@interface interface::AbstractSparseArrayInterface function Base.map!(
f, a_dest::AbstractArray, as::AbstractArray...
)
isempty(a_dest) && return a_dest # special case to avoid trying to access empty array
indices = if !preserves_unstored(f, a_dest, as...)
eachindex(a_dest)
elseif any(a -> a_dest !== a, as)
as = map(a -> Base.unalias(a_dest, a), as)
@interface interface zero!(a_dest)
eachstoredindex(as...)
else
eachstoredindex(a_dest)
end
@interface interface map_indices!(indices, f, a_dest, as...)
return a_dest
end
# @interface interface::AbstractSparseArrayInterface function Base.map!(
# f, a_dest::AbstractArray, as::AbstractArray...
# )
# isempty(a_dest) && return a_dest # special case to avoid trying to access empty array
# indices = if !preserves_unstored(f, a_dest, as...)
# eachindex(a_dest)
# elseif any(a -> a_dest !== a, as)
# as = map(a -> Base.unalias(a_dest, a), as)
# @interface interface zero!(a_dest)
# eachstoredindex(as...)
# else
# eachstoredindex(a_dest)
# end
# @interface interface map_indices!(indices, f, a_dest, as...)
# return a_dest
# end

# `f::typeof(norm)`, `op::typeof(max)` used by `norm`.
function reduce_init(f, op, as...)
Expand Down
16 changes: 13 additions & 3 deletions src/indexing.jl
Original file line number Diff line number Diff line change
Expand Up @@ -308,10 +308,20 @@ end
end
end

# required:
@interface ::AbstractSparseArrayInterface eachstoredindex(style::IndexStyle, A::AbstractArray) = throw(
MethodError(eachstoredindex, Tuple{typeof(style),typeof(A)})
# required: one implementation for canonical index style
@interface ::AbstractSparseArrayInterface function eachstoredindex(
style::IndexStyle, A::AbstractArray
)
if style == IndexStyle(A)
throw(MethodError(eachstoredindex, Tuple{typeof(style),typeof(A)}))
elseif style == IndexCartesian()
return map(Base.Fix1(Base.getindex, CartesianIndices(A)), eachindex(A))
elseif style == IndexLinear()
return map(Base.Fix1(Base.getindex, LinearIndices(A)), eachindex(A))
else
throw(ArgumentError(lazy"unknown index style $style"))
end
end

# derived but may be specialized:
@interface ::AbstractSparseArrayInterface function eachstoredindex(
Expand Down
142 changes: 142 additions & 0 deletions src/map.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,142 @@
# zero-preserving Traits
# ----------------------
"""
abstract type ZeroPreserving <: Function end

Holy Trait to indicate how a function interacts with abstract zero values:

- `StrongPreserving` : output is guaranteed to be zero if **any** input is.
- `WeakPreserving` : output is guaranteed to be zero if **all** inputs are.
- `NonPreserving` : no guarantees on output.

To attempt to automatically determine this, either `ZeroPreserving(f, A::AbstractArray...)` or
`ZeroPreserving(f, T::Type...)` can be used/overloaded.

!!! warning
incorrectly registering a function to be zero-preserving will lead to silently wrong results.
"""
abstract type ZeroPreserving <: Function end

struct StrongPreserving{F} <: ZeroPreserving
f::F
end
struct WeakPreserving{F} <: ZeroPreserving
f::F
end
struct NonPreserving{F} <: ZeroPreserving
f::F
end

# Backport: remove in 1.12
@static if !isdefined(Base, :haszero)
_haszero(T::Type) = false
_haszero(::Type{<:Number}) = true
else
_haszero = Base.haszero
end

# warning: cannot automatically detect WeakPreserving since this would mean checking all values
function ZeroPreserving(f, A::AbstractArray, Bs::AbstractArray...)
return ZeroPreserving(f, eltype(A), eltype.(Bs)...)
end
# TODO: the following might not properly specialize on the types
# TODO: non-concrete element types
function ZeroPreserving(f, T::Type, Ts::Type...)
if all(_haszero, (T, Ts...))
return iszero(f(zero(T), zero.(Ts)...)) ? WeakPreserving(f) : NonPreserving(f)
else
return NonPreserving(f)
end
end

const _WEAK_FUNCTIONS = (:+, :-)
for f in _WEAK_FUNCTIONS
@eval begin
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = WeakPreserving($f)
end
end

const _STRONG_FUNCTIONS = (:*,)
for f in _STRONG_FUNCTIONS
@eval begin
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = StrongPreserving(
$f
)
end
end

# map(!)
# ------
@interface I::AbstractSparseArrayInterface function Base.map(
f, A::AbstractArray, Bs::AbstractArray...
)
f_pres = ZeroPreserving(f, A, Bs...)
return @interface I map(f_pres, A, Bs...)
end
@interface I::AbstractSparseArrayInterface function Base.map(
f::ZeroPreserving, A::AbstractArray, Bs::AbstractArray...
)
T = Base.Broadcast.combine_eltypes(f.f, (A, Bs...))
C = similar(I, T, size(A))
return @interface I map!(f, C, A, Bs...)
end

@interface I::AbstractSparseArrayInterface function Base.map!(
f, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
)
f_pres = ZeroPreserving(f, A, Bs...)
return @interface I map!(f_pres, C, A, Bs...)
end

@interface ::AbstractSparseArrayInterface function Base.map!(
f::ZeroPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
)
checkshape(C, A, Bs...)
unaliased = map(Base.Fix1(Base.unalias, C), (A, Bs...))

if f isa StrongPreserving
style = IndexStyle(C, unaliased...)
inds = intersect(eachstoredindex.(Ref(style), unaliased)...)
zero!(C)
elseif f isa WeakPreserving
style = IndexStyle(C, unaliased...)
inds = union(eachstoredindex.(Ref(style), unaliased)...)
zero!(C)
elseif f isa NonPreserving
inds = eachindex(C, unaliased...)
else
error(lazy"unknown zero-preserving type $(typeof(f))")
end

@inbounds for I in inds
C[I] = f.f(ith_all(I, unaliased)...)
end

return C
end

# Derived functions
# -----------------
@interface I::AbstractSparseArrayInterface Base.copyto!(C::AbstractArray, A::AbstractArray) = @interface I map!(
identity, C, A
)

# Utility functions
# -----------------
# shape check similar to checkbounds
checkshape(::Type{Bool}, A::AbstractArray) = true
checkshape(::Type{Bool}, A::AbstractArray, B::AbstractArray) = size(A) == size(B)
function checkshape(::Type{Bool}, A::AbstractArray, Bs::AbstractArray...)
return allequal(size, (A, Bs...))
end

function checkshape(A::AbstractArray, Bs::AbstractArray...)
return checkshape(Bool, A, Bs...) ||
throw(DimensionMismatch("argument shapes must match"))
end

@inline ith_all(i, ::Tuple{}) = ()
function ith_all(i, as)
@_propagate_inbounds_meta
return (as[1][i], ith_all(i, Base.tail(as))...)
end
2 changes: 1 addition & 1 deletion src/oneelementarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -287,7 +287,7 @@ storedindex(a::OneElementArray) = getfield(a, :index)
function isstored(a::OneElementArray, I::Int...)
return I == storedindex(a)
end
function eachstoredindex(a::OneElementArray)
function eachstoredindex(::IndexCartesian, a::OneElementArray)
return Fill(CartesianIndex(storedindex(a)), 1)
end

Expand Down
Loading