Skip to content

Commit 7706211

Browse files
Faster matmul sometimes (#580)
* Use MPS instead of MPSGraph matmul when optimal * Faster testing * Fix * Algorithm selection * flopscomp improvements * Fix & tests * No need for AppleAccelerate * More specific error * More tests
1 parent ab25f7b commit 7706211

File tree

6 files changed

+125
-33
lines changed

6 files changed

+125
-33
lines changed

Project.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ Preferences = "21216c6a-2e73-6563-6e65-726566657250"
2121
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
2222
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
2323
SHA = "ea8e919c-243c-51af-8825-aaa63cd721ce"
24+
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
2425
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"
2526
UUIDs = "cf7118a7-6976-5b1a-9a39-7adc72f591a4"
2627

@@ -49,6 +50,7 @@ Preferences = "1"
4950
Printf = "1"
5051
Random = "1"
5152
SHA = "0.7"
53+
ScopedValues = "1.3.0"
5254
SpecialFunctions = "2"
5355
StaticArrays = "1"
5456
UUIDs = "1"

examples/flopscomp.jl

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
1-
2-
using Metal, GPUArrays, LinearAlgebra, Printf, AppleAccelerate
1+
using Metal, GPUArrays, LinearAlgebra, Printf#, AppleAccelerate
32

43
testing = (@isdefined TESTING) && TESTING
54

@@ -8,14 +7,15 @@ testing = (@isdefined TESTING) && TESTING
87
using Plots.Measures
98
end
109

11-
const Ts=[
12-
(Int8, Float16),
13-
(Int8, Float32),
14-
(Int16, Float32),
15-
(Float16, Float16),
16-
(Float16, Float32),
17-
(Float32, Float32),
18-
]
10+
Ts=[
11+
(Int8, Float16),
12+
(Int8, Float32),
13+
(Int16, Float32),
14+
(Float16, Float16),
15+
(Float16, Float32),
16+
(Float32, Float32),
17+
]
18+
DEFAULT_NS = [50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 1500, 2000, 2048, 2500, 3000, 4000, 4096, 5000, 6000, 6144, 8000, 8192]
1919

2020
n_gpu_cores = "??"
2121
# Comment this out if scary. Please mention number of cores in your comment when uploading the figure
@@ -68,6 +68,16 @@ function gpuarrpeakflops(; n::Integer=4096,
6868
GPUArrays.generic_matmatmul!(c, LinearAlgebra.wrap(a, 'N'), LinearAlgebra.wrap(b, 'N'), 1, 0)
6969
end
7070
end
71+
function defaultpeakflops(; n::Integer=4096,
72+
n_batch::Integer=1,
73+
inT::DataType=Float32,
74+
outT::DataType=inT,
75+
ntrials::Integer=3,
76+
verify=true)
77+
_peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b
78+
LinearAlgebra.generic_matmatmul!(c, 'N', 'N', a, b, 1, 0)
79+
end
80+
end
7181
function mpspeakflops(; n::Integer=4096,
7282
n_batch::Integer=1,
7383
inT::DataType=Float32,
@@ -128,25 +138,25 @@ function compare(Ns, Fs, inT, outT=inT; n_batch=1, ntrials)
128138
return results
129139
end
130140

131-
function runcomparison(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4000, 4096, 6000, 6144, 8000, 8192],#, 10000],
132-
Fs=[
133-
(mpspeakflops, "MPS"),
134-
(graphpeakflops, "MPSGraph"),
135-
(anepeakflops, "MPSGraph (ANE)"),
136-
# (gpuarrpeakflops, "GPUArrays"),
137-
# (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
138-
],
139-
n_batch=1,
140-
ntrials=5)
141-
res = Dict()
141+
DEFAULT_FS = [
142+
(mpspeakflops, "MPS"),
143+
(graphpeakflops, "MPSGraph"),
144+
(defaultpeakflops, "Default"),
145+
# (anepeakflops, "MPSGraph (ANE)"),
146+
# (gpuarrpeakflops, "GPUArrays"),
147+
# (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
148+
]
142149

150+
function runcomparison(; Ns=DEFAULT_NS, Fs=DEFAULT_FS, n_batch=1, ntrials=5)
151+
res = Dict()
143152
for (inT, outT) in Ts
144153
res[(inT,outT)] = (n_batch, Ns, compare(Ns, Fs, inT, outT; n_batch, ntrials))
145154
end
146155
return res
147156
end
148157

149-
function plot_results(res, Fs=["MPS", "MPSGraph", "MPSGraph (ANE)"]; outpath=nothing, outtype="svg", plt_title=PLOT_TITLE)
158+
function plot_results(res, Fs=DEFAULT_FS; outpath=nothing, outtype="svg", plt_title=PLOT_TITLE)
159+
Fs = get.(Fs, 2, "You shouldn't be reading this")
150160
ylim_upper = 9e12
151161
resplts = []
152162

@@ -164,7 +174,7 @@ function plot_results(res, Fs=["MPS", "MPSGraph", "MPSGraph (ANE)"]; outpath=not
164174
if maximum(flops) > ylim_upper
165175
ylim_upper = maximum(flops) * 1.02
166176
end
167-
plot!(plt, Ns, tmpres[info_str]; linewidth=1.5, label="$(peakf) peak: $info_str")
177+
plot!(plt, Ns, tmpres[info_str]; linewidth=1.5, label="$(peakf) peak: $info_str", α=0.8)
168178
end
169179
push!(resplts, plt)
170180
push!(n_batches, n_batch)
@@ -184,4 +194,7 @@ end
184194

185195
if testing
186196
runcomparison(Ns=[50, 64, 100, 128, 250, 256, 500, 512])
197+
elseif abspath(PROGRAM_FILE) == @__FILE__
198+
res = runcomparison()
199+
plot_results(res; outpath=".")
187200
end

src/Metal.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ using ExprTools: splitdef, combinedef
1212
using ObjectiveC, .CoreFoundation, .Foundation, .Dispatch, .OS
1313
import ObjectiveC: is_macos, darwin_version, macos_version
1414
import KernelAbstractions
15+
using ScopedValues
1516

1617
include("version.jl")
1718

src/linalg.jl

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -6,13 +6,13 @@ using .MPSGraphs: MPSGRAPH_VALID_MATMUL_TYPES, MPSGRAPH_VALID_MATVECMUL_TYPES,
66
graph_matmul!, graph_matvecmul!
77

88
@inline function supports_mps_matmul(A, B, C, valid_types)
9-
MPS.is_supported(device(A)) &&
9+
MPS.is_supported(device(C)) &&
1010
eltype(A) == eltype(B) &&
1111
(eltype(A), eltype(C)) in valid_types
1212
end
1313

1414
@inline function supports_mpsgraph_matmul(A, B, C, valid_types)
15-
MPS.is_supported(device(A)) &&
15+
MPS.is_supported(device(C)) &&
1616
eltype(A) == eltype(B) &&
1717
(eltype(A), eltype(C)) in valid_types &&
1818
# TODO: remove this limitation
@@ -21,6 +21,19 @@ end
2121
C.offset == 0
2222
end
2323

24+
# Assumes support for MPS matrix multiplication has been verified elsewhere
25+
@inline function should_use_MPS(A, _, C)
26+
rows = size(C,1)
27+
cols = size(C,2)
28+
# TODO: matvecmul different?
29+
(eltype(A) <: Integer && rows <= 2000 && cols <= 2000 ) ||
30+
eltype(A) <: AbstractFloat && rows <= 6000 && cols <= 6000 && Metal.supports_family(device(C), MTL.MTLGPUFamilyApple9)
31+
end
32+
33+
# Supported values are :auto, :MPS, :MPSGraph, and :GPUArrays
34+
const matmul_alg = ScopedValue(:auto)
35+
matmul_alg_error(alg, inT, outT, vec) = error("Matrix-$(vec ? "Vector" : "Matrix") multiplication algorithm `:$alg` is not supported for input eltype $inT and output eltype $outT.")
36+
2437
LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatrix, _add::MulAddMul) =
2538
LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta)
2639
@autoreleasepool function LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB,
@@ -46,13 +59,20 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri
4659
transA = tA == 'T' || tA == 'C'
4760
transB = tB == 'T' || tB == 'C'
4861

62+
alg = matmul_alg[]
63+
mps_supported = supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES)
64+
mpsgraph_supported = supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES)
4965
# If possible, dispatch to MPSGraphs, then performance shaders
50-
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES)
66+
if alg === :MPSGraph || (alg === :auto && mpsgraph_supported && !should_use_MPS(A, B, C))
67+
mpsgraph_supported || matmul_alg_error(alg, eltype(A), eltype(C), false)
5168
graph_matmul!(C, A, B, alpha, beta, transA, transB)
52-
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES) # TODO: Remove once contiguous views are working
69+
elseif alg === :MPS || (alg === :auto && mps_supported)
70+
mps_supported || matmul_alg_error(alg, eltype(A), eltype(C), false)
5371
matmul!(C, A, B, alpha, beta, transA, transB)
54-
else
72+
elseif alg === :GPUArrays || alg === :auto
5573
GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta)
74+
else
75+
error(":$alg is not a valid matmul algorithm. Options are: `:auto`, `:MPS`, `:MPSGraph`, `:GPUArrays`")
5676
end
5777
end
5878

