Skip to content

Commit 18b0ef4

Browse files
committed
Fix TRMM/TRSM Tests
- TRMM test is completely fixed upstream. - TRSM test is unstable on random A -> Only do small tests.
1 parent 7d8f12e commit 18b0ef4

File tree

1 file changed

+40
-22
lines changed

1 file changed

+40
-22
lines changed

test/init_test_mmul.jl

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ using Test
88
using Random
99
using LinearAlgebra
1010
using LinearAlgebra: BLAS
11+
using DelimitedFiles
1112
using Statistics
1213
Random.seed!(1234)
1314
rtype(::Type{Complex{T}}) where {T} = T
@@ -43,15 +44,17 @@ global Clarge_her2k= [zeros(T, χlarge, χlarge) for T=(Float32, Float64, Comple
4344
global Clarge_syr2k= [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)]
4445
global Clarge_herk = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)]
4546
global Clarge_syrk = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)]
46-
# global Clarge_trmm = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)]
47-
# global Clarge_trsm = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)]
47+
global Clarge_trmm = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)]
48+
# global Clarge_trsm = [zeros(T, χlarge, χlarge) for T=(Float32, Float64, ComplexF32, ComplexF64)] TRSM is unstable on random A.
4849
global Csmall_gemm = [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)]
4950
global Csmall_hemm = [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)]
5051
global Csmall_symm = [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)]
5152
global Csmall_her2k= [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)]
5253
global Csmall_syr2k= [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)]
5354
global Csmall_herk = [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)]
5455
global Csmall_syrk = [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)]
56+
global Csmall_trmm = [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)]
57+
global Csmall_trsm = [zeros(T, χsmall, χsmall) for T=(Float32, Float64, ComplexF32, ComplexF64)]
5558

5659
global Cst_lg_gemm = [zeros(T, χlarge÷2, χlarge÷2) for T=(Float32, Float64, ComplexF32, ComplexF64)]
5760
global Cst_lg_hemm = [zeros(T, χlarge÷2, χlarge÷2) for T=(Float32, Float64, ComplexF32, ComplexF64)]
@@ -102,13 +105,16 @@ for (i, T)=zip(1:4, [Float32, Float64, ComplexF32, ComplexF64])
102105
Clarge_syr2k[i] .= elconv.(Clarge_base)
103106
Clarge_herk[i] .= elconv.(Clarge_base)
104107
Clarge_syrk[i] .= elconv.(Clarge_base)
108+
Clarge_trmm[i] .= elconv.(Clarge_base)
105109
Csmall_gemm[i] .= elconv.(Csmall_base)
106110
Csmall_hemm[i] .= elconv.(Csmall_base)
107111
Csmall_symm[i] .= elconv.(Csmall_base)
108112
Csmall_her2k[i] .= elconv.(Csmall_base)
109113
Csmall_syr2k[i] .= elconv.(Csmall_base)
110114
Csmall_herk[i] .= elconv.(Csmall_base)
111115
Csmall_syrk[i] .= elconv.(Csmall_base)
116+
Csmall_trmm[i] .= elconv.(Csmall_base)
117+
Csmall_trsm[i] .= elconv.(Csmall_base)
112118

113119
# Strided.
114120
Ast_lg = view(Alarge, 1:2:χlarge, 1:2:χlarge)
@@ -131,6 +137,9 @@ for (i, T)=zip(1:4, [Float32, Float64, ComplexF32, ComplexF64])
131137
locl_herk!('U', 'N', αR, Asmall, βR, Csmall_herk[i])
132138
BLAS.syrk!('U', 'N', αu, Alarge, βu, Clarge_syrk[i])
133139
BLAS.syrk!('U', 'N', αu, Asmall, βu, Csmall_syrk[i])
140+
BLAS.trmm!('L', 'U', 'N', 'N', αu, Alarge, Clarge_trmm[i])
141+
BLAS.trmm!('L', 'U', 'N', 'N', αu, Asmall, Csmall_trmm[i])
142+
BLAS.trsm!('L', 'U', 'N', 'N', αu, Asmall, Csmall_trsm[i])
134143

