Skip to content

Commit d40ca1a

Browse files
authored
[SparseArrayInterface] NestedPermutedDimsArray support (#1590)
1 parent 3594216 commit d40ca1a

File tree

5 files changed

+111
-26
lines changed

5 files changed

+111
-26
lines changed

NDTensors/src/lib/SparseArrayInterface/src/abstractsparsearray/sparsearrayinterface.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,7 +32,7 @@ end
3232

3333
# TODO: Make this into a generic definition of all `AbstractArray`?
3434
function SparseArrayInterface.stored_indices(
35-
a::PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray}
35+
a::AnyPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray}
3636
)
3737
return Iterators.map(
3838
I -> CartesianIndex(map(i -> I[i], perm(a))), stored_indices(parent(a))
@@ -41,7 +41,7 @@ end
4141

4242
# TODO: Make this into a generic definition of all `AbstractArray`?
4343
function SparseArrayInterface.sparse_storage(
44-
a::PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray}
44+
a::AnyPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:AbstractSparseArray}
4545
)
4646
return sparse_storage(parent(a))
4747
end

NDTensors/src/lib/SparseArrayInterface/src/sparsearrayinterface/wrappers.jl

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,26 @@
1+
using ..NestedPermutedDimsArrays: NestedPermutedDimsArray
2+
13
## PermutedDimsArray
24

3-
perm(::PermutedDimsArray{<:Any,<:Any,P}) where {P} = P
4-
iperm(::PermutedDimsArray{<:Any,<:Any,<:Any,IP}) where {IP} = IP
5+
const AnyPermutedDimsArray{T,N,perm,iperm,P} = Union{
6+
PermutedDimsArray{T,N,perm,iperm,P},NestedPermutedDimsArray{T,N,perm,iperm,P}
7+
}
8+
9+
# TODO: Use `TypeParameterAccessors`.
10+
perm(::AnyPermutedDimsArray{<:Any,<:Any,Perm}) where {Perm} = Perm
11+
iperm(::AnyPermutedDimsArray{<:Any,<:Any,<:Any,IPerm}) where {IPerm} = IPerm
512

613
# TODO: Use `Base.PermutedDimsArrays.genperm` or
714
# https://github.com/jipolanco/StaticPermutations.jl?
815
genperm(v, perm) = map(j -> v[j], perm)
916
genperm(v::CartesianIndex, perm) = CartesianIndex(map(j -> Tuple(v)[j], perm))
1017

11-
function storage_index_to_index(a::PermutedDimsArray, I)
18+
function storage_index_to_index(a::AnyPermutedDimsArray, I)
1219
return genperm(storage_index_to_index(parent(a), I), perm(a))
1320
end
1421

1522
function index_to_storage_index(
16-
a::PermutedDimsArray{<:Any,N}, I::CartesianIndex{N}
23+
a::AnyPermutedDimsArray{<:Any,N}, I::CartesianIndex{N}
1724
) where {N}
1825
return index_to_storage_index(parent(a), genperm(I, perm(a)))
1926
end

NDTensors/src/lib/SparseArrayInterface/test/SparseArrayInterfaceTestUtils/AbstractSparseArrays.jl

Lines changed: 16 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
11
module AbstractSparseArrays
22
using ArrayLayouts: ArrayLayouts, MatMulMatAdd, MemoryLayout, MulAdd
3-
using NDTensors.SparseArrayInterface: SparseArrayInterface, AbstractSparseArray
3+
using NDTensors.SparseArrayInterface: SparseArrayInterface, AbstractSparseArray, Zero
44

