|
89 | 89 | α = rand(T) |
90 | 90 | rocBLAS.gemv!('N', α, dA, dx, T(0), dy) |
91 | 91 | @test α * A * x ≈ Array(dy) |
| 92 | + A = rand(T, m, n) |
| 93 | + x = rand(T, n) |
| 94 | + dA, dx = ROCArray.((A, x)) |
| 95 | + dy = rocBLAS.gemv('N', α, dA, dx) |
| 96 | + @test α * A * x ≈ Array(dy) |
| 97 | + dy = rocBLAS.gemv('N', dA, dx) |
| 98 | + @test A * x ≈ Array(dy) |
92 | 99 | end |
93 | 100 |
|
94 | 101 | @testset "mul! y = $f(A) * x * $Ts(a) + y * $Ts(b)" for f in ( |
|
106 | 113 | @test testf( |
107 | 114 | (y, a, b) -> mul!(y, Hermitian(a), b), rand(T, 5), |
108 | 115 | rand(T, 5, 5), rand(T, 5)) |
| 116 | + |
| 117 | + A_ = rand(T, m, m) |
| 118 | + A = A_ + A_' |
| 119 | + x = rand(T, m) |
| 120 | + y = zeros(T, m) |
| 121 | + dA, dx, dy = ROCArray.((A, x, y)) |
| 122 | + α = rand(T) |
| 123 | + if T <: Real |
| 124 | + rocBLAS.symv!('U', α, dA, dx, T(0), dy) |
| 125 | + @test α * A * x ≈ Array(dy) |
| 126 | + dy = rocBLAS.symv('U', α, dA, dx) |
| 127 | + @test α * A * x ≈ Array(dy) |
| 128 | + dy = rocBLAS.symv('U', dA, dx) |
| 129 | + @test A * x ≈ Array(dy) |
| 130 | + else |
| 131 | + rocBLAS.hemv!('U', α, dA, dx, T(0), dy) |
| 132 | + @test α * A * x ≈ Array(dy) |
| 133 | + dy = rocBLAS.hemv('U', α, dA, dx) |
| 134 | + @test α * A * x ≈ Array(dy) |
| 135 | + dy = rocBLAS.hemv('U', dA, dx) |
| 136 | + @test A * x ≈ Array(dy) |
| 137 | + end |
109 | 138 | end |
110 | 139 |
|
111 | 140 | A = rand(T, m, m) |
|
137 | 166 | ) |
138 | 167 | @test testf(a -> inv(TR(a)), x) |
139 | 168 | end |
| 169 | + |
| 170 | + @testset "ger!" begin |
| 171 | + A = rand(T, m, m) |
| 172 | + x = rand(T, m) |
| 173 | + y = rand(T, m) |
| 174 | + dA = ROCArray(A) |
| 175 | + dx = ROCArray(x) |
| 176 | + dy = ROCArray(y) |
| 177 | + # perform rank one update |
| 178 | + dB = copy(dA) |
| 179 | + rocBLAS.ger!(alpha,dx,dy,dB) |
| 180 | + B = (alpha*x)*y' + A |
| 181 | + # move to host and compare |
| 182 | + hB = Array(dB) |
| 183 | + @test B ≈ hB |
| 184 | + end |
| 185 | + |
| 186 | + @testset "syr!" begin |
| 187 | + sA = rand(T, m, m) |
| 188 | + sA = sA + transpose(sA) |
| 189 | + x = rand(T,m) |
| 190 | + dx = ROCArray(x) |
| 191 | + dB = ROCArray(sA) |
| 192 | + rocBLAS.syr!('U',alpha,dx,dB) |
| 193 | + B = (alpha*x)*transpose(x) + sA |
| 194 | + # move to host and compare upper triangles |
| 195 | + hB = Array(dB) |
| 196 | + B = triu(B) |
| 197 | + hB = triu(hB) |
| 198 | + @test B ≈ hB |
| 199 | + end |
| 200 | + |
| 201 | + if T <: Complex |
| 202 | + @testset "her" begin |
| 203 | + hA = rand(T,m,m) |
| 204 | + hA = hA + adjoint(hA) |
| 205 | + dB = ROCArray(hA) |
| 206 | + x = rand(T,m) |
| 207 | + dx = ROCArray(x) |
| 208 | + # perform rank one update |
| 209 | + rocBLAS.her!('U',real(alpha),dx,dB) |
| 210 | + B = (real(alpha)*x)*x' + hA |
| 211 | + # move to host and compare upper triangles |
| 212 | + hB = Array(dB) |
| 213 | + B = triu(B) |
| 214 | + hB = triu(hB) |
| 215 | + @test B ≈ hB |
| 216 | + end |
| 217 | + @testset "her2!" begin |
| 218 | + hA = rand(T,m,m) |
| 219 | + hA = hA + adjoint(hA) |
| 220 | + x = rand(T, m) |
| 221 | + y = rand(T,m) |
| 222 | + dB = ROCArray(hA) |
| 223 | + dx = ROCArray(x) |
| 224 | + dy = ROCArray(y) |
| 225 | + rocBLAS.her2!('U',alpha,dx,dy,dB) |
| 226 | + B = (alpha*x)*y' + y*(alpha*x)' + hA |
| 227 | + # move to host and compare upper triangles |
| 228 | + hB = Array(dB) |
| 229 | + B = triu(B) |
| 230 | + hB = triu(hB) |
| 231 | + @test B ≈ hB |
| 232 | + end |
| 233 | + end |
140 | 234 | end |
141 | 235 | end |
142 | 236 |
|
|
156 | 250 | @test testf( |
157 | 251 | (c, a, b) -> mul!(c, Hermitian(a), b), |
158 | 252 | rand(T, 5, 5), Hermitian(rand(T, 5, 5)), rand(T, 5, 5)) |
| 253 | + A_ = rand(T, m, m) |
| 254 | + A = A_ + A_' |
| 255 | + B = rand(T, m, n) |
| 256 | + dA, dB = ROCArray.((A, B)) |
| 257 | + α = rand(T) |
| 258 | + if T <: Real |
| 259 | + dC = rocBLAS.symm('L', 'U', α, dA, dB) |
| 260 | + @test α * A * B ≈ Array(dC) |
| 261 | + dC = rocBLAS.symm('L', 'U', dA, dB) |
| 262 | + @test A * B ≈ Array(dC) |
| 263 | + else |
| 264 | + dC = rocBLAS.hemm('L', 'U', α, dA, dB) |
| 265 | + @test α * A * B ≈ Array(dC) |
| 266 | + dC = rocBLAS.hemm('L', 'U', dA, dB) |
| 267 | + @test A * B ≈ Array(dC) |
| 268 | + end |
159 | 269 | end |
160 | 270 |
|
161 | 271 | @testset "trsm ($T, $adjtype, $uplotype)" for adjtype in ( |
@@ -286,6 +396,122 @@ end |
286 | 396 | end |
287 | 397 | end |
288 | 398 | end |
| 399 | + @testset "syrk T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64) |
| 400 | + # generate parameters |
| 401 | + α = rand(T) |
| 402 | + β = rand(T) |
| 403 | + A = rand(T, m, m) |
| 404 | + Abad = rand(T, m + 1, m + 1) |
| 405 | + C = rand(T, m, m) |
| 406 | + # move to device |
| 407 | + d_A, d_Abad = ROCArray(A), ROCArray(Abad) |
| 408 | + C = C + transpose(C) |
| 409 | + d_C = ROCArray(C) |
| 410 | + C = α*(A*transpose(A)) + β*C |
| 411 | + rocBLAS.syrk!('U','N',α,d_A,β,d_C) |
| 412 | + # move back to host and compare |
| 413 | + C = triu(C) |
| 414 | + h_C = Array(d_C) |
| 415 | + h_C = triu(h_C) |
| 416 | + @test C ≈ h_C |
| 417 | + @test_throws DimensionMismatch rocBLAS.syrk!('U','N',α,d_Abad,β,d_C) |
| 418 | + |
| 419 | + d_C = rocBLAS.syrk('U','N',d_A) |
| 420 | + C = triu(A*transpose(A)) |
| 421 | + h_C = Array(d_C) |
| 422 | + h_C = triu(h_C) |
| 423 | + @test C ≈ h_C |
| 424 | + end |
| 425 | + @testset "syr2k T=$T" for T in (Float32, Float64, ComplexF32, ComplexF64) |
| 426 | + # generate parameters |
| 427 | + α = rand(T) |
| 428 | + β = rand(T) |
| 429 | + A = rand(T, m, k) |
| 430 | + B = rand(T, m, k) |
| 431 | + Bbad = rand(T, m + 1, k + 1) |
| 432 | + C = rand(T, m, m) |
| 433 | + # move to device |
| 434 | + d_A, d_B, d_Bbad = ROCArray(A), ROCArray(B), ROCArray(Bbad) |
| 435 | + C = C + transpose(C) |
| 436 | + d_C = ROCArray(C) |
| 437 | + C = α*(A*transpose(B) + B*transpose(A)) + β*C |
| 438 | + rocBLAS.syr2k!('U','N',α,d_A,d_B,β,d_C) |
| 439 | + # move back to host and compare |
| 440 | + C = triu(C) |
| 441 | + h_C = Array(d_C) |
| 442 | + h_C = triu(h_C) |
| 443 | + @test C ≈ h_C |
| 444 | + @test_throws DimensionMismatch rocBLAS.syr2k!('U','N',α,d_A,d_Bbad,β,d_C) |
| 445 | + Bbad = rand(T, m, k + 1) |
| 446 | + d_Bbad = ROCArray(Bbad) |
| 447 | + @test_throws DimensionMismatch rocBLAS.syr2k!('U','N',α,d_A,d_Bbad,β,d_C) |
| 448 | + |
| 449 | + d_C = rocBLAS.syr2k('U','N',d_A,d_B) |
| 450 | + C = triu((A*transpose(B)) + (B*transpose(A))) |
| 451 | + h_C = Array(d_C) |
| 452 | + h_C = triu(h_C) |
| 453 | + @test C ≈ h_C |
| 454 | + end |
| 455 | + @testset "herk T=$T" for T in (ComplexF32, ComplexF64) |
| 456 | + T1 = T |
| 457 | + T2 = real(T) |
| 458 | + # generate parameters |
| 459 | + α = rand(T2) |
| 460 | + β = rand(T2) |
| 461 | + A = rand(T, m, m) |
| 462 | + Abad = rand(T, m + 1, m + 1) |
| 463 | + C = rand(T, m, m) |
| 464 | + # move to device |
| 465 | + d_A, d_Abad = ROCArray(A), ROCArray(Abad) |
| 466 | + C = C + C' |
| 467 | + d_C = ROCArray(C) |
| 468 | + C = α*(A*A') + β*C |
| 469 | + rocBLAS.herk!('U','N',α,d_A,β,d_C) |
| 470 | + # move back to host and compare |
| 471 | + C = triu(C) |
| 472 | + h_C = Array(d_C) |
| 473 | + h_C = triu(h_C) |
| 474 | + @test C ≈ h_C |
| 475 | + @test_throws DimensionMismatch rocBLAS.herk!('U','N',α,d_Abad,β,d_C) |
| 476 | + |
| 477 | + d_C = rocBLAS.herk('U','N',d_A) |
| 478 | + C = triu(A*A') |
| 479 | + h_C = Array(d_C) |
| 480 | + h_C = triu(h_C) |
| 481 | + @test C ≈ h_C |
| 482 | + end |
| 483 | + @testset "her2k T=$T" for T in (ComplexF32, ComplexF64) |
| 484 | + T1 = T |
| 485 | + T2 = real(T) |
| 486 | + # generate parameters |
| 487 | + α = rand(T1) |
| 488 | + β = rand(T2) |
| 489 | + A = rand(T, m, k) |
| 490 | + B = rand(T, m, k) |
| 491 | + Bbad = rand(T, m + 1, k + 1) |
| 492 | + C = rand(T, m, m) |
| 493 | + # move to device |
| 494 | + d_A, d_B, d_Bbad = ROCArray(A), ROCArray(B), ROCArray(Bbad) |
| 495 | + C = C + C' |
| 496 | + d_C = ROCArray(C) |
| 497 | + C = α*(A*B') + conj(α)*(B*A') + β*C |
| 498 | + rocBLAS.her2k!('U','N',α,d_A,d_B,β,d_C) |
| 499 | + # move back to host and compare |
| 500 | + C = triu(C) |
| 501 | + h_C = Array(d_C) |
| 502 | + h_C = triu(h_C) |
| 503 | + @test C ≈ h_C |
| 504 | + @test_throws DimensionMismatch rocBLAS.her2k!('U','N',α,d_A,d_Bbad,β,d_C) |
| 505 | + Bbad = rand(T, m, k + 1) |
| 506 | + d_Bbad = ROCArray(Bbad) |
| 507 | + @test_throws DimensionMismatch rocBLAS.her2k!('U','N',α,d_A,d_Bbad,β,d_C) |
| 508 | + |
| 509 | + d_C = rocBLAS.her2k('U','N',d_A,d_B) |
| 510 | + C = triu((A*B') + (B*A')) |
| 511 | + h_C = Array(d_C) |
| 512 | + h_C = triu(h_C) |
| 513 | + @test C ≈ h_C |
| 514 | + end |
289 | 515 | end |
290 | 516 |
|
291 | 517 | @testset "Extension" begin |
|
338 | 564 | d_YA = rmul!(d_YA, d_Y) |
339 | 565 | @test Array(d_YA) ≈ adjoint(AY) * Diagonal(y) |
340 | 566 | end |
| 567 | + @testset "geam" begin |
| 568 | + m = 4 |
| 569 | + for T in (Float32, Float64, ComplexF32, ComplexF64) |
| 570 | + for at in ('N', 'T'), bt in ('N', 'T') |
| 571 | + A = rand(T, m, m) |
| 572 | + B = rand(T, m, m) |
| 573 | + RA = ROCArray(A) |
| 574 | + RB = ROCArray(B) |
| 575 | + α = rand(T) |
| 576 | + β = rand(T) |
| 577 | + RC = rocBLAS.geam(at, bt, α, RA, β, RB) |
| 578 | + C = |
| 579 | + α * (at == 'T' ? transpose(A) : A) + |
| 580 | + β * (bt == 'T' ? transpose(B) : B) |
| 581 | + @test Array(RC) ≈ C |
| 582 | + end |
| 583 | + end |
| 584 | + end |
| 585 | + |
341 | 586 | end |
342 | 587 |
|
343 | 588 | end |
0 commit comments