Skip to content

Commit c2f4472

Browse files
author
Katharine Hyatt
committed
A few more blas tests and fix
1 parent 4cb0eec commit c2f4472

File tree

2 files changed

+258
-13
lines changed

2 files changed

+258
-13
lines changed

src/blas/wrappers.jl

Lines changed: 13 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -496,10 +496,10 @@ for (fname, elty) in ((:rocblas_dsyr,:Float64),
496496
end
497497

498498
### her
499-
for (fname, elty) in ((:rocblas_zher,:ComplexF64),
500-
(:rocblas_cher,:ComplexF32))
499+
for (fname, elty, relty) in ((:rocblas_zher,:ComplexF64,:Float64),
500+
(:rocblas_cher,:ComplexF32,:Float32))
501501
@eval begin
502-
function her!(uplo::Char, alpha::$elty, x::ROCVector{$elty}, A::ROCMatrix{$elty})
502+
function her!(uplo::Char, alpha::$relty, x::ROCVector{$elty}, A::ROCMatrix{$elty})
503503
m, n = size(A)
504504
m == n || throw(DimensionMismatch("Matrix A is $m by $n but must be square"))
505505
length(x) == n || throw(DimensionMismatch("Length of vector must be the same as the matrix dimensions"))
@@ -863,12 +863,12 @@ for (fname, elty) in ((:rocblas_zhemm,:ComplexF64),
863863
end
864864

865865
## herk
866-
for (fname, elty) in ((:rocblas_zherk,:ComplexF64),
867-
(:rocblas_cherk,:ComplexF32))
866+
for (fname, elty, relty) in ((:rocblas_zherk,:ComplexF64,:Float64),
867+
(:rocblas_cherk,:ComplexF32,:Float32))
868868
@eval begin
869869
function herk!(
870-
uplo::Char, trans::Char, alpha::($elty), A::ROCVecOrMat{$elty},
871-
beta::($elty), C::ROCMatrix{$elty},
870+
uplo::Char, trans::Char, alpha::($relty), A::ROCVecOrMat{$elty},
871+
beta::($relty), C::ROCMatrix{$elty},
872872
)
873873
mC, n = size(C)
874874
if mC != n throw(DimensionMismatch("C must be square")) end
@@ -881,12 +881,12 @@ for (fname, elty) in ((:rocblas_zherk,:ComplexF64),
881881
$(fname)(handle, uplo, trans, n, k, Ref(alpha), A, lda, Ref(beta), C, ldc)
882882
C
883883
end
884-
function herk(uplo::Char, trans::Char, alpha::($elty), A::ROCVecOrMat{$elty})
884+
function herk(uplo::Char, trans::Char, alpha::($relty), A::ROCVecOrMat{$elty})
885885
n = size(A, trans == 'N' ? 1 : 2)
886-
herk!(uplo, trans, alpha, A, zero($elty), similar(A, $elty, (n,n)))
886+
herk!(uplo, trans, alpha, A, zero($relty), similar(A, $elty, (n,n)))
887887
end
888888
herk(uplo::Char, trans::Char, A::ROCVecOrMat{$elty}) =
889-
herk(uplo, trans, one($elty), A)
889+
herk(uplo, trans, one($relty), A)
890890
end
891891
end
892892

@@ -1092,13 +1092,13 @@ for (fname, elty) in ((:rocblas_dgeam,:Float64),
10921092
)
10931093
m,n = size(B)
10941094
if ((transb == 'T' || transb == 'C'))
1095-
geam!( transa, transb, alpha, A, beta, B, similar(B, $elty, (n,m) ) )
1095+
return geam!( transa, transb, alpha, A, beta, B, similar(B, $elty, (n,m) ) )
10961096
end
10971097
if (transb == 'N')
1098-
geam!( transa, transb, alpha, A, beta, B, similar(B, $elty, (m,n) ) )
1098+
return geam!( transa, transb, alpha, A, beta, B, similar(B, $elty, (m,n) ) )
10991099
end
11001100
end
1101-
geam( uplo::Char, trans::Char, A::ROCMatrix{$elty}, B::ROCMatrix{$elty}) = geam( uplo, trans, one($elty), A, one($elty), B)
1101+
geam( transa::Char, transb::Char, A::ROCMatrix{$elty}, B::ROCMatrix{$elty}) = geam( transa, transb, one($elty), A, one($elty), B)
11021102
end
11031103
end
11041104

test/rocarray/blas.jl

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -89,6 +89,13 @@ end
8989
α = rand(T)
9090
rocBLAS.gemv!('N', α, dA, dx, T(0), dy)
9191
@test α * A * x Array(dy)
92+
A = rand(T, m, n)
93+
x = rand(T, n)
94+
dA, dx = ROCArray.((A, x))
95+
dy = rocBLAS.gemv('N', α, dA, dx)
96+
@test α * A * x Array(dy)
97+
dy = rocBLAS.gemv('N', dA, dx)
98+
@test A * x Array(dy)
9299
end
93100

94101
@testset "mul! y = $f(A) * x * $Ts(a) + y * $Ts(b)" for f in (
@@ -106,6 +113,28 @@ end
106113
@test testf(
107114
(y, a, b) -> mul!(y, Hermitian(a), b), rand(T, 5),
108115
rand(T, 5, 5), rand(T, 5))
116+
117+
A_ = rand(T, m, m)
118+
A = A_ + A_'
119+
x = rand(T, m)
120+
y = zeros(T, m)
121+
dA, dx, dy = ROCArray.((A, x, y))
122+
α = rand(T)
123+
if T <: Real
124+
rocBLAS.symv!('U', α, dA, dx, T(0), dy)
125+
@test α * A * x Array(dy)
126+
dy = rocBLAS.symv('U', α, dA, dx)
127+
@test α * A * x Array(dy)
128+
dy = rocBLAS.symv('U', dA, dx)
129+
@test A * x Array(dy)
130+
else
131+
rocBLAS.hemv!('U', α, dA, dx, T(0), dy)
132+
@test α * A * x Array(dy)
133+
dy = rocBLAS.hemv('U', α, dA, dx)
134+
@test α * A * x Array(dy)
135+
dy = rocBLAS.hemv('U', dA, dx)
136+
@test A * x Array(dy)
137+
end
109138
end
110139

111140
A = rand(T, m, m)
@@ -137,6 +166,71 @@ end
137166
)
138167
@test testf(a -> inv(TR(a)), x)
139168
end
169+
170+
@testset "ger!" begin
171+
A = rand(T, m, m)
172+
x = rand(T, m)
173+
y = rand(T, m)
174+
dA = ROCArray(A)
175+
dx = ROCArray(x)
176+
dy = ROCArray(y)
177+
# perform rank one update
178+
dB = copy(dA)
179+
rocBLAS.ger!(alpha,dx,dy,dB)
180+
B = (alpha*x)*y' + A
181+
# move to host and compare
182+
hB = Array(dB)
183+
@test B hB
184+
end
185+
186+
@testset "syr!" begin
187+
sA = rand(T, m, m)
188+
sA = sA + transpose(sA)
189+
x = rand(T,m)
190+
dx = ROCArray(x)
191+
dB = ROCArray(sA)
192+
rocBLAS.syr!('U',alpha,dx,dB)
193+
B = (alpha*x)*transpose(x) + sA
194+
# move to host and compare upper triangles
195+
hB = Array(dB)
196+
B = triu(B)
197+
hB = triu(hB)
198+
@test B hB
199+
end
200+
201+
if T <: Complex
202+
@testset "her" begin
203+
hA = rand(T,m,m)
204+
hA = hA + adjoint(hA)
205+
dB = ROCArray(hA)
206+
x = rand(T,m)
207+
dx = ROCArray(x)
208+
# perform rank one update
209+
rocBLAS.her!('U',real(alpha),dx,dB)
210+
B = (real(alpha)*x)*x' + hA
211+
# move to host and compare upper triangles
212+
hB = Array(dB)
213+
B = triu(B)
214+
hB = triu(hB)
215+
@test B hB
216+
end
217+
@testset "her2!" begin
218+
hA = rand(T,m,m)
219+
hA = hA + adjoint(hA)
220+
x = rand(T, m)
221+
y = rand(T,m)
222+
dB = ROCArray(hA)
223+
dx = ROCArray(x)
224+
dy = ROCArray(y)
225+
rocBLAS.her2!('U',alpha,dx,dy,dB)
226+
B = (alpha*x)*y' + y*(alpha*x)' + hA
227+
# move to host and compare upper triangles
228+
hB = Array(dB)
229+
B = triu(B)
230+
hB = triu(hB)
231+
@test B hB
232+
end
233+
end
140234
end
141235
end
142236

@@ -156,6 +250,22 @@ end
156250
@test testf(
157251
(c, a, b) -> mul!(c, Hermitian(a), b),
158252
rand(T, 5, 5), Hermitian(rand(T, 5, 5)), rand(T, 5, 5))
253+
A_ = rand(T, m, m)
254+
A = A_ + A_'
255+
B = rand(T, m, n)
256+
dA, dB = ROCArray.((A, B))
257+
α = rand(T)
258+
if T <: Real
259+
dC = rocBLAS.symm('L', 'U', α, dA, dB)
260+
@test α * A * B Array(dC)
261+
dC = rocBLAS.symm('L', 'U', dA, dB)
262+
@test A * B Array(dC)
263+
else
264+
dC = rocBLAS.hemm('L', 'U', α, dA, dB)
265+
@test α * A * B Array(dC)
266+
dC = rocBLAS.hemm('L', 'U', dA, dB)
267+
@test A * B Array(dC)
268+
end
159269
end
160270

161271
@testset "trsm ($T, $adjtype, $uplotype)" for adjtype in (
@@ -286,6 +396,122 @@ end
286396
end
287397
end
288398
end
399+
@testset "syrk T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
400+
# generate parameters
401+
α = rand(T)
402+
β = rand(T)
403+
A = rand(T, m, m)
404+
Abad = rand(T, m + 1, m + 1)
405+
C = rand(T, m, m)
406+
# move to device
407+
d_A, d_Abad = ROCArray(A), ROCArray(Abad)
408+
C = C + transpose(C)
409+
d_C = ROCArray(C)
410+
C = α*(A*transpose(A)) + β*C
411+
rocBLAS.syrk!('U','N',α,d_A,β,d_C)
412+
# move back to host and compare
413+
C = triu(C)
414+
h_C = Array(d_C)
415+
h_C = triu(h_C)
416+
@test C h_C
417+
@test_throws DimensionMismatch rocBLAS.syrk!('U','N',α,d_Abad,β,d_C)
418+
419+
d_C = rocBLAS.syrk('U','N',d_A)
420+
C = triu(A*transpose(A))
421+
h_C = Array(d_C)
422+
h_C = triu(h_C)
423+
@test C h_C
424+
end
425+
@testset "syr2k T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64)
426+
# generate parameters
427+
α = rand(T)
428+
β = rand(T)
429+
A = rand(T, m, k)
430+
B = rand(T, m, k)
431+
Bbad = rand(T, m + 1, k + 1)
432+
C = rand(T, m, m)
433+
# move to device
434+
d_A, d_B, d_Bbad = ROCArray(A), ROCArray(B), ROCArray(Bbad)
435+
C = C + transpose(C)
436+
d_C = ROCArray(C)
437+
C = α*(A*transpose(B) + B*transpose(A)) + β*C
438+
rocBLAS.syr2k!('U','N',α,d_A,d_B,β,d_C)
439+
# move back to host and compare
440+
C = triu(C)
441+
h_C = Array(d_C)
442+
h_C = triu(h_C)
443+
@test C h_C
444+
@test_throws DimensionMismatch rocBLAS.syr2k!('U','N',α,d_A,d_Bbad,β,d_C)
445+
Bbad = rand(T, m, k + 1)
446+
d_Bbad = ROCArray(Bbad)
447+
@test_throws DimensionMismatch rocBLAS.syr2k!('U','N',α,d_A,d_Bbad,β,d_C)
448+
449+
d_C = rocBLAS.syr2k('U','N',d_A,d_B)
450+
C = triu((A*transpose(B)) + (B*transpose(A)))
451+
h_C = Array(d_C)
452+
h_C = triu(h_C)
453+
@test C h_C
454+
end
455+
@testset "herk T=$T" for T in (ComplexF32, ComplexF64)
456+
T1 = T
457+
T2 = real(T)
458+
# generate parameters
459+
α = rand(T2)
460+
β = rand(T2)
461+
A = rand(T, m, m)
462+
Abad = rand(T, m + 1, m + 1)
463+
C = rand(T, m, m)
464+
# move to device
465+
d_A, d_Abad = ROCArray(A), ROCArray(Abad)
466+
C = C + C'
467+
d_C = ROCArray(C)
468+
C = α*(A*A') + β*C
469+
rocBLAS.herk!('U','N',α,d_A,β,d_C)
470+
# move back to host and compare
471+
C = triu(C)
472+
h_C = Array(d_C)
473+
h_C = triu(h_C)
474+
@test C h_C
475+
@test_throws DimensionMismatch rocBLAS.herk!('U','N',α,d_Abad,β,d_C)
476+
477+
d_C = rocBLAS.herk('U','N',d_A)
478+
C = triu(A*A')
479+
h_C = Array(d_C)
480+
h_C = triu(h_C)
481+
@test C h_C
482+
end
483+
@testset "her2k T=$T" for T in (ComplexF32, ComplexF64)
484+
T1 = T
485+
T2 = real(T)
486+
# generate parameters
487+
α = rand(T1)
488+
β = rand(T2)
489+
A = rand(T, m, k)
490+
B = rand(T, m, k)
491+
Bbad = rand(T, m + 1, k + 1)
492+
C = rand(T, m, m)
493+
# move to device
494+
d_A, d_B, d_Bbad = ROCArray(A), ROCArray(B), ROCArray(Bbad)
495+
C = C + C'
496+
d_C = ROCArray(C)
497+
C = α*(A*B') + conj(α)*(B*A') + β*C
498+
rocBLAS.her2k!('U','N',α,d_A,d_B,β,d_C)
499+
# move back to host and compare
500+
C = triu(C)
501+
h_C = Array(d_C)
502+
h_C = triu(h_C)
503+
@test C h_C
504+
@test_throws DimensionMismatch rocBLAS.her2k!('U','N',α,d_A,d_Bbad,β,d_C)
505+
Bbad = rand(T, m, k + 1)
506+
d_Bbad = ROCArray(Bbad)
507+
@test_throws DimensionMismatch rocBLAS.her2k!('U','N',α,d_A,d_Bbad,β,d_C)
508+
509+
d_C = rocBLAS.her2k('U','N',d_A,d_B)
510+
C = triu((A*B') + (B*A'))
511+
h_C = Array(d_C)
512+
h_C = triu(h_C)
513+
@test C h_C
514+
end
289515
end
290516

291517
@testset "Extension" begin
@@ -338,6 +564,25 @@ end
338564
d_YA = rmul!(d_YA, d_Y)
339565
@test Array(d_YA) adjoint(AY) * Diagonal(y)
340566
end
567+
@testset "geam" begin
568+
m = 4
569+
for T in (Float32, Float64, ComplexF32, ComplexF64)
570+
for at in ('N', 'T'), bt in ('N', 'T')
571+
A = rand(T, m, m)
572+
B = rand(T, m, m)
573+
RA = ROCArray(A)
574+
RB = ROCArray(B)
575+
α = rand(T)
576+
β = rand(T)
577+
RC = rocBLAS.geam(at, bt, α, RA, β, RB)
578+
C =
579+
α * (at == 'T' ? transpose(A) : A) +
580+
β * (bt == 'T' ? transpose(B) : B)
581+
@test Array(RC) C
582+
end
583+
end
584+
end
585+
341586
end
342587

343588
end

0 commit comments

Comments
 (0)