diff --git a/Project.toml b/Project.toml index dbbe6bb..46edc19 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "KroneckerArrays" uuid = "05d0b138-81bc-4ff7-84be-08becefb1ccc" authors = ["ITensor developers and contributors"] -version = "0.1.22" +version = "0.1.23" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/cartesianproduct.jl b/src/cartesianproduct.jl index e0138f3..be1c4fa 100644 --- a/src/cartesianproduct.jl +++ b/src/cartesianproduct.jl @@ -134,13 +134,22 @@ function cartesianrange(p::CartesianProduct, range::AbstractUnitRange) end function Base.axes(r::CartesianProductUnitRange) - return (CartesianProductUnitRange(cartesianproduct(r), only(axes(unproduct(r)))),) + prod = cartesianproduct(r) + prod_ax = only(axes(arg1(prod))) × only(axes(arg2(prod))) + return (CartesianProductUnitRange(prod_ax, only(axes(unproduct(r)))),) end function Base.checkindex(::Type{Bool}, inds::CartesianProductUnitRange, i::CartesianPair) return checkindex(Bool, arg1(inds), arg1(i)) && checkindex(Bool, arg2(inds), arg2(i)) end +const CartesianProductOneTo{T,P<:CartesianProduct,R<:Base.OneTo{T}} = CartesianProductUnitRange{ + T,P,R +} +Base.axes(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,) +Base.axes1(S::Base.Slice{<:CartesianProductOneTo}) = S.indices +Base.unsafe_indices(S::Base.Slice{<:CartesianProductOneTo}) = (S.indices,) + function Base.getindex(a::CartesianProductUnitRange, I::CartesianProduct) prod = cartesianproduct(a) prod_I = arg1(prod)[arg1(I)] × arg2(prod)[arg2(I)] diff --git a/src/kroneckerarray.jl b/src/kroneckerarray.jl index 96167f3..6298f29 100644 --- a/src/kroneckerarray.jl +++ b/src/kroneckerarray.jl @@ -38,6 +38,10 @@ function Base.copyto!(dest::KroneckerArray, src::KroneckerArray) return dest end +function Base.convert(::Type{KroneckerArray{T,N,A,B}}, a::KroneckerArray) where {T,N,A,B} + return KroneckerArray(convert(A, arg1(a)), convert(B, arg2(a))) +end + # Like `similar` but allows some custom behavior, such as for `FillArrays.Eye`. function _similar(a::AbstractArray, elt::Type, axs::Tuple{Vararg{AbstractUnitRange}}) return similar(a, elt, axs) @@ -189,7 +193,14 @@ Base.getindex(a::KroneckerArray{<:Any,0}) = arg1(a)[] * arg2(a)[] # Allow customizing for `FillArrays.Eye`. _view(a::AbstractArray, I...) = view(a, I...) -function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianProduct,N}) where {N} +arg1(::Colon) = (:) +arg2(::Colon) = (:) +arg1(::Base.Slice) = (:) +arg2(::Base.Slice) = (:) +function Base.view( + a::KroneckerArray{<:Any,N}, + I::Vararg{Union{CartesianProduct,CartesianProductUnitRange,Base.Slice,Colon},N}, +) where {N} return _view(arg1(a), arg1.(I)...) ⊗ _view(arg2(a), arg2.(I)...) end function Base.view(a::KroneckerArray{<:Any,N}, I::Vararg{CartesianPair,N}) where {N} @@ -272,10 +283,8 @@ function Base.BroadcastStyle(style1::KroneckerStyle{N}, style2::KroneckerStyle{N return KroneckerStyle{N}(style_a, style_b) end function Base.similar(bc::Broadcasted{<:KroneckerStyle{N,A,B}}, elt::Type, ax) where {N,A,B} - ax_a = arg1.(ax) - ax_b = arg2.(ax) - bc_a = Broadcasted(A, nothing, (), ax_a) - bc_b = Broadcasted(B, nothing, (), ax_b) + bc_a = Broadcasted(A, bc.f, arg1.(bc.args), arg1.(ax)) + bc_b = Broadcasted(B, bc.f, arg2.(bc.args), arg2.(ax)) a = similar(bc_a, elt) b = similar(bc_b, elt) return a ⊗ b