Skip to content

Commit 89a428a

Browse files
authored
[NDTensors] cuTENSOR extension (#1395)
1 parent 0e5d3f8 commit 89a428a

File tree

15 files changed

+195
-107
lines changed

15 files changed

+195
-107
lines changed

NDTensors/Project.toml

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -32,20 +32,22 @@ TupleTools = "9d95972d-f1c8-5527-a6e0-b4b365fa01f6"
3232
VectorInterface = "409d34a3-91d5-4945-b6ec-7529ddf182d8"
3333

3434
[weakdeps]
35+
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
3536
CUDA = "052768ef-5323-5732-b1bb-66c8b64840ba"
37+
cuTENSOR = "011b41b2-24ef-40a8-b3eb-fa098493e9e1"
3638
HDF5 = "f67ccb44-e63f-5c2f-98bd-6dc0ccc4ba2f"
3739
Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
3840
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
3941
TBLIS = "48530278-0828-4a49-9772-0f3830dfa1e9"
40-
AMDGPU = "21141c5a-9bdb-4563-92ae-f87d6854732e"
4142

4243
[extensions]
44+
NDTensorsAMDGPUExt = "AMDGPU"
4345
NDTensorsCUDAExt = "CUDA"
46+
NDTensorscuTENSORExt = "cuTENSOR"
4447
NDTensorsHDF5Ext = "HDF5"
4548
NDTensorsMetalExt = "Metal"
4649
NDTensorsOctavianExt = "Octavian"
4750
NDTensorsTBLISExt = "TBLIS"
48-
NDTensorsAMDGPUExt = "AMDGPU"
4951

5052
[compat]
5153
Accessors = "0.1.33"
Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
module NDTensorscuTENSORExt
2+
include("contract.jl")
3+
end
Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,20 @@
1+
using NDTensors: NDTensors, DenseTensor, array
2+
using NDTensors.Expose: Exposed, unexpose
3+
using cuTENSOR: cuTENSOR, CuArray, CuTensor
4+
5+
function NDTensors.contract!(
6+
R::Exposed{<:CuArray,<:DenseTensor},
7+
labelsR,
8+
T1::Exposed{<:CuArray,<:DenseTensor},
9+
labelsT1,
10+
T2::Exposed{<:CuArray,<:DenseTensor},
11+
labelsT2,
12+
α::Number=one(Bool),
13+
β::Number=zero(Bool),
14+
)
15+
cuR = CuTensor(array(unexpose(R)), collect(labelsR))
16+
cuT1 = CuTensor(array(unexpose(T1)), collect(labelsT1))
17+
cuT2 = CuTensor(array(unexpose(T2)), collect(labelsT2))
18+
cuTENSOR.mul!(cuR, cuT1, cuT2, α, β)
19+
return R
20+
end

NDTensors/src/blocksparse/contract_generic.jl

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ function contract!(
104104
return R
105105
end
106106

107+
using NDTensors.Expose: expose
107108
# Function barrier to improve type stability,
108109
# since `Folds`/`FLoops` is not type stable:
109110
# https://discourse.julialang.org/t/type-instability-in-floop-reduction/68598
@@ -139,11 +140,11 @@ function _contract!(
139140
)
140141

141142
contract!(
142-
R[blockR],
143+
expose(R[blockR]),
143144
labelsR,
144-
tensor1[blocktensor1],
145+
expose(tensor1[blocktensor1]),
145146
labelstensor1,
146-
tensor2[blocktensor2],
147+
expose(tensor2[blocktensor2]),
147148
labelstensor2,
148149
α,
149150
β,

NDTensors/src/blocksparse/contract_sequential.jl

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ function contract!(
5555
R, labelsR, tensor1, labelstensor1, tensor2, labelstensor2, contraction_plan, executor
5656
)
5757
end
58-
58+
using NDTensors.Expose: expose
5959
###########################################################################
6060
# Old version
6161
# TODO: DELETE, keeping around for now for testing/benchmarking.
@@ -97,7 +97,9 @@ function contract!(
9797
# Overwrite the block of R
9898
β = zero(ElR)
9999
end
100-
contract!(Rblock, labelsR, T1block, labelsT1, T2block, labelsT2, α, β)
100+
contract!(
101+
expose(Rblock), labelsR, expose(T1block), labelsT1, expose(T2block), labelsT2, α, β
102+
)
101103
end
102104
return R
103105
end

NDTensors/src/blocksparse/contract_threads.jl

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,4 @@
1+
using NDTensors.Expose: expose
12
# TODO: This seems to be faster than the newer version using `Folds.jl`
23
# in `contract_folds.jl`, investigate why.
34
function contract_blocks!(
@@ -115,7 +116,16 @@ function contract!(
115116
ElR, labelsR, blockR, indsR, labelsT1, blockT1, indsT1, labelsT2, blockT2, indsT2
116117
)
117118

118-
contract!(blockR, labelsR, blockT1, labelsT1, blockT2, labelsT2, α, β)
119+
contract!(
120+
expose(blockR),
121+
labelsR,
122+
expose(blockT1),
123+
labelsT1,
124+
expose(blockT2),
125+
labelsT2,
126+
α,
127+
β,
128+
)
119129
# Now keep adding to the block, since it has
120130
# been written to
121131
# R .= α .* (T1 * T2) .+ R

NDTensors/src/diag/tensoralgebra/contract.jl

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,28 @@ function _contract!!(
5858
return R
5959
end
6060

61+
function contract!(
62+
output_tensor::Exposed{<:AbstractArray,<:DiagTensor},
63+
labelsoutput_tensor,
64+
tensor1::Exposed,
65+
labelstensor1,
66+
tensor2::Exposed,
67+
labelstensor2,
68+
α::Number=one(Bool),
69+
β::Number=zero(Bool),
70+
)
71+
@assert isone(α)
72+
@assert iszero(β)
73+
return contract!(
74+
unexpose(output_tensor),
75+
labelsoutput_tensor,
76+
unexpose(tensor1),
77+
labelstensor1,
78+
unexpose(tensor2),
79+
labelstensor2,
80+
)
81+
end
82+
6183
function contract!(
6284
R::DiagTensor{ElR,NR},
6385
labelsR,

NDTensors/src/lib/TensorAlgebra/test/Project.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,3 +4,6 @@ Combinatorics = "861a8166-3701-5b0c-9a16-15d98fcdc6aa"
44
EllipsisNotation = "da5c29d0-fa7d-589e-88eb-ea29b0a81949"
55
NDTensors = "23ae76d9-e61a-49c4-8f12-3f1a16adf9cf"
66
TensorOperations = "6aa20fa7-93e2-5fca-9bc0-fbd0db3c71a2"
7+
8+
[compat]
9+
TensorOperations = "4.1.1"

NDTensors/src/lib/TensorAlgebra/test/test_basics.jl

Lines changed: 59 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,6 @@ using NDTensors.TensorAlgebra:
88
using NDTensors: NDTensors
99
include(joinpath(pkgdir(NDTensors), "test", "NDTensorsTestUtils", "NDTensorsTestUtils.jl"))
1010
using .NDTensorsTestUtils: default_rtol
11-
using TensorOperations: TensorOperations
1211
using Test: @test, @test_broken, @testset
1312
const elts = (Float32, Float64, Complex{Float32}, Complex{Float64})
1413
@testset "BlockedPermutation" begin
@@ -111,62 +110,67 @@ end
111110
@test eltype(a_split) === elt
112111
@test a_split reshape(a, (2, 3, 20))
113112
end
114-
@testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts
115-
dims = (2, 3, 4, 5, 6, 7, 8, 9, 10)
116-
labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i)
117-
for (d1s, d2s, d_dests) in (
118-
((1, 2), (1, 2), ()),
119-
((1, 2), (2, 1), ()),
120-
((1, 2), (2, 3), (1, 3)),
121-
((1, 2), (2, 3), (3, 1)),
122-
((2, 1), (2, 3), (3, 1)),
123-
((1, 2, 3), (2, 3, 4), (1, 4)),
124-
((1, 2, 3), (2, 3, 4), (4, 1)),
125-
((3, 2, 1), (4, 2, 3), (4, 1)),
126-
((1, 2, 3), (3, 4), (1, 2, 4)),
127-
((1, 2, 3), (3, 4), (4, 1, 2)),
128-
((1, 2, 3), (3, 4), (2, 4, 1)),
129-
((3, 1, 2), (3, 4), (2, 4, 1)),
130-
((3, 2, 1), (4, 3), (2, 4, 1)),
131-
((1, 2, 3, 4, 5, 6), (4, 5, 6, 7, 8, 9), (1, 2, 3, 7, 8, 9)),
132-
((2, 4, 5, 1, 6, 3), (6, 4, 9, 8, 5, 7), (1, 7, 2, 8, 3, 9)),
133-
)
134-
a1 = randn(elt1, map(i -> dims[i], d1s))
135-
labels1 = map(i -> labels[i], d1s)
136-
a2 = randn(elt2, map(i -> dims[i], d2s))
137-
labels2 = map(i -> labels[i], d2s)
138-
labels_dest = map(i -> labels[i], d_dests)
139-
140-
# Don't specify destination labels
141-
a_dest, labels_dest′ = TensorAlgebra.contract(a1, labels1, a2, labels2)
142-
a_dest_tensoroperations = TensorOperations.tensorcontract(
143-
labels_dest′, a1, labels1, a2, labels2
113+
## Right now TensorOperations version is downgraded when using cuTENSOR to `v0.7` we
114+
## are waiting for TensorOperations to support the breaking changes in cuTENSOR 2.x
115+
if !("cutensor" ARGS)
116+
using TensorOperations: TensorOperations
117+
@testset "contract (eltype1=$elt1, eltype2=$elt2)" for elt1 in elts, elt2 in elts
118+
dims = (2, 3, 4, 5, 6, 7, 8, 9, 10)
119+
labels = (:a, :b, :c, :d, :e, :f, :g, :h, :i)
120+
for (d1s, d2s, d_dests) in (
121+
((1, 2), (1, 2), ()),
122+
((1, 2), (2, 1), ()),
123+
((1, 2), (2, 3), (1, 3)),
124+
((1, 2), (2, 3), (3, 1)),
125+
((2, 1), (2, 3), (3, 1)),
126+
((1, 2, 3), (2, 3, 4), (1, 4)),
127+
((1, 2, 3), (2, 3, 4), (4, 1)),
128+
((3, 2, 1), (4, 2, 3), (4, 1)),
129+
((1, 2, 3), (3, 4), (1, 2, 4)),
130+
((1, 2, 3), (3, 4), (4, 1, 2)),
131+
((1, 2, 3), (3, 4), (2, 4, 1)),
132+
((3, 1, 2), (3, 4), (2, 4, 1)),
133+
((3, 2, 1), (4, 3), (2, 4, 1)),
134+
((1, 2, 3, 4, 5, 6), (4, 5, 6, 7, 8, 9), (1, 2, 3, 7, 8, 9)),
135+
((2, 4, 5, 1, 6, 3), (6, 4, 9, 8, 5, 7), (1, 7, 2, 8, 3, 9)),
144136
)
145-
@test a_dest a_dest_tensoroperations
137+
a1 = randn(elt1, map(i -> dims[i], d1s))
138+
labels1 = map(i -> labels[i], d1s)
139+
a2 = randn(elt2, map(i -> dims[i], d2s))
140+
labels2 = map(i -> labels[i], d2s)
141+
labels_dest = map(i -> labels[i], d_dests)
146142

147-
# Specify destination labels
148-
a_dest = TensorAlgebra.contract(labels_dest, a1, labels1, a2, labels2)
149-
a_dest_tensoroperations = TensorOperations.tensorcontract(
150-
labels_dest, a1, labels1, a2, labels2
151-
)
152-
@test a_dest a_dest_tensoroperations
143+
# Don't specify destination labels
144+
a_dest, labels_dest′ = TensorAlgebra.contract(a1, labels1, a2, labels2)
145+
a_dest_tensoroperations = TensorOperations.tensorcontract(
146+
labels_dest, a1, labels1, a2, labels2
147+
)
148+
@test a_dest a_dest_tensoroperations
153149

154-
# Specify α and β
155-
elt_dest = promote_type(elt1, elt2)
156-
# TODO: Using random `α`, `β` causing
157-
# random test failures, investigate why.
158-
α = elt_dest(1.2) # randn(elt_dest)
159-
β = elt_dest(2.4) # randn(elt_dest)
160-
a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests))
161-
a_dest = copy(a_dest_init)
162-
TensorAlgebra.contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
163-
a_dest_tensoroperations = TensorOperations.tensorcontract(
164-
labels_dest, a1, labels1, a2, labels2
165-
)
166-
## Here we loosened the tolerance because of some floating point roundoff issue.
167-
## with Float32 numbers
168-
@test a_dest α * a_dest_tensoroperations + β * a_dest_init rtol =
169-
10 * default_rtol(elt_dest)
150+
# Specify destination labels
151+
a_dest = TensorAlgebra.contract(labels_dest, a1, labels1, a2, labels2)
152+
a_dest_tensoroperations = TensorOperations.tensorcontract(
153+
labels_dest, a1, labels1, a2, labels2
154+
)
155+
@test a_dest a_dest_tensoroperations
156+
157+
# Specify α and β
158+
elt_dest = promote_type(elt1, elt2)
159+
# TODO: Using random `α`, `β` causing
160+
# random test failures, investigate why.
161+
α = elt_dest(1.2) # randn(elt_dest)
162+
β = elt_dest(2.4) # randn(elt_dest)
163+
a_dest_init = randn(elt_dest, map(i -> dims[i], d_dests))
164+
a_dest = copy(a_dest_init)
165+
TensorAlgebra.contract!(a_dest, labels_dest, a1, labels1, a2, labels2, α, β)
166+
a_dest_tensoroperations = TensorOperations.tensorcontract(
167+
labels_dest, a1, labels1, a2, labels2
168+
)
169+
## Here we loosened the tolerance because of some floating point roundoff issue.
170+
## with Float32 numbers
171+
@test a_dest α * a_dest_tensoroperations + β * a_dest_init rtol =
172+
10 * default_rtol(elt_dest)
173+
end
170174
end
171175
end
172176
@testset "qr (eltype=$elt)" for elt in elts
@@ -182,4 +186,5 @@ end
182186
@test a a′
183187
end
184188
end
189+
185190
end

NDTensors/src/tensoroperations/generic_tensor_operations.jl

Lines changed: 32 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -116,6 +116,7 @@ function contract(
116116
return output_tensor
117117
end
118118

119+
using NDTensors.Expose: Exposed, expose, unexpose
119120
# Overload this function for immutable storage types
120121
function _contract!!(
121122
output_tensor::Tensor,
@@ -129,23 +130,50 @@ function _contract!!(
129130
)
130131
if α 1 || β 0
131132
contract!(
132-
output_tensor,
133+
expose(output_tensor),
133134
labelsoutput_tensor,
134-
tensor1,
135+
expose(tensor1),
135136
labelstensor1,
136-
tensor2,
137+
expose(tensor2),
137138
labelstensor2,
138139
α,
139140
β,
140141
)
141142
else
142143
contract!(
143-
output_tensor, labelsoutput_tensor, tensor1, labelstensor1, tensor2, labelstensor2
144+
expose(output_tensor),
145+
labelsoutput_tensor,
146+
expose(tensor1),
147+
labelstensor1,
148+
expose(tensor2),
149+
labelstensor2,
144150
)
145151
end
146152
return output_tensor
147153
end
148154

155+
function contract!(
156+
output_tensor::Exposed,
157+
labelsoutput_tensor,
158+
tensor1::Exposed,
159+
labelstensor1,
160+
tensor2::Exposed,
161+
labelstensor2,
162+
α::Number=one(Bool),
163+
β::Number=zero(Bool),
164+
)
165+
return contract!(
166+
unexpose(output_tensor),
167+
labelsoutput_tensor,
168+
unexpose(tensor1),
169+
labelstensor1,
170+
unexpose(tensor2),
171+
labelstensor2,
172+
α,
173+
β,
174+
)
175+
end
176+
149177
# Is this generic for all storage types?
150178
function contract!!(
151179
output_tensor::Tensor,

0 commit comments

Comments
 (0)