5-
struct SparseArray{T,N} <: AbstractSparseArray{T,N}
5+
struct SparseArray{T,N,Zero} <: AbstractSparseArray{T,N}
66
data::Vector{T}
77
dims::Tuple{Vararg{Int,N}}
88
index_to_dataindex::Dict{CartesianIndex{N},Int}
99
dataindex_to_index::Vector{CartesianIndex{N}}
10+
zero::Zero
1011
end
11-
function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}) where {T,N}
12-
return SparseArray{T,N}(
13-
T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}()
12+
function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}; zero=Zero()) where {T,N}
13+
return SparseArray{T,N,typeof(zero)}(
14+
T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}(), zero
1415
)
1516
end
16-
SparseArray{T,N}(dims::Vararg{Int,N}) where {T,N} = SparseArray{T,N}(dims)
17-
SparseArray{T}(dims::Tuple{Vararg{Int}}) where {T} = SparseArray{T,length(dims)}(dims)
18-
function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}) where {T}
19-
return SparseArray{T}(dims)
17+
function SparseArray{T,N}(dims::Vararg{Int,N}; kwargs...) where {T,N}
18+
return SparseArray{T,N}(dims; kwargs...)
2019
end
21-
SparseArray{T}(dims::Vararg{Int}) where {T} = SparseArray{T}(dims)
20+
function SparseArray{T}(dims::Tuple{Vararg{Int}}; kwargs...) where {T}
21+
return SparseArray{T,length(dims)}(dims; kwargs...)
22+
end
23+
function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}; kwargs...) where {T}
24+
return SparseArray{T}(dims; kwargs...)
25+
end
26+
SparseArray{T}(dims::Vararg{Int}; kwargs...) where {T} = SparseArray{T}(dims; kwargs...)
2227

2328
# ArrayLayouts interface
2429
struct SparseLayout <: MemoryLayout end
@@ -41,6 +46,7 @@ function Base.similar(a::SparseArray, elt::Type, dims::Tuple{Vararg{Int}})
4146
end
4247

4348
# Minimal interface
49+
SparseArrayInterface.getindex_zero_function(a::SparseArray) = a.zero
4450
SparseArrayInterface.sparse_storage(a::SparseArray) = a.data
4551
function SparseArrayInterface.index_to_storage_index(
4652
a::SparseArray{<:Any,N}, I::CartesianIndex{N}

NDTensors/src/lib/SparseArrayInterface/test/SparseArrayInterfaceTestUtils/SparseArrays.jl

Lines changed: 43 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -1,24 +1,29 @@
11
module SparseArrays
22
using LinearAlgebra: LinearAlgebra
3-
using NDTensors.SparseArrayInterface: SparseArrayInterface
3+
using NDTensors.SparseArrayInterface: SparseArrayInterface, Zero
44

5-
struct SparseArray{T,N} <: AbstractArray{T,N}
5+
struct SparseArray{T,N,Zero} <: AbstractArray{T,N}
66
data::Vector{T}
77
dims::Tuple{Vararg{Int,N}}
88
index_to_dataindex::Dict{CartesianIndex{N},Int}
99
dataindex_to_index::Vector{CartesianIndex{N}}
10+
zero::Zero
1011
end
11-
function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}) where {T,N}
12-
return SparseArray{T,N}(
13-
T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}()
12+
function SparseArray{T,N}(dims::Tuple{Vararg{Int,N}}; zero=Zero()) where {T,N}
13+
return SparseArray{T,N,typeof(zero)}(
14+
T[], dims, Dict{CartesianIndex{N},Int}(), Vector{CartesianIndex{N}}(), zero
1415
)
1516
end
16-
SparseArray{T,N}(dims::Vararg{Int,N}) where {T,N} = SparseArray{T,N}(dims)
17-
SparseArray{T}(dims::Tuple{Vararg{Int}}) where {T} = SparseArray{T,length(dims)}(dims)
18-
function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}) where {T}
19-
return SparseArray{T}(dims)
17+
function SparseArray{T,N}(dims::Vararg{Int,N}; kwargs...) where {T,N}
18+
return SparseArray{T,N}(dims; kwargs...)
2019
end
21-
SparseArray{T}(dims::Vararg{Int}) where {T} = SparseArray{T}(dims)
20+
function SparseArray{T}(dims::Tuple{Vararg{Int}}; kwargs...) where {T}
21+
return SparseArray{T,length(dims)}(dims; kwargs...)
22+
end
23+
function SparseArray{T}(::UndefInitializer, dims::Tuple{Vararg{Int}}; kwargs...) where {T}
24+
return SparseArray{T}(dims; kwargs...)
25+
end
26+
SparseArray{T}(dims::Vararg{Int}; kwargs...) where {T} = SparseArray{T}(dims; kwargs...)
2227

