Skip to content

Commit 2ac5277

Browse files
committed
More functionality working
1 parent 9061d9c commit 2ac5277

File tree

9 files changed

+62
-49
lines changed

9 files changed

+62
-49
lines changed

Project.toml

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
88
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
99
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
1010
BroadcastMapConversion = "4a4adec5-520f-4750-bb37-d5e66b4ddeb2"
11+
Derive = "a07dfc7f-7d04-4eb5-84cc-a97f051f655a"
1112
Dictionaries = "85a47980-9c8c-11e8-2b9f-f7ca1fa99fb4"
1213
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
1314
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
@@ -21,24 +22,20 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
2122
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
2223

2324
[sources]
24-
BroadcastMapConversion = {url = "https://github.com/ITensor/BroadcastMapConversion.jl"}
25-
GradedUnitRanges = {url = "https://github.com/ITensor/GradedUnitRanges.jl"}
26-
LabelledNumbers = {url = "https://github.com/ITensor/LabelledNumbers.jl"}
27-
NestedPermutedDimsArrays = {url = "https://github.com/ITensor/NestedPermutedDimsArrays.jl"}
28-
SparseArraysBase = {url = "https://github.com/ITensor/SparseArraysBase.jl"}
2925
TensorAlgebra = {url = "https://github.com/ITensor/TensorAlgebra.jl"}
30-
TypeParameterAccessors = {url = "https://github.com/ITensor/TypeParameterAccessors.jl"}
3126

3227
[compat]
3328
Adapt = "4.1.1"
3429
Aqua = "0.8.9"
3530
ArrayLayouts = "1.10.4"
3631
BlockArrays = "1.2.0"
32+
Derive = "0.3.1"
3733
Dictionaries = "0.4.3"
3834
GPUArraysCore = "0.1.0"
3935
LinearAlgebra = "1.10"
4036
MacroTools = "0.5.13"
4137
SplitApplyCombine = "1.2.3"
38+
TensorAlgebra = "0.1.0"
4239
Test = "1.10"
4340
julia = "1.10"
4441

docs/Project.toml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -12,10 +12,4 @@ TensorAlgebra = "68bd88dc-f39d-4e12-b2ca-f046b68fcc6a"
1212
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1313

1414
[sources]
15-
BroadcastMapConversion = {url = "https://github.com/ITensor/BroadcastMapConversion.jl"}
16-
GradedUnitRanges = {url = "https://github.com/ITensor/GradedUnitRanges.jl"}
17-
LabelledNumbers = {url = "https://github.com/ITensor/LabelledNumbers.jl"}
18-
NestedPermutedDimsArrays = {url = "https://github.com/ITensor/NestedPermutedDimsArrays.jl"}
19-
SparseArraysBase = {url = "https://github.com/ITensor/SparseArraysBase.jl"}
2015
TensorAlgebra = {url = "https://github.com/ITensor/TensorAlgebra.jl"}
21-
TypeParameterAccessors = {url = "https://github.com/ITensor/TypeParameterAccessors.jl"}

examples/Project.toml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -13,10 +13,4 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1313
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1414

1515
[sources]
16-
BroadcastMapConversion = {url = "https://github.com/ITensor/BroadcastMapConversion.jl"}
17-
GradedUnitRanges = {url = "https://github.com/ITensor/GradedUnitRanges.jl"}
18-
LabelledNumbers = {url = "https://github.com/ITensor/LabelledNumbers.jl"}
19-
NestedPermutedDimsArrays = {url = "https://github.com/ITensor/NestedPermutedDimsArrays.jl"}
20-
SparseArraysBase = {url = "https://github.com/ITensor/SparseArraysBase.jl"}
2116
TensorAlgebra = {url = "https://github.com/ITensor/TensorAlgebra.jl"}
22-
TypeParameterAccessors = {url = "https://github.com/ITensor/TypeParameterAccessors.jl"}

src/abstractblocksparsearray/map.jl