135144
# Execute: generic-strided.
136145
Cst_lg_gemm[i] .= Ast_lg * Bst_lg
@@ -185,13 +194,16 @@ for (i, T)=zip(1:4, [Float32, Float64, ComplexF32, ComplexF64])
185194
Clarge_syr2k_cur = T.(elconv.(Clarge_base))
186195
Clarge_herk_cur = T.(elconv.(Clarge_base))
187196
Clarge_syrk_cur = T.(elconv.(Clarge_base))
197+
Clarge_trmm_cur = T.(elconv.(Clarge_base))
188198
Csmall_gemm_cur = T.(elconv.(Csmall_base))
189199
Csmall_hemm_cur = T.(elconv.(Csmall_base))
190200
Csmall_symm_cur = T.(elconv.(Csmall_base))
191201
Csmall_her2k_cur = T.(elconv.(Csmall_base))
192202
Csmall_syr2k_cur = T.(elconv.(Csmall_base))
193203
Csmall_herk_cur = T.(elconv.(Csmall_base))
194204
Csmall_syrk_cur = T.(elconv.(Csmall_base))
205+
Csmall_trmm_cur = T.(elconv.(Csmall_base))
206+
Csmall_trsm_cur = T.(elconv.(Csmall_base))
195207

196208
# Strided.
197209
Ast_lg = view(Alarge, 1:2:χlarge, 1:2:χlarge)
@@ -214,6 +226,9 @@ for (i, T)=zip(1:4, [Float32, Float64, ComplexF32, ComplexF64])
214226
locl_herk!('U', 'N', αR, Asmall, βR, Csmall_herk_cur)
215227
BLAS.syrk!('U', 'N', αu, Alarge, βu, Clarge_syrk_cur)
216228
BLAS.syrk!('U', 'N', αu, Asmall, βu, Csmall_syrk_cur)
229+
BLAS.trmm!('L', 'U', 'N', 'N', αu, Alarge, Clarge_trmm_cur)
230+
BLAS.trmm!('L', 'U', 'N', 'N', αu, Asmall, Csmall_trmm_cur)
231+
BLAS.trsm!('L', 'U', 'N', 'N', αu, Asmall, Csmall_trsm_cur)
217232

218233
# Execute: generic-strided.
219234
Cst_lg_gemm_cur = Ast_lg * Bst_lg
@@ -224,28 +239,31 @@ for (i, T)=zip(1:4, [Float32, Float64, ComplexF32, ComplexF64])
224239
Cst_sm_symm_cur = Symmetric(Ast_sm) * Bst_sm
225240

226241
# Check.
227-
@test zrtest(mean(abs.(Clarge_gemm_cur - Clarge_gemm[i] )), 1e-6*χlarge^1.2, "gemm")
228-
@test zrtest(mean(abs.(Clarge_hemm_cur - Clarge_hemm[i] )), 1e-6*χlarge^1.2, "hemm")
229-
@test zrtest(mean(abs.(Clarge_symm_cur - Clarge_symm[i] )), 1e-6*χlarge^1.2, "symm")
230-
@test zrtest(mean(abs.(Clarge_her2k_cur - Clarge_her2k[i])), 1e-6*χlarge^1.2, "her2k")
231-
@test zrtest(mean(abs.(Clarge_syr2k_cur - Clarge_syr2k[i])), 1e-6*χlarge^1.2, "syr2k")
232-
@test zrtest(mean(abs.(Clarge_herk_cur - Clarge_herk[i] )), 1e-6*χlarge^1.2, "herk")
233-
@test zrtest(mean(abs.(Clarge_syrk_cur - Clarge_syrk[i] )), 1e-6*χlarge^1.2, "syrk")
234-
@test zrtest(mean(abs.(Csmall_gemm_cur - Csmall_gemm[i] )), 1e-6*χsmall^1.2, "gemm")
235-
@test zrtest(mean(abs.(Csmall_hemm_cur - Csmall_hemm[i] )), 1e-6*χsmall^1.2, "hemm")
236-
@test zrtest(mean(abs.(Csmall_symm_cur - Csmall_symm[i] )), 1e-6*χsmall^1.2, "symm")
237-
@test zrtest(mean(abs.(Csmall_her2k_cur - Csmall_her2k[i])), 1e-6*χsmall^1.2, "her2k")
238-
@test zrtest(mean(abs.(Csmall_syr2k_cur - Csmall_syr2k[i])), 1e-6*χsmall^1.2, "syr2k")
239-
@test zrtest(mean(abs.(Csmall_herk_cur - Csmall_herk[i] )), 1e-6*χsmall^1.2, "herk")
240-
@test zrtest(mean(abs.(Csmall_syrk_cur - Csmall_syrk[i] )), 1e-6*χsmall^1.2, "syrk")
242+
@test zrtest(mean(abs.(Clarge_gemm_cur - Clarge_gemm[i] )), 1e-6*χlarge^1.2, "500_gemm_$T")
243+
@test zrtest(mean(abs.(Clarge_hemm_cur - Clarge_hemm[i] )), 1e-6*χlarge^1.2, "500_hemm_$T")
244+
@test zrtest(mean(abs.(Clarge_symm_cur - Clarge_symm[i] )), 1e-6*χlarge^1.2, "500_symm_$T")
245+
@test zrtest(mean(abs.(Clarge_her2k_cur - Clarge_her2k[i])), 1e-6*χlarge^1.2, "500_her2k_$T")
246+
@test zrtest(mean(abs.(Clarge_syr2k_cur - Clarge_syr2k[i])), 1e-6*χlarge^1.2, "500_syr2k_$T")
247+
@test zrtest(mean(abs.(Clarge_herk_cur - Clarge_herk[i] )), 1e-6*χlarge^1.2, "500_herk_$T")
248+
@test zrtest(mean(abs.(Clarge_syrk_cur - Clarge_syrk[i] )), 1e-6*χlarge^1.2, "500_syrk_$T")
249+
@test zrtest(mean(abs.(Clarge_trmm_cur - Clarge_trmm[i] )), 1e-6*χlarge^1.2, "500_trmm_$T")
250+
@test zrtest(mean(abs.(Csmall_gemm_cur - Csmall_gemm[i] )), 1e-6*χsmall^1.2, "20_gemm_$T")
251+
@test zrtest(mean(abs.(Csmall_hemm_cur - Csmall_hemm[i] )), 1e-6*χsmall^1.2, "20_hemm_$T")
252+
@test zrtest(mean(abs.(Csmall_symm_cur - Csmall_symm[i] )), 1e-6*χsmall^1.2, "20_symm_$T")
253+
@test zrtest(mean(abs.(Csmall_her2k_cur - Csmall_her2k[i])), 1e-6*χsmall^1.2, "20_her2k_$T")
254+
@test zrtest(mean(abs.(Csmall_syr2k_cur - Csmall_syr2k[i])), 1e-6*χsmall^1.2, "20_syr2k_$T")
255+
@test zrtest(mean(abs.(Csmall_herk_cur - Csmall_herk[i] )), 1e-6*χsmall^1.2, "20_herk_$T")
256+
@test zrtest(mean(abs.(Csmall_syrk_cur - Csmall_syrk[i] )), 1e-6*χsmall^1.2, "20_syrk_$T")
257+
@test zrtest(mean(abs.(Csmall_trmm_cur - Csmall_trmm[i] )), 1e-6*χsmall^1.2, "20_trmm_$T")
258+
@test zrtest(mean(abs.(Csmall_trsm_cur - Csmall_trsm[i] )), 1e-2*χsmall^1.2, "20_trsm_$T") # Large TRSM err on random A.
241259