@@ -81,13 +101,20 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B
81101

82102
transA = tA == 'T' || tA == 'C'
83103

104+
alg = matmul_alg[]
105+
mps_supported = supports_mps_matmul(A, B, C, MPS_VALID_MATVECMUL_TYPES)
106+
mpsgraph_supported = supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES)
84107
# If possible, dispatch to MPSGraphs, then performance shaders
85-
if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES)
108+
if alg === :MPSGraph || (alg === :auto && mpsgraph_supported)
109+
mpsgraph_supported || matmul_alg_error(alg, eltype(A), eltype(C), true)
86110
graph_matvecmul!(C, A, B, alpha, beta, transA)
87-
elseif supports_mps_matmul(A, B, C, MPS_VALID_MATVECMUL_TYPES) # TODO: Remove once contiguous views are working
111+
elseif alg === :MPS || (alg === :auto && mps_supported)
112+
mps_supported || matmul_alg_error(alg, eltype(A), eltype(C), true)
88113
matvecmul!(C, A, B, alpha, beta, transA)
89-
else
114+
elseif alg === :GPUArrays || alg === :auto
90115
GPUArrays.generic_matmatmul!(C, wrap(A, tA), B, alpha, beta)
116+
else
117+
error(":$alg is not a valid matmul algorithm. Options are: `:auto`, `:MPS`, `:MPSGraph`, `:GPUArrays`")
91118
end
92119
end
93120