Lines changed: 37 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
using ArrayLayouts: LayoutArray
22
using BlockArrays: blockisequal
3+
using Derive: @interface, interface
34
using LinearAlgebra: Adjoint, Transpose
45
using SparseArraysBase:
56
SparseArraysBase,
@@ -16,7 +17,7 @@ using SparseArraysBase:
1617
function union_stored_blocked_cartesianindices(as::Vararg{AbstractArray})
1718
combined_axes = combine_axes(axes.(as)...)
1819
stored_blocked_cartesianindices_as = map(as) do a
19-
return blocked_cartesianindices(axes(a), combined_axes, block_stored_indices(a))
20+
return blocked_cartesianindices(axes(a), combined_axes, block_eachstoredindex(a))
2021
end
2122
return (stored_blocked_cartesianindices_as...)
2223
end
@@ -57,14 +58,14 @@ function reblock(
5758
return @view parent(a)[map(I -> Vector(I.blocks), parentindices(a))...]
5859
end
5960

61+
# TODO: Move to `blocksparsearrayinterface/map.jl`.
6062
# TODO: Rewrite this so that it takes the blocking structure
6163
# made by combining the blocking of the axes (i.e. the blocking that
6264
# is used to determine `union_stored_blocked_cartesianindices(...)`).
6365
# `reblock` is a partial solution to that, but a bit ad-hoc.
64-
# TODO: Move to `blocksparsearrayinterface/map.jl`.
6566
## TODO: Make this an `@interface AbstractBlockSparseArrayInterface` function.
66-
function sparse_map!(
67-
::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray}
67+
@interface ::AbstractBlockSparseArrayInterface function Base.map!(
68+
f, a_dest::AbstractArray, a_srcs::AbstractArray...
6869
)
6970
a_dest, a_srcs = reblock(a_dest), reblock.(a_srcs)
7071
for I in union_stored_blocked_cartesianindices(a_dest, a_srcs...)
@@ -89,12 +90,28 @@ function sparse_map!(
8990
return a_dest
9091
end
9192

92-
# TODO: Implement this.
93-
# function SparseArraysBase.sparse_mapreduce(::BlockSparseArrayStyle, f, a_dest::AbstractArray, a_srcs::Vararg{AbstractArray})
94-
# end
93+
# TODO: Move to `blocksparsearrayinterface/map.jl`.
94+
@interface ::AbstractBlockSparseArrayInterface function Base.mapreduce(
95+
f, op, as::AbstractArray...; kwargs...
96+
)
97+
# TODO: Define an `init` value based on the element type.
98+
return @interface interface(blocks.(as)...) mapreduce(
99+
block -> mapreduce(f, op, block), op, blocks.(as)...; kwargs...
100+
)
101+
end
102+
103+
# TODO: Move to `blocksparsearrayinterface/map.jl`.
104+
@interface ::AbstractBlockSparseArrayInterface function Base.iszero(a::AbstractArray)
105+
return @interface interface(blocks(a)) iszero(blocks(a))
106+
end
107+
108+
# TODO: Move to `blocksparsearrayinterface/map.jl`.
109+
@interface ::AbstractBlockSparseArrayInterface function Base.isreal(a::AbstractArray)
110+
return @interface interface(blocks(a)) isreal(blocks(a))
111+
end
95112

96-
function Base.map!(f, a_dest::AbstractArray, a_srcs::Vararg{AnyAbstractBlockSparseArray})
97-
sparse_map!(f, a_dest, a_srcs...)
113+
function Base.map!(f, a_dest::AbstractArray, a_srcs::AnyAbstractBlockSparseArray...)
114+
@interface interface(a_srcs...) map!(f, a_dest, a_srcs...)
98115
return a_dest
99116
end
100117

@@ -103,50 +120,53 @@ function Base.map(f, as::Vararg{AnyAbstractBlockSparseArray})
103120
end
104121

105122
function Base.copy!(a_dest::AbstractArray, a_src::AnyAbstractBlockSparseArray)
123+
# TODO: Call `@interface`.
106124
sparse_copy!(a_dest, a_src)
107125
return a_dest
108126
end
109127

110128
function Base.copyto!(a_dest::AbstractArray, a_src::AnyAbstractBlockSparseArray)
129+
# TODO: Call `@interface`.
111130
sparse_copyto!(a_dest, a_src)
112131
return a_dest
113132
end
114133

115134
# Fix ambiguity error
116135
function Base.copyto!(a_dest::LayoutArray, a_src::AnyAbstractBlockSparseArray)
136+
# TODO: Call `@interface`.
117137
sparse_copyto!(a_dest, a_src)
118138
return a_dest
119139
end
120140

121141
function Base.copyto!(
122142
a_dest::AbstractMatrix, a_src::Transpose{T,<:AbstractBlockSparseMatrix{T}}
123143
) where {T}
144+
# TODO: Call `@interface`.
124145
sparse_copyto!(a_dest, a_src)
125146
return a_dest
126147
end
127148

128149
function Base.copyto!(
129150
a_dest::AbstractMatrix, a_src::Adjoint{T,<:AbstractBlockSparseMatrix{T}}
130151
) where {T}
152+
# TODO: Call `@interface`.
131153
sparse_copyto!(a_dest, a_src)
132154
return a_dest
133155
end
134156

135157
function Base.permutedims!(a_dest, a_src::AnyAbstractBlockSparseArray, perm)
136-
sparse_permutedims!(a_dest, a_src, perm)
137-
return a_dest
158+
return @interface interface(a_src) permutedims!(a_dest, a_src, perm)
138159
end
139160

140-
function Base.mapreduce(f, op, as::Vararg{AnyAbstractBlockSparseArray}; kwargs...)
141-
return sparse_mapreduce(f, op, as...; kwargs...)
161+
function Base.mapreduce(f, op, as::AnyAbstractBlockSparseArray...; kwargs...)
162+
@show interface(as...)
163+
return @interface interface(as...) mapreduce(f, op, as...; kwargs...)
142164
end
143165

144-
# TODO: Why isn't this calling `mapreduce` already?
145166
function Base.iszero(a::AnyAbstractBlockSparseArray)
146-
return sparse_iszero(blocks(a))
167+
return @interface interface(a) iszero(a)
147168
end
148169

149-
# TODO: Why isn't this calling `mapreduce` already?
150170
function Base.isreal(a::AnyAbstractBlockSparseArray)
151-
return sparse_isreal(blocks(a))
171+
return @interface interface(a) isreal(a)
152172
end

src/abstractblocksparsearray/wrappedabstractblocksparsearray.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using BlockArrays:
88
blockedrange,
99
mortar,
1010
unblock
11+
using Derive: Derive
1112
using SplitApplyCombine: groupcount
1213
using TypeParameterAccessors: similartype
1314

@@ -20,6 +21,8 @@ const AnyAbstractBlockSparseArray{T,N} = Union{
2021
<:AbstractBlockSparseArray{T,N},<:WrappedAbstractBlockSparseArray{T,N}
2122
}
2223

24+
Derive.interface(::Type{<:AnyAbstractBlockSparseArray}) = BlockSparseArrayInterface()
25+
2326
# a[1:2, 1:2]
2427
function Base.to_indices(
2528
a::AnyAbstractBlockSparseArray, inds, I::Tuple{UnitRange{<:Integer},Vararg{Any}}

src/blocksparsearrayinterface/blocksparsearrayinterface.jl

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,7 +13,12 @@ using BlockArrays:
1313
blocks,
1414
findblockindex
1515
using LinearAlgebra: Adjoint, Transpose
16-
using SparseArraysBase: perm, iperm, storedlength, sparse_zero!
16+
using SparseArraysBase:
17+
AbstractSparseArrayInterface, perm, iperm, storedlength, sparse_zero!
18+
19+
abstract type AbstractBlockSparseArrayInterface <: AbstractSparseArrayInterface end
20+
21+
struct BlockSparseArrayInterface <: AbstractBlockSparseArrayInterface end
1722

1823
blocksparse_blocks(a::AbstractArray) = error("Not implemented")
1924

@@ -265,8 +270,10 @@ SparseArraysBase.storedlength(a::SparseSubArrayBlocks) = length(eachstoredindex(
265270
## return BlockZero(axes(a.array))
266271
## end
267272

268-
function SparseArraysBase.getunstoredindex(a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}) where {N}
269-
error("Not implemented.")
273+
function SparseArraysBase.getunstoredindex(
274+
a::SparseSubArrayBlocks{<:Any,N}, I::Vararg{Int,N}
275+
) where {N}
276+
return error("Not implemented.")
270277
end
271278

272279
to_blocks_indices(I::BlockSlice{<:BlockRange{1}}) = Int.(I.block)

src/blocksparsearrayinterface/broadcast.jl

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,12 @@
11
using Base.Broadcast: BroadcastStyle, AbstractArrayStyle, DefaultArrayStyle, Broadcasted
22
using BroadcastMapConversion: map_function, map_args
3+
using Derive: Derive, @interface
34

4-
struct BlockSparseArrayStyle{N} <: AbstractArrayStyle{N} end
5+
abstract type AbstractBlockSparseArrayStyle{N} <: AbstractArrayStyle{N} end
6+
7+
Derive.interface(::Type{<:AbstractBlockSparseArrayStyle}) = BlockSparseArrayInterface()
8+
9+
struct BlockSparseArrayStyle{N} <: AbstractBlockSparseArrayStyle{N} end
510

611
# Define for new sparse array types.
712
# function Broadcast.BroadcastStyle(arraytype::Type{<:MyBlockSparseArray})
@@ -29,11 +34,12 @@ function Base.similar(bc::Broadcasted{<:BlockSparseArrayStyle}, elt::Type)
2934
end
3035

3136
# Broadcasting implementation
37+
# TODO: Delete this in favor of `Derive` version.
3238
function Base.copyto!(
3339
dest::AbstractArray{<:Any,N}, bc::Broadcasted{BlockSparseArrayStyle{N}}
3440
) where {N}
3541
# convert to map
3642
# flatten and only keep the AbstractArray arguments
37-
sparse_map!(map_function(bc), dest, map_args(bc)...)
43+
@interface interface(bc) map!(map_function(bc), dest, map_args(bc)...)
3844
return dest
3945
end

src/blocksparsearrayinterface/cat.jl

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,7 @@ using SparseArraysBase: SparseArraysBase, allocate_cat_output, sparse_cat!
55
# TODO: Handle dual graded unit ranges, for example in a new `SparseArraysBaseGradedUnitRangesExt`.
66
## TODO: Add this back.
77
## function SparseArraysBase.axis_cat(
8-
function axis_cat(
9-
a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange
10-
)
8+
function axis_cat(a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange)
119
return blockedrange(vcat(blocklengths(a1), blocklengths(a2)))
1210
end
1311

test/Project.toml

Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,11 +18,5 @@ Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
1818
TypeParameterAccessors = "7e5a90cf-f82e-492e-a09b-e3e26432c138"
1919

2020
[sources]
21-
BroadcastMapConversion = {url = "https://github.com/ITensor/BroadcastMapConversion.jl"}
22-
GradedUnitRanges = {url = "https://github.com/ITensor/GradedUnitRanges.jl"}
23-
LabelledNumbers = {url = "https://github.com/ITensor/LabelledNumbers.jl"}
24-
NestedPermutedDimsArrays = {url = "https://github.com/ITensor/NestedPermutedDimsArrays.jl"}
25-
SparseArraysBase = {url = "https://github.com/ITensor/SparseArraysBase.jl"}
2621
SymmetrySectors = {url = "https://github.com/ITensor/SymmetrySectors.jl"}
2722
TensorAlgebra = {url = "https://github.com/ITensor/TensorAlgebra.jl"}
28-
TypeParameterAccessors = {url = "https://github.com/ITensor/TypeParameterAccessors.jl"}

0 commit comments

Comments
 (0)