Skip to content

Commit 5acb5a8

Browse files
committed
More general scalar indexing
1 parent 313374c commit 5acb5a8

File tree

3 files changed

+17
-18
lines changed

3 files changed

+17
-18
lines changed

src/kroneckerarray.jl

Lines changed: 5 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -158,26 +158,13 @@ function Base.getindex(a::KroneckerArray, i::Integer)
158158
return a[CartesianIndices(a)[i]]
159159
end
160160

161-
# TODO: Use this logic from KroneckerProducts.jl for cartesian indexing
162-
# in the n-dimensional case and use it to replace the matrix and vector cases:
163-
# https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66
164-
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N}
165-
return error("Not implemented.")
166-
end
167-
168161
using GPUArraysCore: GPUArraysCore
169-
function Base.getindex(a::KroneckerMatrix, i1::Integer, i2::Integer)
170-
GPUArraysCore.assertscalar("getindex")
171-
# Code logic from Kronecker.jl:
172-
# https://github.com/MichielStock/Kronecker.jl/blob/v0.5.5/src/base.jl#L101-L105
173-
k, l = size(arg2(a))
174-
return arg1(a)[cld(i1, k), cld(i2, l)] * arg2(a)[(i1 - 1) % k + 1, (i2 - 1) % l + 1]
175-
end
176-
177-
function Base.getindex(a::KroneckerVector, i::Integer)
162+
function Base.getindex(a::KroneckerArray{<:Any,N}, I::Vararg{Integer,N}) where {N}
178163
GPUArraysCore.assertscalar("getindex")
179-
k = length(arg2(a))
180-
return arg1(a)[cld(i, k)] * arg2(a)[(i - 1) % k + 1]
164+
I′ = ntuple(Val(N)) do dim
165+
return cartesianproduct(axes(a, dim))[I[dim]]
166+
end
167+
return a[I′...]
181168
end
182169

183170
# Allow customizing for `FillArrays.Eye`.

test/Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@ BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
66
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
77
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
88
FillArrays = "1a297f60-69ca-5386-bcde-b61e274b549b"
9+
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
910
JLArrays = "27aeb0d3-9eb9-45fb-866b-73c2ecf80fcb"
1011
KroneckerArrays = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
1112
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
@@ -24,6 +25,7 @@ BlockSparseArrays = "0.7.19"
2425
DerivableInterfaces = "0.5"
2526
DiagonalArrays = "0.3.7"
2627
FillArrays = "1"
28+
GPUArraysCore = "0.2"
2729
JLArrays = "0.2"
2830
KroneckerArrays = "0.1"
2931
LinearAlgebra = "1.10"

test/test_basics.jl

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ using Adapt: adapt
22
using Base.Broadcast: BroadcastStyle, Broadcasted, broadcasted
33
using DerivableInterfaces: zero!
44
using DiagonalArrays: diagonal
5+
using GPUArraysCore: @allowscalar
56
using JLArrays: JLArray
67
using KroneckerArrays:
78
KroneckerArrays,
@@ -44,6 +45,15 @@ elts = (Float32, Float64, ComplexF32, ComplexF64)
4445
@test first(r) == 2
4546
@test last(r) == 7
4647

48+
# Test high-dimensional materialization.
49+
a = randn(elt, 2, 2, 2) randn(elt, 2, 2, 2)
50+
x = Array(a)
51+
y = similar(x)
52+
for I in eachindex(a)
53+
y[I] = @allowscalar x[I]
54+
end
55+
@test x == y
56+
4757
a = @constinferred(randn(elt, 2, 2) randn(elt, 3, 3))
4858
b = randn(elt, 2, 2) randn(elt, 3, 3)
4959
c = a.a b.b

0 commit comments

Comments
 (0)