Skip to content

Commit 0ef437d

Browse files
committed
Add adapt support to BlockSparseArrays
1 parent 0fbac75 commit 0ef437d

File tree

3 files changed

+21
-0
lines changed

3 files changed

+21
-0
lines changed
Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
module BlockSparseArraysAdaptExt
2+
using Adapt: Adapt, adapt
3+
using ..BlockSparseArrays: AbstractBlockSparseArray, map_stored_blocks
4+
Adapt.adapt_structure(to, x::AbstractBlockSparseArray) = map_stored_blocks(adapt(to), x)
5+
end

NDTensors/src/lib/BlockSparseArrays/src/BlockSparseArrays.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ include("blocksparsearrayinterface/blocksparsearrayinterface.jl")
44
include("blocksparsearrayinterface/linearalgebra.jl")
55
include("blocksparsearrayinterface/blockzero.jl")
66
include("blocksparsearrayinterface/broadcast.jl")
7+
include("blocksparsearrayinterface/map.jl")
78
include("blocksparsearrayinterface/arraylayouts.jl")
89
include("blocksparsearrayinterface/views.jl")
910
include("abstractblocksparsearray/abstractblocksparsearray.jl")
@@ -20,4 +21,5 @@ include("blocksparsearray/blocksparsearray.jl")
2021
include("BlockArraysSparseArrayInterfaceExt/BlockArraysSparseArrayInterfaceExt.jl")
2122
include("../ext/BlockSparseArraysTensorAlgebraExt/src/BlockSparseArraysTensorAlgebraExt.jl")
2223
include("../ext/BlockSparseArraysGradedAxesExt/src/BlockSparseArraysGradedAxesExt.jl")
24+
include("../ext/BlockSparseArraysAdaptExt/src/BlockSparseArraysAdaptExt.jl")
2325
end
Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,14 @@
1+
function map_stored_blocks(f, a::AbstractArray)
2+
# TODO: Implement this as:
3+
# ```julia
4+
# mapped_blocks = SparseArraysInterface.map_stored(f, blocks(a))
5+
# BlockSparseArray(mapped_blocks, axes(a))
6+
# ```
7+
# TODO: `block_stored_indices` should output `Indices` storing
8+
# the stored Blocks, not a `Dictionary` from cartesian indices
9+
# to Blocks.
10+
bs = block_stored_indices(a)
11+
mapped_blocks = Dictionary(bs, map(b -> f(@view(a[b])), bs))
12+
# TODO: Use `similartype(typeof(a), eltype(eltype(mapped_blocks)))(...)`.
13+
return BlockSparseArray(mapped_blocks, axes(a))
14+
end

0 commit comments

Comments
 (0)