Skip to content

Commit 50b953c

Browse files
authored
CUTENSOR: Reduce amount of broadcasts compiled during tests. (#2527)
[only subpackages]
1 parent 6221589 commit 50b953c

File tree

3 files changed

+30
-35
lines changed

3 files changed

+30
-35
lines changed

lib/cutensor/test/elementwise_binary.jl

Lines changed: 10 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -36,12 +36,12 @@ eltypes = [(Float16, Float16),
3636
opAC = cuTENSOR.OP_ADD
3737
dD = elementwise_binary_execute!(1, dA, indsA, opA, 1, dC, indsC, opC, dD, indsC, opAC)
3838
D = collect(dD)
39-
@test D permutedims(A, p) .+ C
39+
@test D permutedims(A, p) + C
4040

4141
# using integers as indices
4242
dD = elementwise_binary_execute!(1, dA, 1:N, opA, 1, dC, p, opC, dD, p, opAC)
4343
D = collect(dD)
44-
@test D permutedims(A, p) .+ C
44+
@test D permutedims(A, p) + C
4545

4646
# multiplication as binary operator
4747
opAC = cuTENSOR.OP_MUL
@@ -57,7 +57,7 @@ eltypes = [(Float16, Float16),
5757
γ = rand(eltyD)
5858
dD = elementwise_binary_execute!(α, dA, indsA, opA, γ, dC, indsC, opC, dD, indsC, opAC)
5959
D = collect(dD)
60-
@test D α .* conj.(permutedims(A, p)) .+ γ .* C
60+
@test D α * conj.(permutedims(A, p)) + γ * C
6161

6262
# test in-place, and more complicated unary and binary operations
6363
opA = eltyA <: Complex ? cuTENSOR.OP_IDENTITY : cuTENSOR.OP_SQRT
@@ -70,12 +70,12 @@ eltypes = [(Float16, Float16),
7070
D = collect(dC)
7171
if eltyD <: Complex
7272
if eltyA <: Complex
73-
@test D α .* permutedims(A, p) .+ γ .* conj.(C)
73+
@test D α * permutedims(A, p) + γ * conj.(C)
7474
else
75-
@test D α .* sqrt.(eltyD.(permutedims(A, p))) .+ γ .* conj.(C)
75+
@test D α * sqrt.(eltyD.(permutedims(A, p))) + γ * conj.(C)
7676
end
7777
else
78-
@test D max.(α .* sqrt.(eltyD.(permutedims(A, p))), γ .* C)
78+
@test D max.(α * sqrt.(eltyD.(permutedims(A, p))), γ * C)
7979
end
8080

8181
# using CuTensor type
@@ -85,24 +85,24 @@ eltypes = [(Float16, Float16),
8585
ctC = CuTensor(dC, indsC)
8686
ctD = ctA + ctC
8787
hD = collect(ctD.data)
88-
@test hD permutedims(A, p) .+ C
88+
@test hD permutedims(A, p) + C
8989
ctD = ctA - ctC
9090
hD = collect(ctD.data)
91-
@test hD permutedims(A, p) .- C
91+
@test hD permutedims(A, p) - C
9292

9393
α = rand(eltyD)
9494
ctC_copy = copy(ctC)
9595
ctD = LinearAlgebra.axpy!(α, ctA, ctC_copy)
9696
@test ctD == ctC_copy
9797
hD = collect(ctD.data)
98-
@test hD α.*permutedims(A, p) .+ C
98+
@test hD α * permutedims(A, p) + C
9999

100100
γ = rand(eltyD)
101101
ctC_copy = copy(ctC)
102102
ctD = LinearAlgebra.axpby!(α, ctA, γ, ctC_copy)
103103
@test ctD == ctC_copy
104104
hD = collect(ctD.data)
105-
@test hD α.*permutedims(A, p) .+ γ.*C
105+
@test hD α * permutedims(A, p) + γ * C
106106
end
107107
end
108108

lib/cutensor/test/elementwise_trinary.jl

Lines changed: 19 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -46,36 +46,35 @@ eltypes = [(Float16, Float16, Float16),
4646
dD = elementwise_trinary_execute!(1, dA, indsA, opA, 1, dB, indsB, opB,
4747
1, dC, indsC, opC, dD, indsC, opAB, opABC)
4848
D = collect(dD)
49-
@test D permutedims(A, pA) .+ permutedims(B, pB) .+ C
49+
@test D permutedims(A, pA) + permutedims(B, pB) + C
5050

5151
# using integers as indices
5252
dD = elementwise_trinary_execute!(1, dA, ipA, opA, 1, dB, ipB, opB,
5353
1, dC, 1:N, opC, dD, 1:N, opAB, opABC)
5454
D = collect(dD)
55-
@test D permutedims(A, pA) .+ permutedims(B, pB) .+ C
55+
@test D permutedims(A, pA) + permutedims(B, pB) + C
5656

5757
# multiplication as binary operator
5858
opAB = cuTENSOR.OP_MUL
5959
opABC = cuTENSOR.OP_ADD
6060
dD = elementwise_trinary_execute!(1, dA, indsA, opA, 1, dB, indsB, opB,
6161
1, dC, indsC, opC, dD, indsC, opAB, opABC)
6262
D = collect(dD)
63-
@test D (eltyD.(permutedims(A, pA)) .* eltyD.(permutedims(B, pB))) .+ C
63+
@test D (eltyD.(permutedims(A, pA)) .* eltyD.(permutedims(B, pB))) + C
6464

6565
opAB = cuTENSOR.OP_ADD
6666
opABC = cuTENSOR.OP_MUL
6767
dD = elementwise_trinary_execute!(1, dA, indsA, opA, 1, dB, indsB, opB,
6868
1, dC, indsC, opC, dD, indsC, opAB, opABC)
6969
D = collect(dD)
70-
@test D (eltyD.(permutedims(A, pA)) .+ eltyD.(permutedims(B, pB))) .* C
70+
@test D (eltyD.(permutedims(A, pA)) + eltyD.(permutedims(B, pB))) .* C
7171

7272
opAB = cuTENSOR.OP_MUL
7373
opABC = cuTENSOR.OP_MUL
7474
dD = elementwise_trinary_execute!(1, dA, indsA, opA, 1, dB, indsB, opB,
7575
1, dC, indsC, opC, dD, indsC, opAB, opABC)
7676
D = collect(dD)
77-
@test D eltyD.(permutedims(A, pA)) .*
78-
eltyD.(permutedims(B, pB)) .* C
77+
@test D eltyD.(permutedims(A, pA)) .* eltyD.(permutedims(B, pB)) .* C
7978

8079
# with non-trivial coefficients and conjugation
8180
α = rand(eltyD)
@@ -88,24 +87,22 @@ eltypes = [(Float16, Float16, Float16),
8887
dD = elementwise_trinary_execute!(α, dA, indsA, opA, β, dB, indsB, opB,
8988
γ, dC, indsC, opC, dD, indsC, opAB, opABC)
9089
D = collect(dD)
91-
@test D α .* conj.(permutedims(A, pA)) .+ β .* permutedims(B, pB) .+ γ .* C
90+
@test D α * conj.(permutedims(A, pA)) + β * permutedims(B, pB) + γ * C
9291

9392
opB = eltyB <: Complex ? cuTENSOR.OP_CONJ : cuTENSOR.OP_IDENTITY
9493
opAB = cuTENSOR.OP_ADD
9594
opABC = cuTENSOR.OP_ADD
9695
dD = elementwise_trinary_execute!(α, dA, indsA, opA, β, dB, indsB, opB,
9796
γ, dC, indsC, opC, dD, indsC, opAB, opABC)
9897
D = collect(dD)
99-
@test D α .* conj.(permutedims(A, pA)) .+
100-
β .* conj.(permutedims(B, pB)) .+ γ .* C
101-
98+
@test D α * conj.(permutedims(A, pA)) + β * conj.(permutedims(B, pB)) + γ * C
10299
opA = cuTENSOR.OP_IDENTITY
103100
opAB = cuTENSOR.OP_MUL
104101
opABC = cuTENSOR.OP_ADD
105102
dD = elementwise_trinary_execute!(α, dA, indsA, opA, β, dB, indsB, opB,
106103
γ, dC, indsC, opC, dD, indsC, opAB, opABC)
107104
D = collect(dD)
108-
@test D α .* permutedims(A, pA) .* β .* conj.(permutedims(B, pB)) .+ γ .* C
105+
@test D * permutedims(A, pA)) .* * conj.(permutedims(B, pB))) + γ * C
109106

110107
# test in-place, and more complicated unary and binary operations
111108
opA = eltyA <: Complex ? cuTENSOR.OP_IDENTITY : cuTENSOR.OP_SQRT
@@ -122,24 +119,22 @@ eltypes = [(Float16, Float16, Float16),
122119
D = collect(dD)
123120
if eltyD <: Complex
124121
if eltyA <: Complex && eltyB <: Complex
125-
@test D α .* permutedims(A, pA) .* β .* permutedims(B, pB) .+
126-
γ .* conj.(C)
122+
@test D * permutedims(A, pA)) .*
123+
* permutedims(B, pB)) + γ * conj.(C)
127124
elseif eltyB <: Complex
128-
@test D α .* sqrt.(eltyD.(permutedims(A, pA))) .*
129-
β .* permutedims(B, pB) .+ γ .* conj.(C)
125+
@test D * sqrt.(eltyD.(permutedims(A, pA)))) .*
126+
* permutedims(B, pB)) + γ * conj.(C)
130127
elseif eltyB <: Complex
131-
@test D α .* permutedims(A, pA) .*
132-
β .* sqrt.(eltyD.(permutedims(B, pB))) .+
133-
γ .* conj.(C)
128+
@test D * permutedims(A, pA)) .*
129+
* sqrt.(eltyD.(permutedims(B, pB)))) + γ * conj.(C)
134130
else
135-
@test D α .* sqrt.(eltyD.(permutedims(A, pA))) .*
136-
β .* sqrt.(eltyD.(permutedims(B, pB))) .+
137-
γ .* conj.(C)
131+
@test D * sqrt.(eltyD.(permutedims(A, pA)))) .*
132+
* sqrt.(eltyD.(permutedims(B, pB)))) + γ * conj.(C)
138133
end
139134
else
140-
@test D max.(min.(α .* sqrt.(eltyD.(permutedims(A, pA))),
141-
β .* sqrt.(eltyD.(permutedims(B, pB)))),
142-
γ .* C)
135+
@test D max.(min.(α * sqrt.(eltyD.(permutedims(A, pA))),
136+
β * sqrt.(eltyD.(permutedims(B, pB)))),
137+
γ * C)
143138
end
144139
end
145140
end

lib/cutensor/test/reductions.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ eltypes = [(Float16, Float16),
5858
γ = rand(eltyC)
5959
dC = reduce!(α, dA, indsA, opA, γ, dC, indsC, opC, opReduce)
6060
@test reshape(collect(dC), (dimsC..., ones(Int,NA-NC)...))
61-
α .* conj.(sum(permutedims(A, p); dims = ((NC+1:NA)...,))) .+ γ .* C
61+
α * conj.(sum(permutedims(A, p); dims = ((NC+1:NA)...,))) + γ * C
6262
end
6363
end
6464

0 commit comments

Comments
 (0)