@@ -158,26 +158,13 @@ function Base.getindex(a::KroneckerArray, i::Integer)
158158 return a[CartesianIndices (a)[i]]
159159end
160160
161- # TODO : Use this logic from KroneckerProducts.jl for cartesian indexing
162- # in the n-dimensional case and use it to replace the matrix and vector cases:
163- # https://github.com/perrutquist/KroneckerProducts.jl/blob/8c0104caf1f17729eb067259ba1473986121d032/src/KroneckerProducts.jl#L59-L66
164- function Base. getindex (a:: KroneckerArray{<:Any,N} , I:: Vararg{Integer,N} ) where {N}
165- return error (" Not implemented." )
166- end
167-
168161using GPUArraysCore: GPUArraysCore
169- function Base. getindex (a:: KroneckerMatrix , i1:: Integer , i2:: Integer )
170- GPUArraysCore. assertscalar (" getindex" )
171- # Code logic from Kronecker.jl:
172- # https://github.com/MichielStock/Kronecker.jl/blob/v0.5.5/src/base.jl#L101-L105
173- k, l = size (arg2 (a))
174- return arg1 (a)[cld (i1, k), cld (i2, l)] * arg2 (a)[(i1 - 1 ) % k + 1 , (i2 - 1 ) % l + 1 ]
175- end
176-
177- function Base. getindex (a:: KroneckerVector , i:: Integer )
162+ function Base. getindex (a:: KroneckerArray{<:Any,N} , I:: Vararg{Integer,N} ) where {N}
178163 GPUArraysCore. assertscalar (" getindex" )
179- k = length (arg2 (a))
180- return arg1 (a)[cld (i, k)] * arg2 (a)[(i - 1 ) % k + 1 ]
164+ I′ = ntuple (Val (N)) do dim
165+ return cartesianproduct (axes (a, dim))[I[dim]]
166+ end
167+ return a[I′... ]
181168end
182169
183170# Allow customizing for `FillArrays.Eye`.
0 commit comments