Skip to content

Commit 8436b23

Browse files
committed
Fix slicing on GPU
1 parent 0ef437d commit 8436b23

File tree

2 files changed

+20
-4
lines changed

2 files changed

+20
-4
lines changed

NDTensors/src/lib/BlockSparseArrays/src/BlockArraysExtensions/BlockArraysExtensions.jl

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using ArrayLayouts: ArrayLayouts, MemoryLayout, sub_materialize
12
using BlockArrays:
23
BlockArrays,
34
AbstractBlockArray,
@@ -537,6 +538,20 @@ function SparseArrayInterface.nstored(a::BlockView)
537538
return 0
538539
end
539540

541+
## # Allow more fine-grained control:
542+
## function ArrayLayouts.sub_materialize(layout, a::BlockView, ax)
543+
## return blocks(a.array)[Int.(a.block)...]
544+
## end
545+
## function ArrayLayouts.sub_materialize(layout, a::BlockView)
546+
## return sub_materialize(layout, a, axes(a))
547+
## end
548+
## function ArrayLayouts.sub_materialize(a::BlockView)
549+
## return sub_materialize(MemoryLayout(a), a)
550+
## end
551+
function ArrayLayouts.sub_materialize(a::BlockView)
552+
return blocks(a.array)[Int.(a.block)...]
553+
end
554+
540555
function view!(a::AbstractArray{<:Any,N}, index::Block{N}) where {N}
541556
return view!(a, Tuple(index)...)
542557
end

NDTensors/src/lib/BlockSparseArrays/src/abstractblocksparsearray/arraylayouts.jl

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,10 @@ end
2222

2323
# Materialize a SubArray view.
2424
function ArrayLayouts.sub_materialize(layout::BlockLayout{<:SparseLayout}, a, axes)
25-
# TODO: Make more generic for GPU.
26-
a_dest = BlockSparseArray{eltype(a)}(axes)
25+
# TODO: Define `blocktype`/`blockstype` for `SubArray` wrapping `BlockSparseArray`.
26+
# TODO: Use `similar`?
27+
blocktype_a = blocktype(parent(a))
28+
a_dest = BlockSparseArray{eltype(a),length(axes),blocktype_a}(axes)
2729
a_dest .= a
2830
return a_dest
2931
end
@@ -32,8 +34,7 @@ end
3234
function ArrayLayouts.sub_materialize(
3335
layout::BlockLayout{<:SparseLayout}, a, axes::Tuple{Vararg{Base.OneTo}}
3436
)
35-
# TODO: Make more generic for GPU.
36-
a_dest = Array{eltype(a)}(undef, length.(axes))
37+
a_dest = blocktype(a)(undef, length.(axes))
3738
a_dest .= a
3839
return a_dest
3940
end

0 commit comments

Comments
 (0)