Skip to content

Commit beb1ba0

Browse files
committed
Algorithm selection
1 parent 1e9201c commit beb1ba0

File tree

3 files changed

+18
-6
lines changed

3 files changed

+18
-6
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2121
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2222
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2323
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
24+
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
2425
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2526
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2627

@@ -49,6 +50,7 @@ Preferences = "1"
4950
Printf = "1"
5051
Random = "1"
5152
SHA = "0.7"
53+
ScopedValues = "1.3.0"
5254
SpecialFunctions = "2"
5355
StaticArrays = "1"
5456
UUIDs = "1"

src/Metal.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using ExprTools: splitdef, combinedef
1212
using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS
1313
import ObjectiveC: is_macos, darwin_version, macos_version
1414
import KernelAbstractions
15+
using ScopedValues
1516

1617
include("version.jl")
1718

src/linalg.jl

Lines changed: 15 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -30,6 +30,9 @@ end
3030
eltype(A) <: AbstractFloat && rows <= 6000 && cols <= 6000 && Metal.supports_family(device(C), MTL.MTLGPUFamilyApple9)
3131
end
3232

33+
# Supported values are :auto, :MPS, :MPSGraph, and :GPUArrays
34+
const matmul_alg = ScopedValue(:auto)
35+
3336
LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatrix, _add::MulAddMul) =
3437
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
3538
@autoreleasepool function LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB,
@@ -55,13 +58,16 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri
5558
transA = tA == 'T' || tA == 'C'
5659
transB = tB == 'T' || tB == 'C'
5760

61+
alg = matmul_alg[]
5862
# If possible, dispatch to MPSGraphs, then performance shaders
59-
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES) && !should_use_MPS(A, B, C)
63+
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES) && (alg === :MPSGraph || (alg === :auto && !should_use_MPS(A, B, C)))
6064
graph_matmul!(C, A, B, alpha, beta, transA, transB)
61-
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES) # TODO: Remove once contiguous views are working
65+
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES) && (alg === :MPS || alg === :auto)
6266
matmul!(C, A, B, alpha, beta, transA, transB)
63-
else
67+
elseif alg === :GPUArrays || alg === :auto
6468
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
69+
else
70+
error("Invalid matmul algorithm and input combination.")
6571
end
6672
end
6773

@@ -90,13 +96,16 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B
9096

9197
transA = tA == 'T' || tA == 'C'
9298

99+
alg = matmul_alg[]
93100
# If possible, dispatch to MPSGraphs, then performance shaders
94-
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES)
101+
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES) && (alg === :MPSGraph || alg === :auto)
95102
graph_matvecmul!(C, A, B, alpha, beta, transA)
96-
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATVECMUL_TYPES) # TODO: Remove once contiguous views are working
103+
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATVECMUL_TYPES) && (alg === :MPS || alg === :auto)
97104
matvecmul!(C, A, B, alpha, beta, transA)
98-
else
105+
elseif alg === :GPUArrays || alg === :auto
99106
GPUArrays.generic_matmatmul!(C, wrap(A, tA), B, alpha, beta)
107+
else
108+
error("Invalid matmul algorithm and input combination.")
100109
end
101110
end
102111

0 commit comments

Comments
 (0)