test/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,5 @@
11
[deps]
22
Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e"
3-
AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924"
43
BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b"
54
BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf"
65
Dates = "ade2ca70-3891-5945-98fb-dc099432e06a"
@@ -14,6 +13,7 @@ ObjectiveC = "e86c9b32-1129-44ac-8ea0-90d5bb39ded9"
1413
Printf = "de0858da-6303-5e67-8744-51eddeeeb8d7"
1514
REPL = "3fa0cd96-eef1-5676-8a61-b3b8758bbffb"
1615
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
16+
ScopedValues = "7e506255-f358-4e82-b7e4-beb19740aa63"
1717
SparseArrays = "2f01184e-e22b-5df5-ae63-d93ebab69eaf"
1818
SpecialFunctions = "276daf66-3868-5448-9aa4-cd146d93841b"
1919
StaticArrays = "90137ffa-7385-5640-81b9-e52037218182"

test/linalg.jl

Lines changed: 50 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,56 @@
1-
using LinearAlgebra
1+
using LinearAlgebra, ScopedValues
22

33
if MPS.is_supported(device())
44

5+
@testset "matmul algorithm selection" begin
6+
# test that unsupported configurations error properly
7+
N = 20
8+
function test_matmul(inT, outT; vec_b=false, alg=:auto)
9+
a = inT <: Integer ? inT.(rand(-5:5, N,N)) : rand(inT, N, N)
10+
11+
bdims = vec_b ? (N,) : (N, N)
12+
b = inT <: Integer ? inT.(rand(-5:5, bdims)) : rand(inT, bdims)
13+
14+
ma = MtlArray(a)
15+
mb = MtlArray(b)
16+
mc = fill!(similar(mb, outT), zero(outT))
17+
18+
@with (Metal.matmul_alg => alg) mul!(mc,ma,mb)
19+
20+
return all((outT.(a)*outT.(b)) .≈ Array(mc))
21+
end
22+
23+
for vec_b in (true, false)
24+
@testset let vec_b = vec_b
25+
# Unsupported for MPS and MPSGraph
26+
@test_throws "Matrix-$(vec_b ? "Vector" : "Matrix") multiplication algorithm `:MPS`" test_matmul(Int8, Int16; vec_b, alg=:MPS)
27+
@test_throws "Matrix-$(vec_b ? "Vector" : "Matrix") multiplication algorithm `:MPSGraph`" test_matmul(Int8, Int16; vec_b, alg=:MPSGraph)
28+
29+
# Invalid algorithm Symbol
30+
@test_throws ":bad is not a valid matmul algorithm." test_matmul(Int8, Int16; vec_b, alg=:bad)
31+
@test_throws ":bad is not a valid matmul algorithm." test_matmul(Float16, Float16; vec_b, alg=:bad)
32+
33+
# :auto
34+
@test test_matmul(Int32, Int32; vec_b) # fallback to GPUArrays
35+
@test test_matmul(Int8, Float32; vec_b) # should use MPS
36+
@test test_matmul(Float16, Float32; vec_b) # should use MPSGraph on M1/M2
37+
38+
# :MPS
39+
mpsInT = vec_b ? Float32 : Int16
40+
@test test_matmul(mpsInT, Float32; vec_b, alg=:MPS)
41+
@test test_matmul(Float16, Float32; vec_b, alg=:MPS)
42+
43+
# :MPSGraph
44+
@test test_matmul(Int8, Float32; vec_b, alg=:MPSGraph)
45+
@test test_matmul(Float16, Float32; vec_b, alg=:MPSGraph)
46+
47+
# :GPUArrays
48+
@test test_matmul(Int32, Int32; vec_b, alg=:GPUArrays)
49+
@test test_matmul(Int8, Float32; vec_b, alg=:GPUArrays)
50+
@test test_matmul(Float16, Float32; vec_b, alg=:GPUArrays)
51+
end
52+
end
53+
end
554

655
@testset "test matrix vector multiplication of views" begin
756
N = 20

0 commit comments

Comments
 (0)