@@ -8,6 +8,7 @@ using Test
8
8
using Random
9
9
using LinearAlgebra
10
10
using LinearAlgebra: BLAS
11
+ using DelimitedFiles
11
12
using Statistics
12
13
Random. seed! (1234 )
13
14
rtype (:: Type{Complex{T}} ) where {T} = T
@@ -43,15 +44,17 @@ global Clarge_her2k= [zeros(T, χlarge, χlarge) for T=(Float32, Float64, Comple
43
44
global Clarge_syr2k= [zeros (T, χlarge, χlarge) for T= (Float32, Float64, ComplexF32, ComplexF64)]
44
45
global Clarge_herk = [zeros (T, χlarge, χlarge) for T= (Float32, Float64, ComplexF32, ComplexF64)]
45
46
global Clarge_syrk = [zeros (T, χlarge, χlarge) for T= (Float32, Float64, ComplexF32, ComplexF64)]
46
- # global Clarge_trmm = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)]
47
- # global Clarge_trsm = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)]
47
+ global Clarge_trmm = [zeros (T, χlarge, χlarge) for T= (Float32, Float64, ComplexF32, ComplexF64)]
48
+ # global Clarge_trsm = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)] TRSM is unstable on random A.
48
49
global Csmall_gemm = [zeros (T, χsmall, χsmall) for T= (Float32, Float64, ComplexF32, ComplexF64)]
49
50
global Csmall_hemm = [zeros (T, χsmall, χsmall) for T= (Float32, Float64, ComplexF32, ComplexF64)]
50
51
global Csmall_symm = [zeros (T, χsmall, χsmall) for T= (Float32, Float64, ComplexF32, ComplexF64)]
51
52
global Csmall_her2k= [zeros (T, χsmall, χsmall) for T= (Float32, Float64, ComplexF32, ComplexF64)]
52
53
global Csmall_syr2k= [zeros (T, χsmall, χsmall) for T= (Float32, Float64, ComplexF32, ComplexF64)]
53
54
global Csmall_herk = [zeros (T, χsmall, χsmall) for T= (Float32, Float64, ComplexF32, ComplexF64)]
54
55
global Csmall_syrk = [zeros (T, χsmall, χsmall) for T= (Float32, Float64, ComplexF32, ComplexF64)]
56
+ global Csmall_trmm = [zeros (T, χsmall, χsmall) for T= (Float32, Float64, ComplexF32, ComplexF64)]
57
+ global Csmall_trsm = [zeros (T, χsmall, χsmall) for T= (Float32, Float64, ComplexF32, ComplexF64)]
55
58
56
59
global Cst_lg_gemm = [zeros (T, χlarge÷ 2 , χlarge÷ 2 ) for T= (Float32, Float64, ComplexF32, ComplexF64)]
57
60
global Cst_lg_hemm = [zeros (T, χlarge÷ 2 , χlarge÷ 2 ) for T= (Float32, Float64, ComplexF32, ComplexF64)]
@@ -102,13 +105,16 @@ for (i, T)=zip(1:4, [Float32, Float64, ComplexF32, ComplexF64])
102
105
Clarge_syr2k[i] .= elconv .(Clarge_base)
103
106
Clarge_herk[i] .= elconv .(Clarge_base)
104
107
Clarge_syrk[i] .= elconv .(Clarge_base)
108
+ Clarge_trmm[i] .= elconv .(Clarge_base)
105
109
Csmall_gemm[i] .= elconv .(Csmall_base)
106
110
Csmall_hemm[i] .= elconv .(Csmall_base)
107
111
Csmall_symm[i] .= elconv .(Csmall_base)
108
112
Csmall_her2k[i] .= elconv .(Csmall_base)
109
113
Csmall_syr2k[i] .= elconv .(Csmall_base)
110
114
Csmall_herk[i] .= elconv .(Csmall_base)
111
115
Csmall_syrk[i] .= elconv .(Csmall_base)
116
+ Csmall_trmm[i] .= elconv .(Csmall_base)
117
+ Csmall_trsm[i] .= elconv .(Csmall_base)
112
118
113
119
# Strided.
114
120
Ast_lg = view (Alarge, 1 : 2 : χlarge, 1 : 2 : χlarge)
@@ -131,6 +137,9 @@ for (i, T)=zip(1:4, [Float32, Float64, ComplexF32, ComplexF64])
131
137
locl_herk! (' U' , ' N' , αR, Asmall, βR, Csmall_herk[i])
132
138
BLAS. syrk! (' U' , ' N' , αu, Alarge, βu, Clarge_syrk[i])
133
139
BLAS. syrk! (' U' , ' N' , αu, Asmall, βu, Csmall_syrk[i])
140
+ BLAS. trmm! (' L' , ' U' , ' N' , ' N' , αu, Alarge, Clarge_trmm[i])
141
+ BLAS. trmm! (' L' , ' U' , ' N' , ' N' , αu, Asmall, Csmall_trmm[i])
142
+ BLAS. trsm! (' L' , ' U' , ' N' , ' N' , αu, Asmall, Csmall_trsm[i])
134
143
135
144
# Execute: generic-strided.
136
145
Cst_lg_gemm[i] .= Ast_lg * Bst_lg
@@ -185,13 +194,16 @@ for (i, T)=zip(1:4, [Float32, Float64, ComplexF32, ComplexF64])
185
194
Clarge_syr2k_cur = T .(elconv .(Clarge_base))
186
195
Clarge_herk_cur = T .(elconv .(Clarge_base))
187
196
Clarge_syrk_cur = T .(elconv .(Clarge_base))
197
+ Clarge_trmm_cur = T .(elconv .(Clarge_base))
188
198
Csmall_gemm_cur = T .(elconv .(Csmall_base))
189
199
Csmall_hemm_cur = T .(elconv .(Csmall_base))
190
200
Csmall_symm_cur = T .(elconv .(Csmall_base))
191
201
Csmall_her2k_cur = T .(elconv .(Csmall_base))
192
202
Csmall_syr2k_cur = T .(elconv .(Csmall_base))
193
203
Csmall_herk_cur = T .(elconv .(Csmall_base))
194
204
Csmall_syrk_cur = T .(elconv .(Csmall_base))
205
+ Csmall_trmm_cur = T .(elconv .(Csmall_base))
206
+ Csmall_trsm_cur = T .(elconv .(Csmall_base))
195
207
196
208
# Strided.
197
209
Ast_lg = view (Alarge, 1 : 2 : χlarge, 1 : 2 : χlarge)
@@ -214,6 +226,9 @@ for (i, T)=zip(1:4, [Float32, Float64, ComplexF32, ComplexF64])
214
226
locl_herk! (' U' , ' N' , αR, Asmall, βR, Csmall_herk_cur)
215
227
BLAS. syrk! (' U' , ' N' , αu, Alarge, βu, Clarge_syrk_cur)
216
228
BLAS. syrk! (' U' , ' N' , αu, Asmall, βu, Csmall_syrk_cur)
229
+ BLAS. trmm! (' L' , ' U' , ' N' , ' N' , αu, Alarge, Clarge_trmm_cur)
230
+ BLAS. trmm! (' L' , ' U' , ' N' , ' N' , αu, Asmall, Csmall_trmm_cur)
231
+ BLAS. trsm! (' L' , ' U' , ' N' , ' N' , αu, Asmall, Csmall_trsm_cur)
217
232
218
233
# Execute: generic-strided.
219
234
Cst_lg_gemm_cur = Ast_lg * Bst_lg
@@ -224,28 +239,31 @@ for (i, T)=zip(1:4, [Float32, Float64, ComplexF32, ComplexF64])
224
239
Cst_sm_symm_cur = Symmetric (Ast_sm) * Bst_sm
225
240
226
241
# Check.
227
- @test zrtest (mean (abs .(Clarge_gemm_cur - Clarge_gemm[i] )), 1e-6 * χlarge^ 1.2 , " gemm" )
228
- @test zrtest (mean (abs .(Clarge_hemm_cur - Clarge_hemm[i] )), 1e-6 * χlarge^ 1.2 , " hemm" )
229
- @test zrtest (mean (abs .(Clarge_symm_cur - Clarge_symm[i] )), 1e-6 * χlarge^ 1.2 , " symm" )
230
- @test zrtest (mean (abs .(Clarge_her2k_cur - Clarge_her2k[i])), 1e-6 * χlarge^ 1.2 , " her2k" )
231
- @test zrtest (mean (abs .(Clarge_syr2k_cur - Clarge_syr2k[i])), 1e-6 * χlarge^ 1.2 , " syr2k" )
232
- @test zrtest (mean (abs .(Clarge_herk_cur - Clarge_herk[i] )), 1e-6 * χlarge^ 1.2 , " herk" )
233
- @test zrtest (mean (abs .(Clarge_syrk_cur - Clarge_syrk[i] )), 1e-6 * χlarge^ 1.2 , " syrk" )
234
- @test zrtest (mean (abs .(Csmall_gemm_cur - Csmall_gemm[i] )), 1e-6 * χsmall^ 1.2 , " gemm" )
235
- @test zrtest (mean (abs .(Csmall_hemm_cur - Csmall_hemm[i] )), 1e-6 * χsmall^ 1.2 , " hemm" )
236
- @test zrtest (mean (abs .(Csmall_symm_cur - Csmall_symm[i] )), 1e-6 * χsmall^ 1.2 , " symm" )
237
- @test zrtest (mean (abs .(Csmall_her2k_cur - Csmall_her2k[i])), 1e-6 * χsmall^ 1.2 , " her2k" )
238
- @test zrtest (mean (abs .(Csmall_syr2k_cur - Csmall_syr2k[i])), 1e-6 * χsmall^ 1.2 , " syr2k" )
239
- @test zrtest (mean (abs .(Csmall_herk_cur - Csmall_herk[i] )), 1e-6 * χsmall^ 1.2 , " herk" )
240
- @test zrtest (mean (abs .(Csmall_syrk_cur - Csmall_syrk[i] )), 1e-6 * χsmall^ 1.2 , " syrk" )
242
+ @test zrtest (mean (abs .(Clarge_gemm_cur - Clarge_gemm[i] )), 1e-6 * χlarge^ 1.2 , " 500_gemm_$T " )
243
+ @test zrtest (mean (abs .(Clarge_hemm_cur - Clarge_hemm[i] )), 1e-6 * χlarge^ 1.2 , " 500_hemm_$T " )
244
+ @test zrtest (mean (abs .(Clarge_symm_cur - Clarge_symm[i] )), 1e-6 * χlarge^ 1.2 , " 500_symm_$T " )
245
+ @test zrtest (mean (abs .(Clarge_her2k_cur - Clarge_her2k[i])), 1e-6 * χlarge^ 1.2 , " 500_her2k_$T " )
246
+ @test zrtest (mean (abs .(Clarge_syr2k_cur - Clarge_syr2k[i])), 1e-6 * χlarge^ 1.2 , " 500_syr2k_$T " )
247
+ @test zrtest (mean (abs .(Clarge_herk_cur - Clarge_herk[i] )), 1e-6 * χlarge^ 1.2 , " 500_herk_$T " )
248
+ @test zrtest (mean (abs .(Clarge_syrk_cur - Clarge_syrk[i] )), 1e-6 * χlarge^ 1.2 , " 500_syrk_$T " )
249
+ @test zrtest (mean (abs .(Clarge_trmm_cur - Clarge_trmm[i] )), 1e-6 * χlarge^ 1.2 , " 500_trmm_$T " )
250
+ @test zrtest (mean (abs .(Csmall_gemm_cur - Csmall_gemm[i] )), 1e-6 * χsmall^ 1.2 , " 20_gemm_$T " )
251
+ @test zrtest (mean (abs .(Csmall_hemm_cur - Csmall_hemm[i] )), 1e-6 * χsmall^ 1.2 , " 20_hemm_$T " )
252
+ @test zrtest (mean (abs .(Csmall_symm_cur - Csmall_symm[i] )), 1e-6 * χsmall^ 1.2 , " 20_symm_$T " )
253
+ @test zrtest (mean (abs .(Csmall_her2k_cur - Csmall_her2k[i])), 1e-6 * χsmall^ 1.2 , " 20_her2k_$T " )
254
+ @test zrtest (mean (abs .(Csmall_syr2k_cur - Csmall_syr2k[i])), 1e-6 * χsmall^ 1.2 , " 20_syr2k_$T " )
255
+ @test zrtest (mean (abs .(Csmall_herk_cur - Csmall_herk[i] )), 1e-6 * χsmall^ 1.2 , " 20_herk_$T " )
256
+ @test zrtest (mean (abs .(Csmall_syrk_cur - Csmall_syrk[i] )), 1e-6 * χsmall^ 1.2 , " 20_syrk_$T " )
257
+ @test zrtest (mean (abs .(Csmall_trmm_cur - Csmall_trmm[i] )), 1e-6 * χsmall^ 1.2 , " 20_trmm_$T " )
258
+ @test zrtest (mean (abs .(Csmall_trsm_cur - Csmall_trsm[i] )), 1e-2 * χsmall^ 1.2 , " 20_trsm_$T " ) # Large TRSM err on random A.
241
259
242
260
# Check - strided.
243
- @test zrtest (mean (abs .(Cst_lg_gemm_cur - Cst_lg_gemm[i])), 1e-6 * χlarge^ 1.2 , " gemm " )
244
- @test zrtest (mean (abs .(Cst_sm_gemm_cur - Cst_sm_gemm[i])), 1e-6 * χsmall^ 1.2 , " gemm " )
245
- @test zrtest (mean (abs .(Cst_lg_hemm_cur - Cst_lg_hemm[i])), 1e-6 * χlarge^ 1.2 , " hemm " )
246
- @test zrtest (mean (abs .(Cst_sm_hemm_cur - Cst_sm_hemm[i])), 1e-6 * χsmall^ 1.2 , " hemm " )
247
- @test zrtest (mean (abs .(Cst_lg_symm_cur - Cst_lg_symm[i])), 1e-6 * χlarge^ 1.2 , " symm " )
248
- @test zrtest (mean (abs .(Cst_sm_symm_cur - Cst_sm_symm[i])), 1e-6 * χsmall^ 1.2 , " symm " )
261
+ @test zrtest (mean (abs .(Cst_lg_gemm_cur - Cst_lg_gemm[i])), 1e-6 * χlarge^ 1.2 , " 250_rs2_gemm_ $T " )
262
+ @test zrtest (mean (abs .(Cst_sm_gemm_cur - Cst_sm_gemm[i])), 1e-6 * χsmall^ 1.2 , " 250_rs2_gemm_ $T " )
263
+ @test zrtest (mean (abs .(Cst_lg_hemm_cur - Cst_lg_hemm[i])), 1e-6 * χlarge^ 1.2 , " 250_rs2_hemm_ $T " )
264
+ @test zrtest (mean (abs .(Cst_sm_hemm_cur - Cst_sm_hemm[i])), 1e-6 * χsmall^ 1.2 , " 250_rs2_hemm_ $T " )
265
+ @test zrtest (mean (abs .(Cst_lg_symm_cur - Cst_lg_symm[i])), 1e-6 * χlarge^ 1.2 , " 250_rs2_symm_ $T " )
266
+ @test zrtest (mean (abs .(Cst_sm_symm_cur - Cst_sm_symm[i])), 1e-6 * χsmall^ 1.2 , " 250_rs2_symm_ $T " )
249
267
end
250
268
end
251
269
0 commit comments