Skip to content

Commit 9636477

Browse files
authored
More general truncation and slicing (#29)
1 parent 2f2caf3 commit 9636477

File tree

3 files changed

+25
-7
lines changed

3 files changed

+25
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "KroneckerArrays"
22
uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc"
33
authors = ["ITensor developers <[email protected]> and contributors"]
4-
version = "0.1.22"
4+
version = "0.1.23"
55

66
[deps]
77
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"

src/cartesianproduct.jl

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -134,13 +134,22 @@ function cartesianrange(p::CartesianProduct, range::AbstractUnitRange)
134134
end
135135

136136
function Base.axes(r::CartesianProductUnitRange)
137-
return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),)
137+
prod = cartesianproduct(r)
138+
prod_ax = only(axes(arg1(prod))) × only(axes(arg2(prod)))
139+
return (CartesianProductUnitRange(prod_ax, only(axes(unproduct(r)))),)
138140
end
139141

140142
function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::CartesianPair)
141143
return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i))
142144
end
143145

146+
const CartesianProductOneTo{T,P<:CartesianProduct,R<:Base.OneTo{T}} = CartesianProductUnitRange{
147+
T,P,R
148+
}
149+
Base.axes(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,)
150+
Base.axes1(S::Base.Slice{<:CartesianProductOneTo}) = S.indices
151+
Base.unsafe_indices(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,)
152+
144153
function Base.getindex(a::CartesianProductUnitRange, I::CartesianProduct)
145154
prod = cartesianproduct(a)
146155
prod_I = arg1(prod)[arg1(I)] × arg2(prod)[arg2(I)]

src/kroneckerarray.jl

Lines changed: 14 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,6 +38,10 @@ function Base.copyto!(dest::KroneckerArray, src::KroneckerArray)
3838
return dest
3939
end
4040

41+
function Base.convert(::Type{KroneckerArray{T,N,A,B}}, a::KroneckerArray) where {T,N,A,B}
42+
return KroneckerArray(convert(A, arg1(a)), convert(B, arg2(a)))
43+
end
44+
4145
# Like `similar` but allows some custom behavior, such as for `FillArrays.Eye`.
4246
function _similar(a::AbstractArray, elt::Type, axs::Tuple{Vararg{AbstractUnitRange}})
4347
return similar(a, elt, axs)
@@ -189,7 +193,14 @@ Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[]
189193

190194
# Allow customizing for `FillArrays.Eye`.
191195
_view(a::AbstractArray, I...) = view(a, I...)
192-
function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N}
196+
arg1(::Colon) = (:)
197+
arg2(::Colon) = (:)
198+
arg1(::Base.Slice) = (:)
199+
arg2(::Base.Slice) = (:)
200+
function Base.view(
201+
a::KroneckerArray{<:Any,N},
202+
I::Vararg{Union{CartesianProduct,CartesianProductUnitRange,Base.Slice,Colon},N},
203+
) where {N}
193204
return _view(arg1(a), arg1.(I)...) _view(arg2(a), arg2.(I)...)
194205
end
195206
function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N}
@@ -272,10 +283,8 @@ function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N
272283
return KroneckerStyle{N}(style_a, style_b)
273284
end
274285
function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type, ax) where {N,A,B}
275-
ax_a = arg1.(ax)
276-
ax_b = arg2.(ax)
277-
bc_a = Broadcasted(A, nothing, (), ax_a)
278-
bc_b = Broadcasted(B, nothing, (), ax_b)
286+
bc_a = Broadcasted(A, bc.f, arg1.(bc.args), arg1.(ax))
287+
bc_b = Broadcasted(B, bc.f, arg2.(bc.args), arg2.(ax))
279288
a = similar(bc_a, elt)
280289
b = similar(bc_b, elt)
281290
return a b

0 commit comments

Comments
 (0)