Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
name = "BlockSparseArrays"
uuid = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
authors = ["ITensor developers <[email protected]> and contributors"]
version = "0.3.5"
version = "0.3.6"

[deps]
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
Expand Down Expand Up @@ -33,7 +33,7 @@ Adapt = "4.1.1"
Aqua = "0.8.9"
ArrayLayouts = "1.10.4"
BlockArrays = "1.2.0"
DerivableInterfaces = "0.3.8"
DerivableInterfaces = "0.4"
DiagonalArrays = "0.3"
Dictionaries = "0.4.3"
FillArrays = "1.13.0"
Expand Down
7 changes: 2 additions & 5 deletions src/abstractblocksparsearray/cat.jl
Original file line number Diff line number Diff line change
@@ -1,9 +1,6 @@
using DerivableInterfaces: @interface, interface
using DerivableInterfaces.Concatenate: concatenate

function Base._cat(dims, as::AnyAbstractBlockSparseArray...)
# TODO: Call `DerivableInterfaces.cat_along(dims, as...)` instead,
# for better inferability. See:
# https://github.com/ITensor/DerivableInterfaces.jl/pull/13
# https://github.com/ITensor/DerivableInterfaces.jl/pull/17
return @interface interface(as...) cat(as...; dims)
return concatenate(dims, as...)
end
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using Adapt: Adapt, WrappedArray, adapt
using ArrayLayouts: zero!
using ArrayLayouts: ArrayLayouts
using BlockArrays:
BlockArrays,
AbstractBlockVector,
Expand All @@ -9,7 +9,7 @@ using BlockArrays:
blockedrange,
mortar,
unblock
using DerivableInterfaces: DerivableInterfaces, @interface, DefaultArrayInterface
using DerivableInterfaces: DerivableInterfaces, @interface, DefaultArrayInterface, zero!
using GPUArraysCore: @allowscalar
using SplitApplyCombine: groupcount
using TypeParameterAccessors: similartype
Expand Down Expand Up @@ -154,9 +154,8 @@ function Base.setindex!(a::AnyAbstractBlockSparseArray{<:Any,1}, value, I::Block
return a
end

# TODO: Use `@derive`.
function ArrayLayouts.zero!(a::AnyAbstractBlockSparseArray)
return @interface interface(a) zero!(a)
return zero!(a)
end

# TODO: Use `@derive`.
Expand Down
8 changes: 5 additions & 3 deletions src/blocksparsearrayinterface/blocksparsearrayinterface.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
using ArrayLayouts: ArrayLayouts, zero!
using ArrayLayouts: ArrayLayouts
using BlockArrays:
BlockArrays,
AbstractBlockVector,
Expand All @@ -16,7 +16,7 @@ using BlockArrays:
blocklength,
blocks,
findblockindex
using DerivableInterfaces: DerivableInterfaces, @interface, DefaultArrayInterface
using DerivableInterfaces: DerivableInterfaces, @interface, DefaultArrayInterface, zero!
using LinearAlgebra: Adjoint, Transpose
using SparseArraysBase:
AbstractSparseArrayInterface,
Expand Down Expand Up @@ -266,7 +266,9 @@ end
return a
end

@interface ::AbstractBlockSparseArrayInterface function ArrayLayouts.zero!(a::AbstractArray)
@interface ::AbstractBlockSparseArrayInterface function DerivableInterfaces.zero!(
a::AbstractArray
)
# This will try to empty the storage if possible.
zero!(blocks(a))
return a
Expand Down
26 changes: 11 additions & 15 deletions src/blocksparsearrayinterface/cat.jl
Original file line number Diff line number Diff line change
@@ -1,18 +1,14 @@
using BlockArrays: AbstractBlockedUnitRange, blockedrange, blocklengths
using DerivableInterfaces: DerivableInterfaces, @interface, cat!
using SparseArraysBase: SparseArraysBase
using BlockArrays: blocks
using DerivableInterfaces.Concatenate: Concatenated, cat!

# TODO: Maybe move to `DerivableInterfacesBlockArraysExt`.
# TODO: Handle dual graded unit ranges, for example in a new `SparseArraysBaseGradedUnitRangesExt`.
function DerivableInterfaces.axis_cat(
a1::AbstractBlockedUnitRange, a2::AbstractBlockedUnitRange
function Base.copyto!(
dest::AbstractArray, concat::Concatenated{<:BlockSparseArrayInterface}
)
return blockedrange(vcat(blocklengths(a1), blocklengths(a2)))
end

@interface ::AbstractBlockSparseArrayInterface function DerivableInterfaces.cat!(
a_dest::AbstractArray, as::AbstractArray...; dims
)
cat!(blocks(a_dest), blocks.(as)...; dims)
return a_dest
# TODO: This assumes the destination blocking is commensurate with
# the blocking of the sources, for example because it was constructed
# based on the input arguments. Maybe check that explicitly.
# This should mostly just get called from `cat` anyway and not get
# called explicitly.
cat!(blocks(dest), blocks.(concat.args)...; dims=concat.dims)
return dest
end
2 changes: 1 addition & 1 deletion src/blocksparsearrayinterface/getunstoredblock.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
using ArrayLayouts: zero!
using BlockArrays: Block
using DerivableInterfaces: zero!

struct GetUnstoredBlock{Axes}
axes::Axes
Expand Down
1 change: 1 addition & 0 deletions test/Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ Aqua = "4c88cf16-eb10-579e-8560-4a9242c79595"
ArrayLayouts = "4c555306-a7a7-4459-81d9-ec55ddd5c99a"
BlockArrays = "8e7c35d0-a365-5155-bbbb-fb81a777f24e"
BlockSparseArrays = "2c9a651f-6452-4ace-a6ac-809f4280fbb4"
DerivableInterfaces = "6c5e35bf-e59e-4898-b73c-732dcc4ba65f"
DiagonalArrays = "74fd4be6-21e2-4f6f-823a-4360d37c7a77"
GPUArraysCore = "46192b85-c4d5-4398-a991-12ede77f4527"
GradedUnitRanges = "e2de450a-8a67-46c7-b59c-01d5a3d041c5"
Expand Down
Loading