Skip to content

Commit 20b71fb

Browse files
committed
Use MPS instead of MPSGraph matmul when optimal
1 parent ab25f7b commit 20b71fb

File tree

2 files changed

+31
-7
lines changed

2 files changed

+31
-7
lines changed

examples/flopscomp.jl

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,16 @@ function gpuarrpeakflops(; n::Integer=4096,
6868
GPUArrays.generic_matmatmul!(c, LinearAlgebra.wrap(a, 'N'), LinearAlgebra.wrap(b, 'N'), 1, 0)
6969
end
7070
end
71+
function defaultpeakflops(; n::Integer=4096,
72+
n_batch::Integer=1,
73+
inT::DataType=Float32,
74+
outT::DataType=inT,
75+
ntrials::Integer=3,
76+
verify=true)
77+
_peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b
78+
LinearAlgebra.generic_matmatmul!(c, 'N', 'N', a, b, 1, 0)
79+
end
80+
end
7181
function mpspeakflops(; n::Integer=4096,
7282
n_batch::Integer=1,
7383
inT::DataType=Float32,
@@ -128,11 +138,13 @@ function compare(Ns, Fs, inT, outT=inT; n_batch=1, ntrials)
128138
return results
129139
end
130140

131-
function runcomparison(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4000, 4096, 6000, 6144, 8000, 8192],#, 10000],
141+
# function runcomparison(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4000, 4096, 6000, 6144, 8000, 8192],#, 10000],
142+
function runcomparison(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024:20:2000..., 2000, 2048:20:3000..., 4000, 4096:20:6000..., 6000, 6144, 8000, 8192],#, 10000],
132143
Fs=[
133144
(mpspeakflops, "MPS"),
134145
(graphpeakflops, "MPSGraph"),
135-
(anepeakflops, "MPSGraph (ANE)"),
146+
(defaultpeakflops, "Default"),
147+
# (anepeakflops, "MPSGraph (ANE)"),
136148
# (gpuarrpeakflops, "GPUArrays"),
137149
# (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
138150
],
@@ -146,7 +158,7 @@ function runcomparison(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2
146158
return res
147159
end
148160

149-
function plot_results(res, Fs=["MPS", "MPSGraph", "MPSGraph (ANE)"]; outpath=nothing, outtype="svg", plt_title=PLOT_TITLE)
161+
function plot_results(res, Fs=["MPS", "MPSGraph", "Default"]; outpath=nothing, outtype="svg", plt_title=PLOT_TITLE)
150162
ylim_upper = 9e12
151163
resplts = []
152164

@@ -164,7 +176,7 @@ function plot_results(res, Fs=["MPS", "MPSGraph", "MPSGraph (ANE)"]; outpath=not
164176
if maximum(flops) > ylim_upper
165177
ylim_upper = maximum(flops) * 1.02
166178
end
167-
plot!(plt, Ns, tmpres[info_str]; linewidth=1.5, label="$(peakf) peak: $info_str")
179+
plot!(plt, Ns, tmpres[info_str]; linewidth=1.5, label="$(peakf) peak: $info_str", α=0.8)
168180
end
169181
push!(resplts, plt)
170182
push!(n_batches, n_batch)
@@ -185,3 +197,6 @@ end
185197
if testing
186198
runcomparison(Ns=[50, 64, 100, 128, 250, 256, 500, 512])
187199
end
200+
201+
res = runcomparison()
202+
plot_results(res; outpath=".")

src/linalg.jl

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ using .MPSGraphs: MPSGRAPH_VALID_MATMUL_TYPES, MPSGRAPH_VALID_MATVECMUL_TYPES,
66
graph_matmul!, graph_matvecmul!
77

88
@inline function supports_mps_matmul(A, B, C, valid_types)
9-
MPS.is_supported(device(A)) &&
9+
MPS.is_supported(device(C)) &&
1010
eltype(A) == eltype(B) &&
1111
(eltype(A), eltype(C)) in valid_types
1212
end
1313

1414
@inline function supports_mpsgraph_matmul(A, B, C, valid_types)
15-
MPS.is_supported(device(A)) &&
15+
MPS.is_supported(device(C)) &&
1616
eltype(A) == eltype(B) &&
1717
(eltype(A), eltype(C)) in valid_types &&
1818
# TODO: remove this limitation
@@ -21,6 +21,15 @@ end
2121
C.offset == 0
2222
end
2323

24+
# Assumes support for MPS matrix multiplication has been verified elsewhere
25+
@inline function should_use_MPS(A, _, C)
26+
rows = size(C,1)
27+
cols = size(C,2)
28+
# TODO: matvecmul different?
29+
(eltype(A) <: Integer && rows <= 2000 && cols <= 2000 ) ||
30+
eltype(A) <: AbstractFloat && rows <= 6000 && cols <= 6000 && Metal.supports_family(device(C), MTL.MTLGPUFamilyApple9)
31+
end
32+
2433
LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatrix, _add::MulAddMul) =
2534
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
2635
@autoreleasepool function LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB,
@@ -47,7 +56,7 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri
4756
transB = tB == 'T' || tB == 'C'
4857

4958
# If possible, dispatch to MPSGraphs, then performance shaders
50-
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES)
59+
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES) && !should_use_MPS(A, B, C)
5160
graph_matmul!(C, A, B, alpha, beta, transA, transB)
5261
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES) # TODO: Remove once contiguous views are working
5362
matmul!(C, A, B, alpha, beta, transA, transB)

0 commit comments

Comments
 (0)