Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
20 commits
Select commit Hold shift + click to select a range
c4918d5
Working on implementing the wrapper for the new blocksparse cutensor …
kmp5VT Jan 22, 2026
c15fea2
Revert to cutensor_jll.libcutensor as this has the blocksparse cutens…
kmp5VT Jan 23, 2026
82752ad
Remove redudant convert function
kmp5VT Jan 23, 2026
9678ecf
Merge branch 'JuliaGPU:master' into kmp5/feature/wrap_blocksparse_cut…
kmp5VT Mar 6, 2026
affc3d4
Make blocksparse code more generic (generic case). Would it be better…
kmp5VT Mar 16, 2026
a3a3f07
Merge branch 'master' into kmp5/feature/wrap_blocksparse_cutensor
kmp5VT Mar 16, 2026
67013c8
Merge branch 'kmp5/feature/wrap_blocksparse_cutensor' of github.com:k…
kmp5VT Mar 16, 2026
f6f5c5f
Merge branch 'JuliaGPU:master' into kmp5/feature/wrap_blocksparse_cut…
kmp5VT Mar 19, 2026
1ec69cf
Working on simplyfying and making accessors
kmp5VT Mar 19, 2026
8f5ef88
Fix problem with stride
kmp5VT Mar 19, 2026
9285b07
Small comment reminder
kmp5VT Mar 19, 2026
cda4a4e
Add a contraction test for the blocksparse system (not comprehensive …
kmp5VT Mar 19, 2026
94b8152
Merge branch 'master' into kmp5/feature/wrap_blocksparse_cutensor
kmp5VT Mar 24, 2026
138edaf
Closer to clang.jl construction
kmp5VT Mar 24, 2026
ce2eeec
Merge branch 'kmp5/feature/wrap_blocksparse_cutensor' of github.com:k…
kmp5VT Mar 24, 2026
3c11bec
Merge branch 'master' into kmp5/feature/wrap_blocksparse_cutensor
kmp5VT Mar 25, 2026
cc4b826
Update cutensor.toml for block sparse contraction
kshyatt Mar 26, 2026
3316f63
Merge branch 'master' into kmp5/feature/wrap_blocksparse_cutensor
kmp5VT Mar 26, 2026
c493659
Apply suggestion from @lkdvos
kmp5VT Mar 27, 2026
f6fb806
Merge branch 'JuliaGPU:master' into kmp5/feature/wrap_blocksparse_cut…
kmp5VT Apr 3, 2026
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 15 additions & 0 deletions lib/cutensor/src/blocksparse/interfaces.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,15 @@


## LinearAlgebra

using LinearAlgebra

function LinearAlgebra.mul!(C::CuTensorBS, A::CuTensorBS, B::CuTensorBS, α::Number, β::Number)
contract!(α,
A, A.inds, CUTENSOR_OP_IDENTITY,
B, B.inds, CUTENSOR_OP_IDENTITY,
β,
C, C.inds, CUTENSOR_OP_IDENTITY,
CUTENSOR_OP_IDENTITY; jit=CUTENSOR_JIT_MODE_DEFAULT)
return C
end
104 changes: 104 additions & 0 deletions lib/cutensor/src/blocksparse/operations.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,104 @@
function nonzero_blocks(A::CuTensorBS)
return A.nonzero_data
end

function contract!(
@nospecialize(alpha::Number),
@nospecialize(A), Ainds::ModeType, opA::cutensorOperator_t,
@nospecialize(B), Binds::ModeType, opB::cutensorOperator_t,
@nospecialize(beta::Number),
@nospecialize(C), Cinds::ModeType, opC::cutensorOperator_t,
opOut::cutensorOperator_t;
jit::cutensorJitMode_t=JIT_MODE_NONE,
workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT,
algo::cutensorAlgo_t=ALGO_DEFAULT,
compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=nothing,
plan::Union{CuTensorPlan, Nothing}=nothing)

actual_plan = if plan === nothing
plan_contraction(A, Ainds, opA, B, Binds, opB, C, Cinds, opC, opOut;
jit, workspace, algo, compute_type)
else
plan
end

contractBS!(actual_plan, alpha, nonzero_blocks(A), nonzero_blocks(B), beta, nonzero_blocks(C))

if plan === nothing
CUDA.unsafe_free!(actual_plan)
end

return C
end

## This function assumes A, B, and C are Arrays of pointers to CuArrays.
## Please overwrite the `nonzero_blocks` function for your datatype to access this function from contract!
function contractBS!(plan::CuTensorPlan,
@nospecialize(alpha::Number),
@nospecialize(A::AbstractArray),
@nospecialize(B::AbstractArray),
@nospecialize(beta::Number),
@nospecialize(C::AbstractArray))
scalar_type = plan.scalar_type

# Extract GPU pointers from each CuArray block
# cuTENSOR expects a host-accessible array of GPU pointers
A_ptrs = CuPtr{Cvoid}[pointer(block) for block in A]
B_ptrs = CuPtr{Cvoid}[pointer(block) for block in B]
C_ptrs = CuPtr{Cvoid}[pointer(block) for block in C]

cutensorBlockSparseContract(handle(), plan,
Ref{scalar_type}(alpha), A_ptrs, B_ptrs,
Ref{scalar_type}(beta), C_ptrs, C_ptrs,
plan.workspace, sizeof(plan.workspace), stream())
synchronize(stream())
return C
end

function plan_contraction(
@nospecialize(A), Ainds::ModeType, opA::cutensorOperator_t,
@nospecialize(B), Binds::ModeType, opB::cutensorOperator_t,
@nospecialize(C), Cinds::ModeType, opC::cutensorOperator_t,
opOut::cutensorOperator_t;
jit::cutensorJitMode_t=JIT_MODE_NONE,
workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT,
algo::cutensorAlgo_t=ALGO_DEFAULT,
compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=nothing)

!is_unary(opA) && throw(ArgumentError("opA must be a unary op!"))
!is_unary(opB) && throw(ArgumentError("opB must be a unary op!"))
!is_unary(opC) && throw(ArgumentError("opC must be a unary op!"))
!is_unary(opOut) && throw(ArgumentError("opOut must be a unary op!"))

descA = CuTensorBSDescriptor(A)
descB = CuTensorBSDescriptor(B)
descC = CuTensorBSDescriptor(C)
# for now, D must be identical to C (and thus, descD must be identical to descC)

modeA = collect(Cint, Ainds)
modeB = collect(Cint, Binds)
modeC = collect(Cint, Cinds)

actual_compute_type = if compute_type === nothing
contraction_compute_types[(eltype(A), eltype(B), eltype(C))]
else
compute_type
end


desc = Ref{cutensorOperationDescriptor_t}()
cutensorCreateBlockSparseContraction(handle(),
desc,
descA, modeA, opA,
descB, modeB, opB,
descC, modeC, opC,
descC, modeC, actual_compute_type)

plan_pref = Ref{cutensorPlanPreference_t}()
cutensorCreatePlanPreference(handle(), plan_pref, algo, jit)

plan = CuTensorPlan(desc[], plan_pref[]; workspacePref=workspace)
# cutensorDestroyOperationDescriptor(desc[])
cutensorDestroyPlanPreference(plan_pref[])
return plan
end
120 changes: 120 additions & 0 deletions lib/cutensor/src/blocksparse/types.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
## tensor

export CuTensorBS

## TODO add checks to see if size of data matches expected block size
mutable struct CuTensorBS{T, N}
nonzero_data::Vector{<:CuArray}
inds::Vector{Int}
blocks_per_mode::Vector{Int}
## This expects a Vector{Tuple(Int)} right now
block_extents
## This expects a Vector{Tuple(Int)} right now
nonzero_block_coords

function CuTensorBS{T, N}(nonzero_data::Vector{<:CuArray},
blocks_per_mode::Vector{Int}, block_extents, nonzero_block_coords, inds::Vector) where {T<:Number, N}
CuArrayT = eltype(nonzero_data)
@assert eltype(CuArrayT) == T
# @assert ndims(CuArrayT) == N
@assert length(block_extents) == N
new(nonzero_data, inds, blocks_per_mode, block_extents, nonzero_block_coords)
end
end

function CuTensorBS(nonzero_data::Vector{<:CuArray{T}},
blocks_per_mode, block_extents, nonzero_block_coords, inds::Vector) where {T<:Number}
CuTensorBS{T,length(block_extents)}(nonzero_data,
blocks_per_mode, block_extents, nonzero_block_coords, inds)
end
# array interface
function Base.size(T::CuTensorBS)
return tuple(sum.(T.block_extents)...)
end
Base.length(T::CuTensorBS) = prod(size(T))
nonzero_length(T::CuTensorBS) = sum(length.(T.nonzero_data))
Base.ndims(T::CuTensorBS) = length(T.inds)

Base.strides(T::CuTensorBS) = vcat([[st...] for st in strides.(T.nonzero_data)]...)
Base.eltype(T::CuTensorBS) = eltype(eltype(T.nonzero_data))

function block_extents(T::CuTensorBS)
extents = Vector{Int64}()

for ex in T.block_extents
extents = vcat(extents, ex...)
end
return extents
end

nblocks_per_mode(T::CuTensorBS) = T.blocks_per_mode

num_nonzero_blocks(T::CuTensorBS) = length(T.nonzero_block_coords)

## This function turns the tuple of the block coordinates into a single
## list of blocks
function list_nonzero_block_coords(T::CuTensorBS)
block_list = Vector{Int64}()
for block in T.nonzero_block_coords
block_list = vcat(block_list, block...)
end
return block_list
end

# ## descriptor
mutable struct CuTensorBSDescriptor
handle::cutensorBlockSparseTensorDescriptor_t
# inner constructor handles creation and finalizer of the descriptor
function CuTensorBSDescriptor(
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it would be both helpful for clarity/self-documentation and for avoiding hard to decypher errors to restrict the types of these arguments in the inner constructor. This would also be more in line with the CuTensorDescriptor type + constructors.

numModes,
numNonZeroBlocks,
numSectionsPerMode,
extent,
nonZeroCoordinates,
stride,
eltype)

desc = Ref{cuTENSOR.cutensorBlockSparseTensorDescriptor_t}()
cutensorCreateBlockSparseTensorDescriptor(handle(), desc,
numModes, numNonZeroBlocks, numSectionsPerMode, extent, nonZeroCoordinates,
stride, eltype)

obj = new(desc[])
finalizer(unsafe_destroy!, obj)
return obj
end
end

function CuTensorBSDescriptor(
numModes,
numNonZeroBlocks,
numSectionsPerMode,
extent,
nonZeroCoordinates,
eltype)

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps a comment here to indicate which argument is filled in as C_NULL and what that means might be helpful

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I added a dummy input comment to show that strides=C_NULL and a comment.

return CuTensorBSDescriptor(numModes, numNonZeroBlocks, numSectionsPerMode, extent, nonZeroCoordinates, C_NULL, eltype)
end

Base.show(io::IO, desc::CuTensorBSDescriptor) = @printf(io, "CuTensorBSDescriptor(%p)", desc.handle)

Base.unsafe_convert(::Type{cutensorBlockSparseTensorDescriptor_t}, obj::CuTensorBSDescriptor) = obj.handle

function unsafe_destroy!(obj::CuTensorBSDescriptor)
cutensorDestroyBlockSparseTensorDescriptor(obj)
end

## Descriptor function for CuTensorBS type. Please overwrite for custom objects
function CuTensorBSDescriptor(A::CuTensorBS)
numModes = Int32(ndims(A))
numNonZeroBlocks = Int64(length(A.nonzero_block_coords))
numSectionsPerMode = collect(Int32, A.blocks_per_mode)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this has to be Int32, would it not be easier to immediately make that type restriction in the CuTensorBS type?

extent = block_extents(A)
nonZeroCoordinates = Int32.(vcat([[x...] for x in A.nonzero_block_coords]...) .- 1)
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Same comment here. It also seems slightly strange to me to have a different storage format from the type that is required for the contraction, as this seems to introduce some allocations that could possibly be avoided?

st = strides(A)
dataType = eltype(A)#convert(cuTENSOR.cutensorDataType_t, eltype(A))

## Right now assume stride is NULL. I am not sure if stride works, need to discuss with cuTENSOR team.
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we add an assert that the strides are the "natural ones" for that in the meantime?

CuTensorBSDescriptor(numModes, numNonZeroBlocks,
numSectionsPerMode, extent, nonZeroCoordinates, dataType)
end
8 changes: 8 additions & 0 deletions lib/cutensor/src/cuTENSOR.jl
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ using CUDACore
using CUDACore: CUstream, cudaDataType, @gcsafe_ccall, @checked, @enum_without_prefix
using CUDACore: retry_reclaim, initialize_context, isdebug

using CUDA.GPUToolbox

using CEnum: @cenum

using Printf: @printf
Expand Down Expand Up @@ -32,8 +34,14 @@ include("utils.jl")
include("types.jl")
include("operations.jl")


# Block sparse wrappers
include("blocksparse/types.jl")
include("blocksparse/operations.jl")

# high-level integrations
include("interfaces.jl")
include("blocksparse/interfaces.jl")


## handles
Expand Down
10 changes: 5 additions & 5 deletions lib/cutensor/src/libcutensor.jl
Original file line number Diff line number Diff line change
Expand Up @@ -545,12 +545,12 @@ end
@gcsafe_ccall libcutensor.cutensorBlockSparseContract(handle::cutensorHandle_t,
plan::cutensorPlan_t,
alpha::Ptr{Cvoid},
A::Ptr{Ptr{Cvoid}},
B::Ptr{Ptr{Cvoid}},
A::Ptr{CuPtr{Cvoid}},
B::Ptr{CuPtr{Cvoid}},
beta::Ptr{Cvoid},
C::Ptr{Ptr{Cvoid}},
D::Ptr{Ptr{Cvoid}},
workspace::Ptr{Cvoid},
C::Ptr{CuPtr{Cvoid}},
D::Ptr{CuPtr{Cvoid}},
workspace::CuPtr{Cvoid},
workspaceSize::UInt64,
stream::cudaStream_t)::cutensorStatus_t
end
Expand Down
58 changes: 58 additions & 0 deletions lib/cutensor/test/contractions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -188,4 +188,62 @@ end
end
end

eltypes_compact = [
(Float32, Float32, Float32, Float32),
(ComplexF32, ComplexF32, ComplexF32, Float32),
(Float64, Float64, Float64, Float64),
(ComplexF64, ComplexF64, ComplexF64, Float64)
]
@testset "Blocksparse Contraction" begin
## There are many unsupported types because this is a new functionality
## So I will test with Float32 and ComplexF32 only
@testset for (eltyA, eltyB, eltyC, eltyCompute) in eltypes_compact
## i = [20,20,25]
## k = [10,10,15]
## l = [30,30,35]
## A = Tensor(k,i,l)
## Nonzero blocks are
## [1,1,1], [1,1,3], [1,3,1], [1,3,3], [3,1,1], [3,1,3], [3,3,1], [3,3,3]
A = Vector{CuArray{eltyA, 3}}()
for k in [10,15]
for i in [20,25]
for l in [30,35]
push!(A, CuArray(ones(eltyA, k,i,l)))
end
end
end

## B = Tensor(k,l)
## Nonzero blocks are
## [1,1], [2,3]
B = Array{CuArray{eltyB, 2}}(
[CuArray(randn(eltyB, 10, 30)),
CuArray(randn(eltyB, 10, 35))])

## C = Tensor(i)
## Nonzero blocks are
## [1,], [3,]
C = Vector{CuArray{eltyC, 1}}(
[CuArray(zeros(eltyC, 20)),
CuArray(zeros(eltyC, 25))]
)

cuTenA = cuTENSOR.CuTensorBS(A, [3,3,3],
[(10,10,15), (20,20,25), (30,30,35)],
[(1,1,1), (1,1,3), (1,3,1), (1,3,3), (3,1,1), (3,1,3), (3,3,1), (3,3,3)],
[1,3,2])
cuTenB = cuTENSOR.CuTensorBS(B, [3,3],
[(10,10,15), (30,30,35)],
[(1,1),(2,3)], [1,2], )
cuTenC = cuTENSOR.CuTensorBS(C, [3],
[(20,20,25)],[(1,),(3,)], [3])

mul!(cuTenC, cuTenA, cuTenB, 1, 0)
## C[1] = A[1,1,1] * B[1,1]
@test C[1] ≈ reshape(permutedims(A[1], (2,1,3)), (20, 10 * 30)) * reshape(B[1], (10 * 30))
## C[3] = A[1,3,1] * B[1,1]
@test C[2] ≈ reshape(permutedims(A[3], (2,1,3)), (25, 10 * 30)) * reshape(B[1], (10 * 30))
end
end

end
7 changes: 7 additions & 0 deletions res/wrap/cutensor.toml
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,10 @@ needs_context = false
6 = "CuPtr{Cvoid}"
7 = "CuPtr{Cvoid}"
8 = "CuPtr{Cvoid}"

[api.cutensorBlockSparseContract.argtypes]
4 = "Ptr{CuPtr{Cvoid}}"
5 = "Ptr{CuPtr{Cvoid}}"
7 = "Ptr{CuPtr{Cvoid}}"
8 = "Ptr{CuPtr{Cvoid}}"
9 = "Ptr{CuPtr{Cvoid}}"