diff --git a/NDTensors/ext/NDTensorscuTENSORExt/contract.jl b/NDTensors/ext/NDTensorscuTENSORExt/contract.jl index 800531ba90..64ffa26b19 100644 --- a/NDTensors/ext/NDTensorscuTENSORExt/contract.jl +++ b/NDTensors/ext/NDTensorscuTENSORExt/contract.jl @@ -1,6 +1,7 @@ using Base: ReshapedArray using NDTensors.Expose: Exposed, expose, unexpose -using NDTensors: NDTensors, DenseTensor, array +using NDTensors: NDTensors, BlockSparseTensor, DenseTensor, array, +blockdims, data, eachnzblock, inds, nblocks, nzblocks using cuTENSOR: cuTENSOR, CuArray, CuTensor # Handle cases that can't be handled by `cuTENSOR.jl` @@ -12,6 +13,66 @@ function to_zero_offset_cuarray(a::ReshapedArray) return copy(expose(a)) end +function block_extents(ind) + return ntuple(i -> ind.space[i].second, nblocks(ind)) +end + +#### Functions to turn Tensors into BlockSparseCuTensors for contraction +function ITensor_to_cuTensorBS(T::BlockSparseTensor) + blocks_t1 = [] + # T = tensor(target) + for blockT in eachnzblock(T) + offsetT = NDTensors.offset(T, blockT) + blockdimsT = blockdims(T, blockT) + blockdimT = prod(blockdimsT) + push!(blocks_t1, @view data(T)[(offsetT + 1):(offsetT + blockdimT)]) + end + blocks_t1 = Vector{typeof(blocks_t1[1])}(blocks_t1) + block_extents_t1 = [block_extents(idx) for idx in inds(T)] ## This is sections + nzblock_coords_t1 = [Int64.(x.data) for x in nzblocks(T)] + block_per_mode_t1 = length.(block_extents_t1) + is = [i for i in 1:ndims(T)] + return cuTENSOR.CuTensorBS(blocks_t1, block_per_mode_t1, block_extents_t1, nzblock_coords_t1, is); +end + +function NDTensors._contract!(R::Exposed{<:CuArray, <:BlockSparseTensor}, + labelsR, + tensor1::Exposed{<:CuArray, <:BlockSparseTensor}, + labelstensor1, + tensor2::Exposed{<:CuArray, <:BlockSparseTensor}, + labelstensor2, + grouped_contraction_plan, + executor, + ) + N1 = ndims(unexpose(tensor1)) + N2 = ndims(unexpose(tensor2)) + NR = ndims(unexpose(R)) + if NDTensors.using_CuTensorBS() && (N1 > 0) && (N2 > 0) && (NR > 0) + # println("Using new function") + cuR = ITensor_to_cuTensorBS(unexpose(R)) + cutensor1 = ITensor_to_cuTensorBS(unexpose(tensor1)) + cutensor2 = ITensor_to_cuTensorBS(unexpose(tensor2)) + + cuR.inds = [labelsR...] + cutensor1.inds = [labelstensor1...] + cutensor2.inds = [labelstensor2...] + + cuTENSOR.mul!(cuR, cutensor1, cutensor2, 1.0, 0.0) + return R + else + return NDTensors._contract!( + unexpose(R), + labelsR, + unexpose(tensor1), + labelstensor1, + unexpose(tensor2), + labelstensor2, + grouped_contraction_plan, + executor, + ) + end +end + function NDTensors.contract!( exposedR::Exposed{<:CuArray, <:DenseTensor}, labelsR, diff --git a/NDTensors/src/NDTensors.jl b/NDTensors/src/NDTensors.jl index 4fe6785856..e7a6068836 100644 --- a/NDTensors/src/NDTensors.jl +++ b/NDTensors/src/NDTensors.jl @@ -241,4 +241,19 @@ end function backend_octavian end + +_using_CuTensorBS = false + +using_CuTensorBS() = _using_CuTensorBS + +function enable_CuTensorBS() + NDTensors._using_CuTensorBS = true + return nothing +end + +function disable_CuTensorBS() + NDTensors._using_CuTensorBS = false + return nothing +end + end # module NDTensors diff --git a/NDTensors/src/blocksparse/contract_generic.jl b/NDTensors/src/blocksparse/contract_generic.jl index e790ef39ee..39b67facc8 100644 --- a/NDTensors/src/blocksparse/contract_generic.jl +++ b/NDTensors/src/blocksparse/contract_generic.jl @@ -34,6 +34,7 @@ end # A generic version making use of `Folds.jl` which # can take various Executor backends. # Used for sequential and threaded contract functions. +using .Expose: Exposed, expose, unexpose function contract!( R::BlockSparseTensor, labelsR, @@ -59,19 +60,32 @@ function contract!( push!(grouped_contraction_plan[last(block_contraction)], block_contraction) end _contract!( - R, + expose(R), labelsR, - tensor1, + expose(tensor1), labelstensor1, - tensor2, + expose(tensor2), labelstensor2, grouped_contraction_plan, executor ) return R end - -using .Expose: expose +function _contract!(R::Exposed, + labelsR, + tensor1::Exposed, + labelstensor1, + tensor2::Exposed, + labelstensor2, + grouped_contraction_plan, + executor, + ) + _contract!(unexpose(R), labelsR, + unexpose(tensor1), labelstensor1, + unexpose(tensor2), labelstensor2, + grouped_contraction_plan,executor + ) +end # Function barrier to improve type stability, # since `Folds`/`FLoops` is not type stable: # https://discourse.julialang.org/t/type-instability-in-floop-reduction/68598