Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 62 additions & 1 deletion NDTensors/ext/NDTensorscuTENSORExt/contract.jl
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
using Base: ReshapedArray
using NDTensors.Expose: Exposed, expose, unexpose
using NDTensors: NDTensors, DenseTensor, array
using NDTensors: NDTensors, BlockSparseTensor, DenseTensor, array,
blockdims, data, eachnzblock, inds, nblocks, nzblocks
using cuTENSOR: cuTENSOR, CuArray, CuTensor

# Handle cases that can't be handled by `cuTENSOR.jl`
Expand All @@ -12,6 +13,66 @@ function to_zero_offset_cuarray(a::ReshapedArray)
return copy(expose(a))
end

function block_extents(ind)
return ntuple(i -> ind.space[i].second, nblocks(ind))
end

#### Functions to turn Tensors into BlockSparseCuTensors for contraction
function ITensor_to_cuTensorBS(T::BlockSparseTensor)
blocks_t1 = []
# T = tensor(target)
for blockT in eachnzblock(T)
offsetT = NDTensors.offset(T, blockT)
blockdimsT = blockdims(T, blockT)
blockdimT = prod(blockdimsT)
push!(blocks_t1, @view data(T)[(offsetT + 1):(offsetT + blockdimT)])
end
blocks_t1 = Vector{typeof(blocks_t1[1])}(blocks_t1)
block_extents_t1 = [block_extents(idx) for idx in inds(T)] ## This is sections
nzblock_coords_t1 = [Int64.(x.data) for x in nzblocks(T)]
block_per_mode_t1 = length.(block_extents_t1)
is = [i for i in 1:ndims(T)]
return cuTENSOR.CuTensorBS(blocks_t1, block_per_mode_t1, block_extents_t1, nzblock_coords_t1, is);
end

function NDTensors._contract!(R::Exposed{<:CuArray, <:BlockSparseTensor},
labelsR,
tensor1::Exposed{<:CuArray, <:BlockSparseTensor},
labelstensor1,
tensor2::Exposed{<:CuArray, <:BlockSparseTensor},
labelstensor2,
grouped_contraction_plan,
executor,
)
N1 = ndims(unexpose(tensor1))
N2 = ndims(unexpose(tensor2))
NR = ndims(unexpose(R))
if NDTensors.using_CuTensorBS() && (N1 > 0) && (N2 > 0) && (NR > 0)
# println("Using new function")
cuR = ITensor_to_cuTensorBS(unexpose(R))
cutensor1 = ITensor_to_cuTensorBS(unexpose(tensor1))
cutensor2 = ITensor_to_cuTensorBS(unexpose(tensor2))

cuR.inds = [labelsR...]
cutensor1.inds = [labelstensor1...]
cutensor2.inds = [labelstensor2...]

cuTENSOR.mul!(cuR, cutensor1, cutensor2, 1.0, 0.0)
return R
else
return NDTensors._contract!(
unexpose(R),
labelsR,
unexpose(tensor1),
labelstensor1,
unexpose(tensor2),
labelstensor2,
grouped_contraction_plan,
executor,
)
end
end

function NDTensors.contract!(
exposedR::Exposed{<:CuArray, <:DenseTensor},
labelsR,
Expand Down
15 changes: 15 additions & 0 deletions NDTensors/src/NDTensors.jl
Original file line number Diff line number Diff line change
Expand Up @@ -241,4 +241,19 @@ end

function backend_octavian end


_using_CuTensorBS = false

using_CuTensorBS() = _using_CuTensorBS

function enable_CuTensorBS()
NDTensors._using_CuTensorBS = true
return nothing
end

function disable_CuTensorBS()
NDTensors._using_CuTensorBS = false
return nothing
end

end # module NDTensors
24 changes: 19 additions & 5 deletions NDTensors/src/blocksparse/contract_generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@ end
# A generic version making use of `Folds.jl` which
# can take various Executor backends.
# Used for sequential and threaded contract functions.
using .Expose: Exposed, expose, unexpose
function contract!(
R::BlockSparseTensor,
labelsR,
Expand All @@ -59,19 +60,32 @@ function contract!(
push!(grouped_contraction_plan[last(block_contraction)], block_contraction)
end
_contract!(
R,
expose(R),
labelsR,
tensor1,
expose(tensor1),
labelstensor1,
tensor2,
expose(tensor2),
labelstensor2,
grouped_contraction_plan,
executor
)
return R
end

using .Expose: expose
function _contract!(R::Exposed,
labelsR,
tensor1::Exposed,
labelstensor1,
tensor2::Exposed,
labelstensor2,
grouped_contraction_plan,
executor,
)
_contract!(unexpose(R), labelsR,
unexpose(tensor1), labelstensor1,
unexpose(tensor2), labelstensor2,
grouped_contraction_plan,executor
)
end
# Function barrier to improve type stability,
# since `Folds`/`FLoops` is not type stable:
# https://discourse.julialang.org/t/type-instability-in-floop-reduction/68598
Expand Down
Loading