2328
# LinearAlgebra interface
2429
function LinearAlgebra.mul!(
@@ -53,6 +58,7 @@ function Base.fill!(a::SparseArray, value)
5358
end
5459

5560
# Minimal interface
61+
SparseArrayInterface.getindex_zero_function(a::SparseArray) = a.zero
5662
SparseArrayInterface.sparse_storage(a::SparseArray) = a.data
5763
function SparseArrayInterface.index_to_storage_index(
5864
a::SparseArray{<:Any,N}, I::CartesianIndex{N}
@@ -79,6 +85,33 @@ function SparseArrayInterface.stored_indices(
7985
)
8086
end
8187

88+
# TODO: Make this into a generic definition of all `AbstractArray`?
89+
using NDTensors.SparseArrayInterface: sparse_storage
90+
function SparseArrayInterface.sparse_storage(
91+
a::PermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:SparseArray}
92+
)
93+
return sparse_storage(parent(a))
94+
end
95+
96+
# TODO: Make this into a generic definition of all `AbstractArray`?
97+
using NDTensors.NestedPermutedDimsArrays: NestedPermutedDimsArray
98+
function SparseArrayInterface.stored_indices(
99+
a::NestedPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:SparseArray}
100+
)
101+
return Iterators.map(
102+
I -> CartesianIndex(map(i -> I[i], perm(a))), stored_indices(parent(a))
103+
)
104+
end
105+
106+
# TODO: Make this into a generic definition of all `AbstractArray`?
107+
using NDTensors.NestedPermutedDimsArrays: NestedPermutedDimsArray
108+
using NDTensors.SparseArrayInterface: sparse_storage
109+
function SparseArrayInterface.sparse_storage(
110+
a::NestedPermutedDimsArray{<:Any,<:Any,<:Any,<:Any,<:SparseArray}
111+
)
112+
return sparse_storage(parent(a))
113+
end
114+
82115
# Empty the storage, helps with efficiency in `map!` to drop
83116
# zeros.
84117
function SparseArrayInterface.dropall!(a::SparseArray)

NDTensors/src/lib/SparseArrayInterface/test/test_abstractsparsearray.jl

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
@eval module $(gensym())
22
using LinearAlgebra: dot, mul!, norm
33
using NDTensors.SparseArrayInterface: SparseArrayInterface
4+
using NDTensors.NestedPermutedDimsArrays: NestedPermutedDimsArray
45
include("SparseArrayInterfaceTestUtils/SparseArrayInterfaceTestUtils.jl")
56
using .SparseArrayInterfaceTestUtils.AbstractSparseArrays: AbstractSparseArrays
67
using .SparseArrayInterfaceTestUtils.SparseArrays: SparseArrays
@@ -224,6 +225,44 @@ using Test: @test, @testset
224225
end
225226
end
226227

228+
a = SparseArray{elt}(2, 3)
229+
a[1, 2] = 12
230+
b = PermutedDimsArray(a, (2, 1))
231+
@test size(b) == (3, 2)
232+
@test axes(b) == (1:3, 1:2)
233+
@test SparseArrayInterface.sparse_storage(b) == elt[12]
234+
@test SparseArrayInterface.stored_length(b) == 1
235+
@test collect(SparseArrayInterface.stored_indices(b)) == [CartesianIndex(2, 1)]
236+
@test !iszero(b)
237+
@test !iszero(norm(b))
238+
for I in eachindex(b)
239+
if I == CartesianIndex(2, 1)
240+
@test b[I] == 12
241+
else
242+
@test iszero(b[I])
243+
end
244+
end
245+
246+
a = SparseArray{Matrix{elt}}(
247+
2, 3; zero=(a, I) -> (z = similar(eltype(a), 2, 3); fill!(z, false); z)
248+
)
249+
a[1, 2] = randn(elt, 2, 3)
250+
b = NestedPermutedDimsArray(a, (2, 1))
251+
@test size(b) == (3, 2)
252+
@test axes(b) == (1:3, 1:2)
253+
@test SparseArrayInterface.sparse_storage(b) == [a[1, 2]]
254+
@test SparseArrayInterface.stored_length(b) == 1
255+
@test collect(SparseArrayInterface.stored_indices(b)) == [CartesianIndex(2, 1)]
256+
@test !iszero(b)
257+
@test !iszero(norm(b))
258+
for I in eachindex(b)
259+
if I == CartesianIndex(2, 1)
260+
@test b[I] == permutedims(a[1, 2], (2, 1))
261+
else
262+
@test iszero(b[I])
263+
end
264+
end
265+
227266
a = SparseArray{elt}(2, 3)
228267
a[1, 2] = 12
229268
b = randn(elt, 2, 3)

0 commit comments

Comments
 (0)