-
Notifications
You must be signed in to change notification settings - Fork 269
Wrapper for Blocksparse CuTensor code #3057
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: master
Are you sure you want to change the base?
Changes from all commits
c4918d5
c15fea2
82752ad
9678ecf
affc3d4
a3a3f07
67013c8
f6f5c5f
1ec69cf
8f5ef88
9285b07
cda4a4e
94b8152
138edaf
ce2eeec
3c11bec
cc4b826
3316f63
c493659
f6fb806
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,15 @@ | ||
|
|
||
|
|
||
| ## LinearAlgebra | ||
|
|
||
| using LinearAlgebra | ||
|
|
||
| function LinearAlgebra.mul!(C::CuTensorBS, A::CuTensorBS, B::CuTensorBS, α::Number, β::Number) | ||
| contract!(α, | ||
| A, A.inds, CUTENSOR_OP_IDENTITY, | ||
| B, B.inds, CUTENSOR_OP_IDENTITY, | ||
| β, | ||
| C, C.inds, CUTENSOR_OP_IDENTITY, | ||
| CUTENSOR_OP_IDENTITY; jit=CUTENSOR_JIT_MODE_DEFAULT) | ||
| return C | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,104 @@ | ||
| function nonzero_blocks(A::CuTensorBS) | ||
| return A.nonzero_data | ||
| end | ||
|
|
||
| function contract!( | ||
| @nospecialize(alpha::Number), | ||
| @nospecialize(A), Ainds::ModeType, opA::cutensorOperator_t, | ||
| @nospecialize(B), Binds::ModeType, opB::cutensorOperator_t, | ||
| @nospecialize(beta::Number), | ||
| @nospecialize(C), Cinds::ModeType, opC::cutensorOperator_t, | ||
| opOut::cutensorOperator_t; | ||
| jit::cutensorJitMode_t=JIT_MODE_NONE, | ||
| workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, | ||
| algo::cutensorAlgo_t=ALGO_DEFAULT, | ||
| compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=nothing, | ||
| plan::Union{CuTensorPlan, Nothing}=nothing) | ||
|
|
||
| actual_plan = if plan === nothing | ||
| plan_contraction(A, Ainds, opA, B, Binds, opB, C, Cinds, opC, opOut; | ||
| jit, workspace, algo, compute_type) | ||
| else | ||
| plan | ||
| end | ||
|
|
||
| contractBS!(actual_plan, alpha, nonzero_blocks(A), nonzero_blocks(B), beta, nonzero_blocks(C)) | ||
|
|
||
| if plan === nothing | ||
| CUDA.unsafe_free!(actual_plan) | ||
| end | ||
|
|
||
| return C | ||
| end | ||
|
|
||
| ## This function assumes A, B, and C are Arrays of pointers to CuArrays. | ||
| ## Please overwrite the `nonzero_blocks` function for your datatype to access this function from contract! | ||
| function contractBS!(plan::CuTensorPlan, | ||
| @nospecialize(alpha::Number), | ||
| @nospecialize(A::AbstractArray), | ||
| @nospecialize(B::AbstractArray), | ||
| @nospecialize(beta::Number), | ||
| @nospecialize(C::AbstractArray)) | ||
| scalar_type = plan.scalar_type | ||
|
|
||
| # Extract GPU pointers from each CuArray block | ||
| # cuTENSOR expects a host-accessible array of GPU pointers | ||
| A_ptrs = CuPtr{Cvoid}[pointer(block) for block in A] | ||
| B_ptrs = CuPtr{Cvoid}[pointer(block) for block in B] | ||
| C_ptrs = CuPtr{Cvoid}[pointer(block) for block in C] | ||
|
|
||
| cutensorBlockSparseContract(handle(), plan, | ||
| Ref{scalar_type}(alpha), A_ptrs, B_ptrs, | ||
| Ref{scalar_type}(beta), C_ptrs, C_ptrs, | ||
| plan.workspace, sizeof(plan.workspace), stream()) | ||
| synchronize(stream()) | ||
| return C | ||
| end | ||
|
|
||
| function plan_contraction( | ||
| @nospecialize(A), Ainds::ModeType, opA::cutensorOperator_t, | ||
| @nospecialize(B), Binds::ModeType, opB::cutensorOperator_t, | ||
| @nospecialize(C), Cinds::ModeType, opC::cutensorOperator_t, | ||
| opOut::cutensorOperator_t; | ||
| jit::cutensorJitMode_t=JIT_MODE_NONE, | ||
| workspace::cutensorWorksizePreference_t=WORKSPACE_DEFAULT, | ||
| algo::cutensorAlgo_t=ALGO_DEFAULT, | ||
| compute_type::Union{DataType, cutensorComputeDescriptorEnum, Nothing}=nothing) | ||
|
|
||
| !is_unary(opA) && throw(ArgumentError("opA must be a unary op!")) | ||
| !is_unary(opB) && throw(ArgumentError("opB must be a unary op!")) | ||
| !is_unary(opC) && throw(ArgumentError("opC must be a unary op!")) | ||
| !is_unary(opOut) && throw(ArgumentError("opOut must be a unary op!")) | ||
|
|
||
| descA = CuTensorBSDescriptor(A) | ||
| descB = CuTensorBSDescriptor(B) | ||
| descC = CuTensorBSDescriptor(C) | ||
| # for now, D must be identical to C (and thus, descD must be identical to descC) | ||
|
|
||
| modeA = collect(Cint, Ainds) | ||
| modeB = collect(Cint, Binds) | ||
| modeC = collect(Cint, Cinds) | ||
|
|
||
| actual_compute_type = if compute_type === nothing | ||
| contraction_compute_types[(eltype(A), eltype(B), eltype(C))] | ||
| else | ||
| compute_type | ||
| end | ||
|
|
||
|
|
||
| desc = Ref{cutensorOperationDescriptor_t}() | ||
| cutensorCreateBlockSparseContraction(handle(), | ||
| desc, | ||
| descA, modeA, opA, | ||
| descB, modeB, opB, | ||
| descC, modeC, opC, | ||
| descC, modeC, actual_compute_type) | ||
|
|
||
| plan_pref = Ref{cutensorPlanPreference_t}() | ||
| cutensorCreatePlanPreference(handle(), plan_pref, algo, jit) | ||
|
|
||
| plan = CuTensorPlan(desc[], plan_pref[]; workspacePref=workspace) | ||
| # cutensorDestroyOperationDescriptor(desc[]) | ||
| cutensorDestroyPlanPreference(plan_pref[]) | ||
| return plan | ||
| end |
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,120 @@ | ||
| ## tensor | ||
|
|
||
| export CuTensorBS | ||
|
|
||
| ## TODO add checks to see if size of data matches expected block size | ||
| mutable struct CuTensorBS{T, N} | ||
| nonzero_data::Vector{<:CuArray} | ||
| inds::Vector{Int} | ||
| blocks_per_mode::Vector{Int} | ||
| ## This expects a Vector{Tuple(Int)} right now | ||
| block_extents | ||
| ## This expects a Vector{Tuple(Int)} right now | ||
| nonzero_block_coords | ||
|
|
||
| function CuTensorBS{T, N}(nonzero_data::Vector{<:CuArray}, | ||
| blocks_per_mode::Vector{Int}, block_extents, nonzero_block_coords, inds::Vector) where {T<:Number, N} | ||
| CuArrayT = eltype(nonzero_data) | ||
| @assert eltype(CuArrayT) == T | ||
| # @assert ndims(CuArrayT) == N | ||
| @assert length(block_extents) == N | ||
| new(nonzero_data, inds, blocks_per_mode, block_extents, nonzero_block_coords) | ||
| end | ||
| end | ||
|
|
||
| function CuTensorBS(nonzero_data::Vector{<:CuArray{T}}, | ||
| blocks_per_mode, block_extents, nonzero_block_coords, inds::Vector) where {T<:Number} | ||
| CuTensorBS{T,length(block_extents)}(nonzero_data, | ||
| blocks_per_mode, block_extents, nonzero_block_coords, inds) | ||
| end | ||
| # array interface | ||
| function Base.size(T::CuTensorBS) | ||
| return tuple(sum.(T.block_extents)...) | ||
| end | ||
| Base.length(T::CuTensorBS) = prod(size(T)) | ||
| nonzero_length(T::CuTensorBS) = sum(length.(T.nonzero_data)) | ||
| Base.ndims(T::CuTensorBS) = length(T.inds) | ||
|
|
||
| Base.strides(T::CuTensorBS) = vcat([[st...] for st in strides.(T.nonzero_data)]...) | ||
| Base.eltype(T::CuTensorBS) = eltype(eltype(T.nonzero_data)) | ||
|
|
||
| function block_extents(T::CuTensorBS) | ||
| extents = Vector{Int64}() | ||
|
|
||
| for ex in T.block_extents | ||
| extents = vcat(extents, ex...) | ||
| end | ||
| return extents | ||
| end | ||
|
|
||
| nblocks_per_mode(T::CuTensorBS) = T.blocks_per_mode | ||
|
|
||
| num_nonzero_blocks(T::CuTensorBS) = length(T.nonzero_block_coords) | ||
|
|
||
| ## This function turns the tuple of the block coordinates into a single | ||
| ## list of blocks | ||
| function list_nonzero_block_coords(T::CuTensorBS) | ||
| block_list = Vector{Int64}() | ||
| for block in T.nonzero_block_coords | ||
| block_list = vcat(block_list, block...) | ||
| end | ||
| return block_list | ||
| end | ||
|
|
||
| # ## descriptor | ||
| mutable struct CuTensorBSDescriptor | ||
| handle::cutensorBlockSparseTensorDescriptor_t | ||
| # inner constructor handles creation and finalizer of the descriptor | ||
| function CuTensorBSDescriptor( | ||
| numModes, | ||
| numNonZeroBlocks, | ||
| numSectionsPerMode, | ||
| extent, | ||
| nonZeroCoordinates, | ||
| stride, | ||
| eltype) | ||
|
|
||
| desc = Ref{cuTENSOR.cutensorBlockSparseTensorDescriptor_t}() | ||
| cutensorCreateBlockSparseTensorDescriptor(handle(), desc, | ||
| numModes, numNonZeroBlocks, numSectionsPerMode, extent, nonZeroCoordinates, | ||
| stride, eltype) | ||
|
|
||
| obj = new(desc[]) | ||
| finalizer(unsafe_destroy!, obj) | ||
| return obj | ||
| end | ||
| end | ||
|
|
||
| function CuTensorBSDescriptor( | ||
| numModes, | ||
| numNonZeroBlocks, | ||
| numSectionsPerMode, | ||
| extent, | ||
| nonZeroCoordinates, | ||
| eltype) | ||
|
|
||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Perhaps a comment here to indicate which argument is filled in as
Contributor
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I added a dummy input comment to show that strides=C_NULL and a comment. |
||
| return CuTensorBSDescriptor(numModes, numNonZeroBlocks, numSectionsPerMode, extent, nonZeroCoordinates, C_NULL, eltype) | ||
| end | ||
|
|
||
| Base.show(io::IO, desc::CuTensorBSDescriptor) = @printf(io, "CuTensorBSDescriptor(%p)", desc.handle) | ||
|
|
||
| Base.unsafe_convert(::Type{cutensorBlockSparseTensorDescriptor_t}, obj::CuTensorBSDescriptor) = obj.handle | ||
|
|
||
| function unsafe_destroy!(obj::CuTensorBSDescriptor) | ||
| cutensorDestroyBlockSparseTensorDescriptor(obj) | ||
| end | ||
|
|
||
| ## Descriptor function for CuTensorBS type. Please overwrite for custom objects | ||
| function CuTensorBSDescriptor(A::CuTensorBS) | ||
kmp5VT marked this conversation as resolved.
Show resolved
Hide resolved
|
||
| numModes = Int32(ndims(A)) | ||
| numNonZeroBlocks = Int64(length(A.nonzero_block_coords)) | ||
| numSectionsPerMode = collect(Int32, A.blocks_per_mode) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. If this has to be |
||
| extent = block_extents(A) | ||
| nonZeroCoordinates = Int32.(vcat([[x...] for x in A.nonzero_block_coords]...) .- 1) | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same comment here. It also seems slightly strange to me to have a different storage format from the type that is required for the contraction, as this seems to introduce some allocations that could possibly be avoided? |
||
| st = strides(A) | ||
| dataType = eltype(A)#convert(cuTENSOR.cutensorDataType_t, eltype(A)) | ||
|
|
||
| ## Right now assume stride is NULL. I am not sure if stride works, need to discuss with cuTENSOR team. | ||
|
Contributor
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Can we add an assert that the strides are the "natural ones" for that in the meantime? |
||
| CuTensorBSDescriptor(numModes, numNonZeroBlocks, | ||
| numSectionsPerMode, extent, nonZeroCoordinates, dataType) | ||
| end | ||
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think it would be both helpful for clarity/self-documentation and for avoiding hard to decypher errors to restrict the types of these arguments in the inner constructor. This would also be more in line with the
CuTensorDescriptortype + constructors.