Skip to content

Commit be74c17

Browse files
committed
Use Octavian.jl for large mixed-mode CPU calculations.
1 parent 781f1de commit be74c17

File tree

2 files changed

+28
-8
lines changed

2 files changed

+28
-8
lines changed

test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,5 +5,6 @@ Distributed = "8ba89e20-285c-5b6f-9357-94700520ee1b"
55
ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
66
InteractiveUtils = "b77e0a4c-d291-57a0-90e8-8db25a27a240"
77
LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e"
8+
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
89
Test = "8dfed614-e22c-5e08-85e1-65c5234f0b40"
910
XUnit = "3e3c03f2-1a94-11e9-2981-050a4ca824ab"

test/matmul.jl

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,26 @@
11
using CUDA
22
using ForwardDiff
33
using GemmKernels
4-
using LinearAlgebra
4+
import Octavian, LinearAlgebra
5+
6+
# for large, non-BLAS-compatible matrices, use Octavian.
7+
matmul!(C, A, B, alpha=true, beta=false) = LinearAlgebra.mul!(C, A, B, alpha, beta)
8+
function matmul!(C::Array,
9+
A::Union{Array, LinearAlgebra.Transpose{<:Any, <:Array},
10+
LinearAlgebra.Adjoint{<:Any, <:Array}},
11+
B::Union{Array, LinearAlgebra.Transpose{<:Any, <:Array},
12+
LinearAlgebra.Adjoint{<:Any, <:Array}},
13+
alpha::Bool=true, beta::Bool=false)
14+
supported = eltype(C) <: LinearAlgebra.BlasFloat &&
15+
eltype(A) <: LinearAlgebra.BlasFloat &&
16+
eltype(B) <: LinearAlgebra.BlasFloat &&
17+
eltype(C) == eltype(A) == eltype(B)
18+
if !supported && (sizeof(C) > 2^20 || sizeof(A) > 2^20 || sizeof(B) > 2^20)
19+
Octavian.matmul!(C, A, B, alpha, beta)
20+
else
21+
LinearAlgebra.mul!(C, A, B, alpha, beta)
22+
end
23+
end
524

625
################################################################################
726

@@ -63,7 +82,7 @@ using LinearAlgebra
6382
new_a_h = transpose_a ? transpose(a_h) : a_h
6483
new_b_h = transpose_b ? transpose(b_h) : b_h
6584

66-
mul!(c_h, new_a_h, new_b_h, alpha, beta)
85+
matmul!(c_h, new_a_h, new_b_h, alpha, beta)
6786
if A_type <: Integer
6887
@test c_h Array(d)
6988
else
@@ -121,7 +140,7 @@ using LinearAlgebra
121140
new_a_h = transpose_a ? transpose(a_h) : a_h
122141
new_b_h = transpose_b ? transpose(b_h) : b_h
123142

124-
mul!(c_h, new_a_h, new_b_h, alpha, beta)
143+
matmul!(c_h, new_a_h, new_b_h, alpha, beta)
125144
@test c_h Array(d) rtol=sqrt(eps(A_type))
126145
end
127146
end
@@ -222,7 +241,7 @@ using LinearAlgebra
222241
new_a_h = transpose_a ? transpose(a_h) : a_h
223242
new_b_h = transpose_b ? transpose(b_h) : b_h
224243

225-
mul!(c_h, new_a_h, new_b_h, alpha, beta)
244+
matmul!(c_h, new_a_h, new_b_h, alpha, beta)
226245
@test c_h Array(d) rtol=sqrt(eps(AB_type))
227246
end
228247
end
@@ -274,7 +293,7 @@ using LinearAlgebra
274293
new_a_h = transpose_a ? transpose(a_h) : a_h
275294
new_b_h = transpose_b ? transpose(b_h) : b_h
276295

277-
mul!(c_h, new_a_h, new_b_h, true, true)
296+
matmul!(c_h, new_a_h, new_b_h, true, true)
278297
@test c_h .+ Array(bias) Array(d) rtol=sqrt(eps(Float16))
279298
end
280299
end
@@ -319,7 +338,7 @@ using LinearAlgebra
319338
new_a_h = transpose_a ? transpose(a_h) : a_h
320339
new_b_h = transpose_b ? transpose(b_h) : b_h
321340

322-
mul!(c_h, Diagonal(new_a_h), new_b_h, true, true)
341+
matmul!(c_h, Diagonal(new_a_h), new_b_h, true, true)
323342
@test c_h Array(d) rtol=sqrt(eps(Float16))
324343
end
325344
end
@@ -383,7 +402,7 @@ using LinearAlgebra
383402
new_a_h = transpose_a ? transpose(new_a_h) : new_a_h
384403
new_b_h = transpose_b ? transpose(new_b_h) : new_b_h
385404

386-
mul!(c_h, new_a_h, new_b_h, true, true)
405+
matmul!(c_h, new_a_h, new_b_h, true, true)
387406
@test c_h Array(d) rtol=sqrt(eps(Float16))
388407
end
389408
end
@@ -436,7 +455,7 @@ using LinearAlgebra
436455
c_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, c_h)
437456
d_dual = reinterpret(ForwardDiff.Dual{Float32,Float32,1}, Array(d))
438457

439-
mul!(c_dual, a_dual, b_dual, true, true)
458+
matmul!(c_dual, a_dual, b_dual, true, true)
440459
@test c_dual d_dual rtol=sqrt(eps(Float16))
441460
end
442461
end

0 commit comments

Comments
 (0)