|
1 | 1 | using CUDA
|
2 | 2 | using ForwardDiff
|
3 | 3 | using GemmKernels
|
4 |
| -using LinearAlgebra |
| 4 | +import Octavian, LinearAlgebra |
| 5 | + |
| 6 | +# for large, non-BLAS-compatible matrices, use Octavian. |
| 7 | +matmul!(C, A, B, alpha=true, beta=false) = LinearAlgebra.mul!(C, A, B, alpha, beta) |
| 8 | +function matmul!(C::Array, |
| 9 | + A::Union{Array, LinearAlgebra.Transpose{<:Any, <:Array}, |
| 10 | + LinearAlgebra.Adjoint{<:Any, <:Array}}, |
| 11 | + B::Union{Array, LinearAlgebra.Transpose{<:Any, <:Array}, |
| 12 | + LinearAlgebra.Adjoint{<:Any, <:Array}}, |
| 13 | + alpha::Bool=true, beta::Bool=false) |
| 14 | + supported = eltype(C) <: LinearAlgebra.BlasFloat && |
| 15 | + eltype(A) <: LinearAlgebra.BlasFloat && |
| 16 | + eltype(B) <: LinearAlgebra.BlasFloat && |
| 17 | + eltype(C) == eltype(A) == eltype(B) |
| 18 | + if !supported && (sizeof(C) > 2^20 || sizeof(A) > 2^20 || sizeof(B) > 2^20) |
| 19 | + Octavian.matmul!(C, A, B, alpha, beta) |
| 20 | + else |
| 21 | + LinearAlgebra.mul!(C, A, B, alpha, beta) |
| 22 | + end |
| 23 | +end |
5 | 24 |
|
6 | 25 | ################################################################################
|
7 | 26 |
|
@@ -63,7 +82,7 @@ using LinearAlgebra
|
63 | 82 | new_a_h = transpose_a ? transpose(a_h) : a_h
|
64 | 83 | new_b_h = transpose_b ? transpose(b_h) : b_h
|
65 | 84 |
|
66 |
| - mul!(c_h, new_a_h, new_b_h, alpha, beta) |
| 85 | + matmul!(c_h, new_a_h, new_b_h, alpha, beta) |
67 | 86 | if A_type <: Integer
|
68 | 87 | @test c_h ≈ Array(d)
|
69 | 88 | else
|
@@ -121,7 +140,7 @@ using LinearAlgebra
|
121 | 140 | new_a_h = transpose_a ? transpose(a_h) : a_h
|
122 | 141 | new_b_h = transpose_b ? transpose(b_h) : b_h
|
123 | 142 |
|
124 |
| - mul!(c_h, new_a_h, new_b_h, alpha, beta) |
| 143 | + matmul!(c_h, new_a_h, new_b_h, alpha, beta) |
125 | 144 | @test c_h ≈ Array(d) rtol=sqrt(eps(A_type))
|
126 | 145 | end
|
127 | 146 | end
|
@@ -222,7 +241,7 @@ using LinearAlgebra
|
222 | 241 | new_a_h = transpose_a ? transpose(a_h) : a_h
|
223 | 242 | new_b_h = transpose_b ? transpose(b_h) : b_h
|
224 | 243 |
|
225 |
| - mul!(c_h, new_a_h, new_b_h, alpha, beta) |
| 244 | + matmul!(c_h, new_a_h, new_b_h, alpha, beta) |
226 | 245 | @test c_h ≈ Array(d) rtol=sqrt(eps(AB_type))
|
227 | 246 | end
|
228 | 247 | end
|
@@ -274,7 +293,7 @@ using LinearAlgebra
|
274 | 293 | new_a_h = transpose_a ? transpose(a_h) : a_h
|
275 | 294 | new_b_h = transpose_b ? transpose(b_h) : b_h
|
276 | 295 |
|
277 |
| - mul!(c_h, new_a_h, new_b_h, true, true) |
| 296 | + matmul!(c_h, new_a_h, new_b_h, true, true) |
278 | 297 | @test c_h .+ Array(bias) ≈ Array(d) rtol=sqrt(eps(Float16))
|
279 | 298 | end
|
280 | 299 | end
|
@@ -319,7 +338,7 @@ using LinearAlgebra
|
319 | 338 | new_a_h = transpose_a ? transpose(a_h) : a_h
|
320 | 339 | new_b_h = transpose_b ? transpose(b_h) : b_h
|
321 | 340 |
|
322 |
| - mul!(c_h, Diagonal(new_a_h), new_b_h, true, true) |
| 341 | + matmul!(c_h, Diagonal(new_a_h), new_b_h, true, true) |
323 | 342 | @test c_h ≈ Array(d) rtol=sqrt(eps(Float16))
|
324 | 343 | end
|
325 | 344 | end
|
@@ -383,7 +402,7 @@ using LinearAlgebra
|
383 | 402 | new_a_h = transpose_a ? transpose(new_a_h) : new_a_h
|
384 | 403 | new_b_h = transpose_b ? transpose(new_b_h) : new_b_h
|
385 | 404 |
|
386 |
| - mul!(c_h, new_a_h, new_b_h, true, true) |
| 405 | + matmul!(c_h, new_a_h, new_b_h, true, true) |
387 | 406 | @test c_h ≈ Array(d) rtol=sqrt(eps(Float16))
|
388 | 407 | end
|
389 | 408 | end
|
@@ -436,7 +455,7 @@ using LinearAlgebra
|
436 | 455 | c_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, c_h)
|
437 | 456 | d_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, Array(d))
|
438 | 457 |
|
439 |
| - mul!(c_dual, a_dual, b_dual, true, true) |
| 458 | + matmul!(c_dual, a_dual, b_dual, true, true) |
440 | 459 | @test c_dual ≈ d_dual rtol=sqrt(eps(Float16))
|
441 | 460 | end
|
442 | 461 | end
|
|
0 commit comments