Skip to content

Commit c735d30

Browse files
authored
Support cuTENSOR contractors for 1D views (#2650)
1 parent 779e4dc commit c735d30

File tree

2 files changed

+38
-2
lines changed

2 files changed

+38
-2
lines changed

lib/cutensor/src/types.jl

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -194,13 +194,15 @@ function CUDA.unsafe_free!(plan::CuTensorPlan)
194194
end
195195

196196

197+
const CUTENSOR_ALIGNMENT = UInt32(128)
198+
197199
## descriptor
198200

199201
mutable struct CuTensorDescriptor
200202
handle::cutensorTensorDescriptor_t
201203
# inner constructor handles creation and finalizer of the descriptor
202204
function CuTensorDescriptor(sz::Vector{Int64}, st::Vector{Int64}, eltype::DataType,
203-
alignmentRequirement::UInt32=UInt32(128))
205+
alignmentRequirement::UInt32=CUTENSOR_ALIGNMENT)
204206
desc = Ref{cutensorTensorDescriptor_t}()
205207
length(st) == (N = length(sz)) || throw(ArgumentError("size and stride vectors must have the same length"))
206208
cutensorCreateTensorDescriptor(handle(), desc, N, sz, st, eltype, alignmentRequirement)
@@ -236,7 +238,12 @@ mutable struct CuTensor{T, N}
236238
inds::Vector{Int32}
237239

238240
function CuTensor{T, N}(data::CuArray{T,N}, inds::Vector) where {T<:Number, N}
239-
new(data, inds)
241+
if !iszero(UInt(pointer(data)) % CUTENSOR_ALIGNMENT)
242+
@warn "The data for this CuTensor does not obey the CUTENSOR alignment requirement of $CUTENSOR_ALIGNMENT. An explicit copy will be made to ensure the requirement is met."
243+
return new(copy(data), inds)
244+
else
245+
return new(data, inds)
246+
end
240247
end
241248
end
242249

lib/cutensor/test/contractions.jl

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -164,4 +164,33 @@ eltypes = [(Float32, Float32, Float32, Float32),
164164
end
165165
end
166166

167+
# https://github.com/JuliaGPU/CUDA.jl/issues/2407
168+
@testset "contractions of views" begin
169+
@testset for (eltyA, eltyB, eltyC, eltyCompute) in eltypes
170+
dimsA = (16,)
171+
dimsB = (4,)
172+
dimsC = (8,)
173+
A = rand(eltyA, dimsA)
174+
B = rand(eltyB, dimsB)
175+
C = rand(eltyC, dimsC)
176+
dA = CuArray(A)
177+
dB = CuArray(B)
178+
dC = CuArray(C)
179+
dD = CuArray(C)
180+
vA = @view dA[1:4]
181+
vB = @view dB[4:4]
182+
vC = @view dC[3:6]
183+
vD = @view dD[3:6]
184+
tA = CuTensor(reshape(vA, (4, 1)), [1, 2])
185+
tB = CuTensor(reshape(vB, (1, 1)), [3, 2])
186+
tC = CuTensor(reshape(vC, (1, 4)), [3, 1])
187+
mul!(tC, tA, tB)
188+
tA2 = CuTensor(copy(vA), [1, 2])
189+
tB2 = CuTensor(copy(vB), [3, 2])
190+
tD = CuTensor(copy(vD), [3, 1])
191+
mul!(tD, tA2, tB2)
192+
@test tD.data tD.data
193+
end
194+
end
195+
167196
end

0 commit comments

Comments
 (0)