@@ -46,36 +46,35 @@ eltypes = [(Float16, Float16, Float16),
46
46
dD = elementwise_trinary_execute! (1 , dA, indsA, opA, 1 , dB, indsB, opB,
47
47
1 , dC, indsC, opC, dD, indsC, opAB, opABC)
48
48
D = collect (dD)
49
- @test D ≈ permutedims (A, pA) . + permutedims (B, pB) . + C
49
+ @test D ≈ permutedims (A, pA) + permutedims (B, pB) + C
50
50
51
51
# using integers as indices
52
52
dD = elementwise_trinary_execute! (1 , dA, ipA, opA, 1 , dB, ipB, opB,
53
53
1 , dC, 1 : N, opC, dD, 1 : N, opAB, opABC)
54
54
D = collect (dD)
55
- @test D ≈ permutedims (A, pA) . + permutedims (B, pB) . + C
55
+ @test D ≈ permutedims (A, pA) + permutedims (B, pB) + C
56
56
57
57
# multiplication as binary operator
58
58
opAB = cuTENSOR. OP_MUL
59
59
opABC = cuTENSOR. OP_ADD
60
60
dD = elementwise_trinary_execute! (1 , dA, indsA, opA, 1 , dB, indsB, opB,
61
61
1 , dC, indsC, opC, dD, indsC, opAB, opABC)
62
62
D = collect (dD)
63
- @test D ≈ (eltyD .(permutedims (A, pA)) .* eltyD .(permutedims (B, pB))) . + C
63
+ @test D ≈ (eltyD .(permutedims (A, pA)) .* eltyD .(permutedims (B, pB))) + C
64
64
65
65
opAB = cuTENSOR. OP_ADD
66
66
opABC = cuTENSOR. OP_MUL
67
67
dD = elementwise_trinary_execute! (1 , dA, indsA, opA, 1 , dB, indsB, opB,
68
68
1 , dC, indsC, opC, dD, indsC, opAB, opABC)
69
69
D = collect (dD)
70
- @test D ≈ (eltyD .(permutedims (A, pA)) . + eltyD .(permutedims (B, pB))) .* C
70
+ @test D ≈ (eltyD .(permutedims (A, pA)) + eltyD .(permutedims (B, pB))) .* C
71
71
72
72
opAB = cuTENSOR. OP_MUL
73
73
opABC = cuTENSOR. OP_MUL
74
74
dD = elementwise_trinary_execute! (1 , dA, indsA, opA, 1 , dB, indsB, opB,
75
75
1 , dC, indsC, opC, dD, indsC, opAB, opABC)
76
76
D = collect (dD)
77
- @test D ≈ eltyD .(permutedims (A, pA)) .*
78
- eltyD .(permutedims (B, pB)) .* C
77
+ @test D ≈ eltyD .(permutedims (A, pA)) .* eltyD .(permutedims (B, pB)) .* C
79
78
80
79
# with non-trivial coefficients and conjugation
81
80
α = rand (eltyD)
@@ -88,24 +87,22 @@ eltypes = [(Float16, Float16, Float16),
88
87
dD = elementwise_trinary_execute! (α, dA, indsA, opA, β, dB, indsB, opB,
89
88
γ, dC, indsC, opC, dD, indsC, opAB, opABC)
90
89
D = collect (dD)
91
- @test D ≈ α . * conj .(permutedims (A, pA)) . + β . * permutedims (B, pB) . + γ . * C
90
+ @test D ≈ α * conj .(permutedims (A, pA)) + β * permutedims (B, pB) + γ * C
92
91
93
92
opB = eltyB <: Complex ? cuTENSOR. OP_CONJ : cuTENSOR. OP_IDENTITY
94
93
opAB = cuTENSOR. OP_ADD
95
94
opABC = cuTENSOR. OP_ADD
96
95
dD = elementwise_trinary_execute! (α, dA, indsA, opA, β, dB, indsB, opB,
97
96
γ, dC, indsC, opC, dD, indsC, opAB, opABC)
98
97
D = collect (dD)
99
- @test D ≈ α .* conj .(permutedims (A, pA)) .+
100
- β .* conj .(permutedims (B, pB)) .+ γ .* C
101
-
98
+ @test D ≈ α * conj .(permutedims (A, pA)) + β * conj .(permutedims (B, pB)) + γ * C
102
99
opA = cuTENSOR. OP_IDENTITY
103
100
opAB = cuTENSOR. OP_MUL
104
101
opABC = cuTENSOR. OP_ADD
105
102
dD = elementwise_trinary_execute! (α, dA, indsA, opA, β, dB, indsB, opB,
106
103
γ, dC, indsC, opC, dD, indsC, opAB, opABC)
107
104
D = collect (dD)
108
- @test D ≈ α . * permutedims (A, pA) .* β . * conj .(permutedims (B, pB)) . + γ . * C
105
+ @test D ≈ (α * permutedims (A, pA)) .* (β * conj .(permutedims (B, pB))) + γ * C
109
106
110
107
# test in-place, and more complicated unary and binary operations
111
108
opA = eltyA <: Complex ? cuTENSOR. OP_IDENTITY : cuTENSOR. OP_SQRT
@@ -122,24 +119,22 @@ eltypes = [(Float16, Float16, Float16),
122
119
D = collect (dD)
123
120
if eltyD <: Complex
124
121
if eltyA <: Complex && eltyB <: Complex
125
- @test D ≈ α . * permutedims (A, pA) .* β .* permutedims (B, pB ) .+
126
- γ . * conj .(C)
122
+ @test D ≈ (α * permutedims (A, pA)) .*
123
+ (β * permutedims (B, pB)) + γ * conj .(C)
127
124
elseif eltyB <: Complex
128
- @test D ≈ α . * sqrt .(eltyD .(permutedims (A, pA))) .*
129
- β . * permutedims (B, pB) . + γ . * conj .(C)
125
+ @test D ≈ (α * sqrt .(eltyD .(permutedims (A, pA) ))) .*
126
+ (β * permutedims (B, pB)) + γ * conj .(C)
130
127
elseif eltyB <: Complex
131
- @test D ≈ α .* permutedims (A, pA) .*
132
- β .* sqrt .(eltyD .(permutedims (B, pB))) .+
133
- γ .* conj .(C)
128
+ @test D ≈ (α * permutedims (A, pA)) .*
129
+ (β * sqrt .(eltyD .(permutedims (B, pB)))) + γ * conj .(C)
134
130
else
135
- @test D ≈ α .* sqrt .(eltyD .(permutedims (A, pA))) .*
136
- β .* sqrt .(eltyD .(permutedims (B, pB))) .+
137
- γ .* conj .(C)
131
+ @test D ≈ (α * sqrt .(eltyD .(permutedims (A, pA)))) .*
132
+ (β * sqrt .(eltyD .(permutedims (B, pB)))) + γ * conj .(C)
138
133
end
139
134
else
140
- @test D ≈ max .(min .(α . * sqrt .(eltyD .(permutedims (A, pA))),
141
- β . * sqrt .(eltyD .(permutedims (B, pB)))),
142
- γ . * C)
135
+ @test D ≈ max .(min .(α * sqrt .(eltyD .(permutedims (A, pA))),
136
+ β * sqrt .(eltyD .(permutedims (B, pB)))),
137
+ γ * C)
143
138
end
144
139
end
145
140
end
0 commit comments