Skip to content

Commit b026d69

Browse files
committed
Also use MPSGraph matmul for int -> float matmul
The only code left running with MPS are the contiguous views.
1 parent 5da5c46 commit b026d69

File tree

2 files changed

+5
-5
lines changed

2 files changed

+5
-5
lines changed

lib/mpsgraphs/MPSGraphs.jl

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,9 @@ using ObjectiveC, .Foundation, .Dispatch
2020
# The commented type combinations work but are slower than with MPSMatrixMultiplicatiom
2121
const MPSGRAPH_VALID_MATMUL_TYPES =
2222
[
23-
# (Int8, Float16),
24-
# (Int8, Float32),
25-
# (Int16, Float32),
23+
(Int8, Float16),
24+
(Int8, Float32),
25+
(Int16, Float32),
2626
(Float16, Float16),
2727
(Float16, Float32),
2828
(Float32, Float32),

src/linalg.jl

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -49,7 +49,7 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri
4949
# If possible, dispatch to MPSGraphs, then performance shaders
5050
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES)
5151
graph_matmul!(C, A, B, alpha, beta, transA, transB)
52-
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES)
52+
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES) # TODO: Remove once contiguous views are working
5353
matmul!(C, A, B, alpha, beta, transA, transB)
5454
else
5555
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
@@ -84,7 +84,7 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B
8484
# If possible, dispatch to MPSGraphs, then performance shaders
8585
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES)
8686
graph_matvecmul!(C, A, B, alpha, beta, transA)
87-
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATVECMUL_TYPES)
87+
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATVECMUL_TYPES) # TODO: Remove once contiguous views are working
8888
matvecmul!(C, A, B, alpha, beta, transA)
8989
else
9090
GPUArrays.generic_matmatmul!(C, wrap(A, tA), B, alpha, beta)

0 commit comments

Comments
 (0)