Skip to content

Commit 5390f0c

Browse files
committed
Fixes for BlockSparseArrays
1 parent 01db9ba commit 5390f0c

File tree

2 files changed

+34
-21
lines changed

2 files changed

+34
-21
lines changed

src/abstractsparsearray.jl

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,12 @@ function Base._cat(dims, a::AnyAbstractSparseArray...)
8585
return concatenate(dims, a...)
8686
end
8787

88+
# TODO: Use `map(WeakPreserving(f), a)` instead.
89+
# Currently that has trouble with type unstable maps, since
90+
# the element type becomes abstract and therefore the zero/unstored
91+
# values are not well defined.
8892
function map_stored(f, a::AnyAbstractSparseArray)
93+
iszero(storedlength(a)) && return a
8994
kvs = storedpairs(a)
9095
# `collect` to convert to `Vector`, since otherwise
9196
# if it stays as `Dictionary` we might hit issues like
@@ -102,6 +107,10 @@ end
102107

103108
using Adapt: adapt
104109
function Base.print_array(io::IO, a::AnyAbstractSparseArray)
110+
# TODO: Use `map(WeakPreserving(adapt(Array)), a)` instead.
111+
# Currently that has trouble with type unstable maps, since
112+
# the element type becomes abstract and therefore the zero/unstored
113+
# values are not well defined.
105114
a′ = map_stored(adapt(Array), a)
106115
return @invoke Base.print_array(io::typeof(io), a′::AbstractArray{<:Any,ndims(a)})
107116
end

src/map.jl

Lines changed: 25 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -48,20 +48,20 @@ function ZeroPreserving(f, T::Type, Ts::Type...)
4848
return NonPreserving(f)
4949
end
5050
end
51+
ZeroPreserving(f::ZeroPreserving, T::Type, Ts::Type...) = f
5152

52-
const _WEAK_FUNCTIONS = (:+, :-)
53-
for f in _WEAK_FUNCTIONS
53+
for F in (:(typeof(+)), :(typeof(-)), :(typeof(identity)))
5454
@eval begin
55-
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = WeakPreserving($f)
55+
ZeroPreserving(f::$F, ::Type, ::Type...) = WeakPreserving(f)
5656
end
5757
end
5858

59-
const _STRONG_FUNCTIONS = (:*,)
60-
for f in _STRONG_FUNCTIONS
59+
using MapBroadcast: MapFunction
60+
for F in (:(typeof(*)), :(MapFunction{typeof(*)}))
6161
@eval begin
62-
ZeroPreserving(::typeof($f), ::Type{<:Number}, ::Type{<:Number}...) = StrongPreserving(
63-
$f
64-
)
62+
function ZeroPreserving(f::$F, ::Type, ::Type...)
63+
return StrongPreserving(f)
64+
end
6565
end
6666
end
6767

@@ -71,29 +71,32 @@ end
7171
f, A::AbstractArray, Bs::AbstractArray...
7272
)
7373
f_pres = ZeroPreserving(f, A, Bs...)
74-
return @interface I map(f_pres, A, Bs...)
74+
return map_sparsearray(f_pres, A, Bs...)
7575
end
76-
@interface I::AbstractSparseArrayInterface function Base.map(
77-
f::ZeroPreserving, A::AbstractArray, Bs::AbstractArray...
78-
)
76+
77+
# This isn't an overload of `Base.map` since that leads to ambiguity errors.
78+
function map_sparsearray(f::ZeroPreserving, A::AbstractArray, Bs::AbstractArray...)
7979
T = Base.Broadcast.combine_eltypes(f.f, (A, Bs...))
80-
C = similar(I, T, size(A))
81-
return @interface I map!(f, C, A, Bs...)
80+
C = similar(A, T)
81+
# TODO: Instead use:
82+
# U = map(f.f, map(unstored, (A, Bs...))...)
83+
# C = similar(A, Unstored(U))
84+
return map_sparsearray!(f, C, A, Bs...)
8285
end
8386

8487
@interface I::AbstractSparseArrayInterface function Base.map!(
8588
f, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
8689
)
8790
f_pres = ZeroPreserving(f, A, Bs...)
88-
return @interface I map!(f_pres, C, A, Bs...)
91+
return map_sparsearray!(f_pres, C, A, Bs...)
8992
end
9093

91-
@interface ::AbstractSparseArrayInterface function Base.map!(
94+
# This isn't an overload of `Base.map!` since that leads to ambiguity errors.
95+
function map_sparsearray!(
9296
f::ZeroPreserving, C::AbstractArray, A::AbstractArray, Bs::AbstractArray...
9397
)
9498
checkshape(C, A, Bs...)
9599
unaliased = map(Base.Fix1(Base.unalias, C), (A, Bs...))
96-
97100
if f isa StrongPreserving
98101
style = IndexStyle(C, unaliased...)
99102
inds = intersect(eachstoredindex.(Ref(style), unaliased)...)
@@ -107,19 +110,20 @@ end
107110
else
108111
error(lazy"unknown zero-preserving type $(typeof(f))")
109112
end
110-
111113
@inbounds for I in inds
112114
C[I] = f.f(ith_all(I, unaliased)...)
113115
end
114-
115116
return C
116117
end
117118

118119
# Derived functions
119120
# -----------------
120-
@interface I::AbstractSparseArrayInterface Base.copyto!(C::AbstractArray, A::AbstractArray) = @interface I map!(
121-
identity, C, A
121+
@interface I::AbstractSparseArrayInterface function Base.copyto!(
122+
dest::AbstractArray, src::AbstractArray
122123
)
124+
@interface I map!(identity, dest, src)
125+
return dest
126+
end
123127

124128
# Only map the stored values of the inputs.
125129
function map_stored! end

0 commit comments

Comments
 (0)