242260
# Check - strided.
243-
@test zrtest(mean(abs.(Cst_lg_gemm_cur - Cst_lg_gemm[i])), 1e-6*χlarge^1.2, "gemm")
244-
@test zrtest(mean(abs.(Cst_sm_gemm_cur - Cst_sm_gemm[i])), 1e-6*χsmall^1.2, "gemm")
245-
@test zrtest(mean(abs.(Cst_lg_hemm_cur - Cst_lg_hemm[i])), 1e-6*χlarge^1.2, "hemm")
246-
@test zrtest(mean(abs.(Cst_sm_hemm_cur - Cst_sm_hemm[i])), 1e-6*χsmall^1.2, "hemm")
247-
@test zrtest(mean(abs.(Cst_lg_symm_cur - Cst_lg_symm[i])), 1e-6*χlarge^1.2, "symm")
248-
@test zrtest(mean(abs.(Cst_sm_symm_cur - Cst_sm_symm[i])), 1e-6*χsmall^1.2, "symm")
261+
@test zrtest(mean(abs.(Cst_lg_gemm_cur - Cst_lg_gemm[i])), 1e-6*χlarge^1.2, "250_rs2_gemm_$T")
262+
@test zrtest(mean(abs.(Cst_sm_gemm_cur - Cst_sm_gemm[i])), 1e-6*χsmall^1.2, "250_rs2_gemm_$T")
263+
@test zrtest(mean(abs.(Cst_lg_hemm_cur - Cst_lg_hemm[i])), 1e-6*χlarge^1.2, "250_rs2_hemm_$T")
264+
@test zrtest(mean(abs.(Cst_sm_hemm_cur - Cst_sm_hemm[i])), 1e-6*χsmall^1.2, "250_rs2_hemm_$T")
265+
@test zrtest(mean(abs.(Cst_lg_symm_cur - Cst_lg_symm[i])), 1e-6*χlarge^1.2, "250_rs2_symm_$T")
266+
@test zrtest(mean(abs.(Cst_sm_symm_cur - Cst_sm_symm[i])), 1e-6*χsmall^1.2, "250_rs2_symm_$T")
249267
end
250268
end
251269

0 commit comments

Comments
 (0)