|
1 | 1 | using ArrayLayouts: LayoutArray |
2 | | -using BlockArrays: blockisequal |
3 | | -using DerivableInterfaces: @interface, AbstractArrayInterface, interface |
4 | | -using GPUArraysCore: @allowscalar |
5 | 2 | using LinearAlgebra: Adjoint, Transpose |
6 | | -using SparseArraysBase: SparseArraysBase, SparseArrayStyle |
7 | | - |
8 | | -# Returns `Vector{<:CartesianIndices}` |
9 | | -function union_stored_blocked_cartesianindices(as::Vararg{AbstractArray}) |
10 | | - combined_axes = combine_axes(axes.(as)...) |
11 | | - stored_blocked_cartesianindices_as = map(as) do a |
12 | | - return blocked_cartesianindices(axes(a), combined_axes, eachblockstoredindex(a)) |
13 | | - end |
14 | | - return ∪(stored_blocked_cartesianindices_as...) |
15 | | -end |
16 | | - |
17 | | -# This is used by `map` to get the output axes. |
18 | | -# This is type piracy, try to avoid this, maybe requires defining `map`. |
19 | | -## Base.promote_shape(a1::Tuple{Vararg{BlockedUnitRange}}, a2::Tuple{Vararg{BlockedUnitRange}}) = combine_axes(a1, a2) |
20 | | - |
21 | | -reblock(a) = a |
22 | | - |
23 | | -# If the blocking of the slice doesn't match the blocking of the |
24 | | -# parent array, reblock according to the blocking of the parent array. |
25 | | -function reblock( |
26 | | - a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{AbstractUnitRange}}} |
27 | | -) |
28 | | - # TODO: This relies on the behavior that slicing a block sparse |
29 | | - # array with a UnitRange inherits the blocking of the underlying |
30 | | - # block sparse array, we might change that default behavior |
31 | | - # so this might become something like `@blocked parent(a)[...]`. |
32 | | - return @view parent(a)[UnitRange{Int}.(parentindices(a))...] |
33 | | -end |
34 | | - |
35 | | -function reblock( |
36 | | - a::SubArray{<:Any,<:Any,<:AbstractBlockSparseArray,<:Tuple{Vararg{NonBlockedArray}}} |
37 | | -) |
38 | | - return @view parent(a)[map(I -> I.array, parentindices(a))...] |
39 | | -end |
40 | | - |
41 | | -function reblock( |
42 | | - a::SubArray{ |
43 | | - <:Any, |
44 | | - <:Any, |
45 | | - <:AbstractBlockSparseArray, |
46 | | - <:Tuple{Vararg{BlockIndices{<:AbstractBlockVector{<:Block{1}}}}}, |
47 | | - }, |
48 | | -) |
49 | | - # Remove the blocking. |
50 | | - return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...] |
51 | | -end |
52 | | - |
53 | | -# `map!` specialized to zero-dimensional inputs. |
54 | | -function map_zero_dim! end |
55 | | - |
56 | | -@interface ::AbstractArrayInterface function map_zero_dim!( |
57 | | - f, a_dest::AbstractArray, a_srcs::AbstractArray... |
58 | | -) |
59 | | - @allowscalar a_dest[] = f.(map(a_src -> a_src[], a_srcs)...) |
60 | | - return a_dest |
61 | | -end |
62 | | - |
63 | | -# TODO: Move to `blocksparsearrayinterface/map.jl`. |
64 | | -# TODO: Rewrite this so that it takes the blocking structure |
65 | | -# made by combining the blocking of the axes (i.e. the blocking that |
66 | | -# is used to determine `union_stored_blocked_cartesianindices(...)`). |
67 | | -# `reblock` is a partial solution to that, but a bit ad-hoc. |
68 | | -## TODO: Make this an `@interface AbstractBlockSparseArrayInterface` function. |
69 | | -@interface interface::AbstractBlockSparseArrayInterface function Base.map!( |
70 | | - f, a_dest::AbstractArray, a_srcs::AbstractArray... |
71 | | -) |
72 | | - if isempty(a_srcs) |
73 | | - # Broadcast expressions of the form `a .= 2`. |
74 | | - @interface interface fill!(a_dest, f()) |
75 | | - return a_dest |
76 | | - end |
77 | | - if iszero(ndims(a_dest)) |
78 | | - @interface interface map_zero_dim!(f, a_dest, a_srcs...) |
79 | | - return a_dest |
80 | | - end |
81 | | - # TODO: This assumes element types are numbers, generalize this logic. |
82 | | - f_preserves_zeros = f(zero.(eltype.(a_srcs))...) == zero(eltype(a_dest)) |
83 | | - a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs) |
84 | | - for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...) |
85 | | - BI_dest = blockindexrange(a_dest, I) |
86 | | - BI_srcs = map(a_src -> blockindexrange(a_src, I), a_srcs) |
87 | | - # TODO: Investigate why this doesn't work: |
88 | | - # block_dest = @view a_dest[_block(BI_dest)] |
89 | | - block_dest = blocks_maybe_single(a_dest)[Int.(Tuple(_block(BI_dest)))...] |
90 | | - # TODO: Investigate why this doesn't work: |
91 | | - # block_srcs = ntuple(i -> @view(a_srcs[i][_block(BI_srcs[i])]), length(a_srcs)) |
92 | | - block_srcs = ntuple(length(a_srcs)) do i |
93 | | - return blocks_maybe_single(a_srcs[i])[Int.(Tuple(_block(BI_srcs[i])))...] |
94 | | - end |
95 | | - subblock_dest = @view block_dest[BI_dest.indices...] |
96 | | - subblock_srcs = ntuple(i -> @view(block_srcs[i][BI_srcs[i].indices...]), length(a_srcs)) |
97 | | - I_dest = CartesianIndex(Int.(Tuple(_block(BI_dest)))) |
98 | | - # If the function preserves zero values and all of the source blocks are zero, |
99 | | - # the output block will be zero. In that case, if the block isn't stored yet, |
100 | | - # don't do anything. |
101 | | - if f_preserves_zeros && all(iszero, subblock_srcs) && !isstored(blocks(a_dest), I_dest) |
102 | | - continue |
103 | | - end |
104 | | - # TODO: Use `map!!` to handle immutable blocks. |
105 | | - map!(f, subblock_dest, subblock_srcs...) |
106 | | - # Replace the entire block, handles initializing new blocks |
107 | | - # or if blocks are immutable. |
108 | | - blocks(a_dest)[I_dest] = block_dest |
109 | | - end |
110 | | - return a_dest |
111 | | -end |
112 | | - |
113 | | -# TODO: Move to `blocksparsearrayinterface/map.jl`. |
114 | | -@interface ::AbstractBlockSparseArrayInterface function Base.mapreduce( |
115 | | - f, op, as::AbstractArray...; kwargs... |
116 | | -) |
117 | | - # TODO: Define an `init` value based on the element type. |
118 | | - return @interface interface(blocks.(as)...) mapreduce( |
119 | | - block -> mapreduce(f, op, block), op, blocks.(as)...; kwargs... |
120 | | - ) |
121 | | -end |
122 | | - |
123 | | -# TODO: Move to `blocksparsearrayinterface/map.jl`. |
124 | | -@interface ::AbstractBlockSparseArrayInterface function Base.iszero(a::AbstractArray) |
125 | | - # TODO: Just call `iszero(blocks(a))`? |
126 | | - return @interface interface(blocks(a)) iszero(blocks(a)) |
127 | | -end |
128 | | - |
129 | | -# TODO: Move to `blocksparsearrayinterface/map.jl`. |
130 | | -@interface ::AbstractBlockSparseArrayInterface function Base.isreal(a::AbstractArray) |
131 | | - # TODO: Just call `isreal(blocks(a))`? |
132 | | - return @interface interface(blocks(a)) isreal(blocks(a)) |
133 | | -end |
134 | 3 |
|
135 | 4 | function Base.map!(f, a_dest::AbstractArray, a_srcs::AnyAbstractBlockSparseArray...) |
136 | 5 | @interface interface(a_dest, a_srcs...) map!(f, a_dest, a_srcs...) |
|
0 commit comments