Skip to content

Commit fab6fc2

Browse files
Initial MPSGraph support & use for matmul (#566)
* Initial MPSGraph support * Use MPSGraph for matrix multiplication * Add encode function and move run functions * Implementation more similar to MPS implementation * Also use MPSGraph matmul for int -> float matmul The only code left running with MPS are the contiguous views. * Remove copying * All operations using Float32 * Tweak * Fix cast to Float32 when beta != 0 * More operations * Support more complex broadcasting behaviour Will unblock NNlib issue 614 * More API * Use optimization Level 0 by default to disable use of neural engine * @autoreleasepool * Push test script * Fix and test MPSGraph random * Comment out unused and slightly broken code * Tests and coverage-related fixes * Add comment explaining why use OptimizationLevel0 * format * Guard against invalid ranks * Move flopscomp to examples * Fix running tests locally * Fix
1 parent 369e292 commit fab6fc2

25 files changed

+1308
-21
lines changed

examples/flopscomp.jl

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
2+
using Metal, GPUArrays, LinearAlgebra, Printf, AppleAccelerate
3+
4+
testing = (@isdefined TESTING) && TESTING
5+
6+
@static if !testing
7+
using Plots
8+
using Plots.Measures
9+
end
10+
11+
const Ts=[
12+
(Int8, Float16),
13+
(Int8, Float32),
14+
(Int16, Float32),
15+
(Float16, Float16),
16+
(Float16, Float32),
17+
(Float32, Float32),
18+
]
19+
20+
n_gpu_cores = "??"
21+
# Comment this out if scary. Please mention number of cores in your comment when uploading the figure
22+
system_prof = read(`system_profiler SPDisplaysDataType`, String)
23+
n_gpu_cores = only(match(r"Total Number of Cores:\s*(\d+)", system_prof).captures)
24+
25+
PLOT_TITLE = "Matmul peakflops for $(device().name) ($n_gpu_cores GPU cores)"
26+
27+
function cpupeakflops(; n::Integer=4096,
28+
n_batch::Integer=1,
29+
inT::DataType=Float32,
30+
outT::DataType=inT,
31+
ntrials::Integer=4,
32+
verify=true)
33+
t = Base.zeros(Float64, ntrials)
34+
n_batch == 1 || @warn "n_batch > 1 not supported for `mul!`, running with n_batch=1"
35+
n_batch = 1
36+
shape = (n, n)
37+
for i=1:ntrials
38+
c = zeros(outT, shape...)
39+
a = ones(inT, shape...)
40+
b = ones(inT, shape...)
41+
t[i] = @elapsed mul!(c, a, b)
42+
verify && @assert only(unique(Array(c))) == n
43+
end
44+
45+
return n_batch*2*Float64(n)^3 / minimum(t)
46+
end
47+
function _peakflops(f, n, n_batch, inT, outT, ntrials; verify=true)
48+
t = Base.zeros(Float64, ntrials)
49+
shape = n_batch == 1 ? (n, n) : (n, n, n_batch)
50+
for i=1:ntrials
51+
c = mtl(zeros(outT, shape...))
52+
a = mtl(ones(inT, shape...))
53+
b = mtl(ones(inT, shape...))
54+
t[i] = @elapsed Metal.@sync f(c, a, b)
55+
verify && @assert only(unique(Array(c))) == n
56+
end
57+
58+
return n_batch*2*Float64(n)^3 / minimum(t)
59+
end
60+
function gpuarrpeakflops(; n::Integer=4096,
61+
n_batch::Integer=1,
62+
inT::DataType=Float32,
63+
outT::DataType=inT,
64+
ntrials::Integer=3,
65+
verify=true)
66+
n_batch == 1 || @warn "n_batch > 1 not supported for `GPUArrays.generic_matmatmul!`, running with n_batch=1"
67+
_peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b
68+
GPUArrays.generic_matmatmul!(c, LinearAlgebra.wrap(a, 'N'), LinearAlgebra.wrap(b, 'N'), 1, 0)
69+
end
70+
end
71+
function mpspeakflops(; n::Integer=4096,
72+
n_batch::Integer=1,
73+
inT::DataType=Float32,
74+
outT::DataType=inT,
75+
ntrials::Integer=3,
76+
verify=true)
77+
_peakflops(MPS.matmul!, n, n_batch, inT, outT, ntrials; verify)
78+
end
79+
function graphpeakflops(; n::Integer=4096,
80+
n_batch::Integer=1,
81+
inT::DataType=Float32,
82+
outT::DataType=inT,
83+
ntrials::Integer=3,
84+
verify=true)
85+
_peakflops(MPSGraphs.graph_matmul!, n, n_batch, inT, outT, ntrials; verify)
86+
end
87+
function anepeakflops(; kwargs...)
88+
# VERY HACKY
89+
newDesc = MPSGraphs.MPSGraphCompilationDescriptor()
90+
# Use optimization level 0 to avoid operations being moved to the neural engine
91+
newDesc.optimizationLevel = MPSGraphs.MPSGraphOptimizationLevel1
92+
93+
oldDesc = MPSGraphs._default_exec_desc[].compilationDescriptor
94+
95+
MPSGraphs._default_exec_desc[].compilationDescriptor = newDesc
96+
res = graphpeakflops(; kwargs...)
97+
MPSGraphs._default_exec_desc[].compilationDescriptor = oldDesc
98+
99+
return res
100+
end
101+
102+
function compare(Ns, Fs, inT, outT=inT; n_batch=1, ntrials)
103+
results = Dict()
104+
105+
newFs = if (outT == Float16 || (outT == Float32 && inT == Float16))
106+
Fs
107+
else
108+
filter(x -> !occursin("ANE", x[2]),Fs)
109+
end
110+
111+
for (_, info_str) in newFs
112+
results[info_str] = Float64[]
113+
end
114+
115+
prefixstr = "\33[2K\r($inT, $outT) "
116+
@time "$((inT, outT))" begin
117+
for n in Ns
118+
verify = (n < maxintfloat(outT) && (inT != Float16 || (n < maxintfloat(inT))))
119+
n_str = "$n: "
120+
for (f, info_str) in newFs
121+
print(prefixstr, n_str, info_str)
122+
push!(results[info_str], f(; inT, outT, n, n_batch, ntrials, verify))
123+
GC.gc()
124+
end
125+
end
126+
print("\33[2K\r")
127+
end
128+
return results
129+
end
130+
131+
function runcomparison(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4000, 4096, 6000, 6144, 8000, 8192],#, 10000],
132+
Fs=[
133+
(mpspeakflops, "MPS"),
134+
(graphpeakflops, "MPSGraph"),
135+
(anepeakflops, "MPSGraph (ANE)"),
136+
# (gpuarrpeakflops, "GPUArrays"),
137+
# (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
138+
],
139+
n_batch=1,
140+
ntrials=5)
141+
res = Dict()
142+
143+
for (inT, outT) in Ts
144+
res[(inT,outT)] = (n_batch, Ns, compare(Ns, Fs, inT, outT; n_batch, ntrials))
145+
end
146+
return res
147+
end
148+
149+
function plot_results(res, Fs=["MPS", "MPSGraph", "MPSGraph (ANE)"]; outpath=nothing, outtype="svg", plt_title=PLOT_TITLE)
150+
ylim_upper = 9e12
151+
resplts = []
152+
153+
n_batches = []
154+
155+
for (inT, outT) in Ts
156+
n_batch, Ns, tmpres = res[(inT,outT)]
157+
158+
plt = plot(xlabel="N, n_batch=$(n_batch)", legendtitle="($inT, $outT)")
159+
for info_str in Fs
160+
haskey(tmpres, info_str) || continue
161+
162+
flops = tmpres[info_str]
163+
peakf = @sprintf("%.3e", maximum(flops))
164+
if maximum(flops) > ylim_upper
165+
ylim_upper = maximum(flops) * 1.02
166+
end
167+
plot!(plt, Ns, tmpres[info_str]; linewidth=1.5, label="$(peakf) peak: $info_str")
168+
end
169+
push!(resplts, plt)
170+
push!(n_batches, n_batch)
171+
end
172+
173+
finalplot = plot(resplts...; layout=(2,3),
174+
ylim=(0,ylim_upper),
175+
plot_title=plt_title,
176+
tickfonthalign=:left,
177+
bottommargin=15pt,
178+
size=(2000,1200))
179+
if !isnothing(outpath)
180+
savefig(plot(finalplot, dpi=500), joinpath(outpath, "bench_all_$(first(n_batches)).$outtype"))
181+
end
182+
return finalplot
183+
end
184+
185+
if testing
186+
runcomparison(Ns=[50, 64, 100, 128, 250, 256, 500, 512])
187+
end

lib/mps/MPS.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using BFloat16s
2121
const MtlFloat = Union{Float32, Float16}
2222

2323
const MPSShape = NSArray#{NSNumber}
24-
Base.convert(::Type{MPSShape}, tuple::Union{Vector{N},NTuple{N, <:Integer}}) where N = NSArray(NSNumber.(collect(tuple)))
24+
Base.convert(::Type{MPSShape}, tuple::Union{Vector{T},NTuple{T, <:Integer}}) where T = NSArray(NSNumber.(collect(tuple)))
2525

2626
# Valid combination of input (A and B matrices) and output (C) types
2727
const MPS_VALID_MATMUL_TYPES =

lib/mps/ndarray.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,8 @@ export MPSNDArrayDescriptor
77
# @objcwrapper immutable=false MPSNDArrayDescriptor <: NSObject
88

99
function MPSNDArrayDescriptor(dataType::DataType, dimensionCount, dimensionSizes::Ptr)
10+
1 <= dimensionCount <= 16 || throw(ArgumentError("`dimensionCount` must be between 1 and 16 inclusive"))
11+
1012
desc = @objc [MPSNDArrayDescriptor descriptorWithDataType:dataType::MPSDataType
1113
dimensionCount:dimensionCount::NSUInteger
1214
dimensionSizes:dimensionSizes::Ptr{NSUInteger}]::id{MPSNDArrayDescriptor}

lib/mpsgraphs/MPSGraphs.jl

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,51 @@
1+
"""
2+
# MPSGraphs
3+
4+
`MPSGraphs` is where the Metal Performance Shaders Graph API wrappers are defined.
5+
6+
Not all functionality is currently implemented or documented. For further details,
7+
refer to the [official Apple documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph).
8+
"""
9+
module MPSGraphs
10+
11+
using ..Metal
12+
using .MTL
13+
using .MPS
14+
using .MPS: MPSDataType, MPSShape, exportDataWithCommandBuffer
15+
16+
using CEnum
17+
using ObjectiveC, .Foundation, .Dispatch
18+
19+
# Valid combination of input (A and B matrices) and output (C) types
20+
# The commented type combinations work but are slower than with MPSMatrixMultiplicatiom
21+
const MPSGRAPH_VALID_MATMUL_TYPES =
22+
[
23+
(Int8, Float16),
24+
(Int8, Float32),
25+
(Int16, Float32),
26+
(Float16, Float16),
27+
(Float16, Float32),
28+
(Float32, Float32),
29+
]
30+
31+
const MPSGRAPH_VALID_MATVECMUL_TYPES =
32+
[
33+
(Int8, Float16),
34+
(Int8, Float32),
35+
(Int16, Float32),
36+
(Float16, Float16),
37+
(Float16, Float32),
38+
(Float32, Float32),
39+
]
40+
41+
include("libmpsgraph.jl")
42+
43+
include("core.jl")
44+
include("tensor.jl")
45+
include("execution.jl")
46+
include("operations.jl")
47+
include("random.jl")
48+
49+
include("matmul.jl")
50+
51+
end

lib/mpsgraphs/core.jl

Lines changed: 42 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,42 @@
1+
# Contains definitions for api from MPSGraphCore.h, MPSGraphDevice.h
2+
3+
## MPSGraphCore.h
4+
# @objcwrapper MPSGraphObject <: NSObject
5+
# @objcwrapper MPSGraphType <: MPSGraphObject
6+
7+
# @objcwrapper MPSGraph <: MPSGraphObject
8+
function MPSGraph()
9+
MPSGraph(@objc [MPSGraph new]::id{MPSGraph})
10+
end
11+
12+
# @objcwrapper immutable=true MPSGraphShapedType <: MPSGraphType
13+
14+
# XXX: Not used yet and needs fixing
15+
# function MPSGraphShapedType(shape::MPSShape, dataType)
16+
# tmp = @objc [MPSGraphShapedType alloc]::id{MPSGraphShapedType}
17+
# obj = MPSGraphShapedType(tmp)
18+
# finalizer(release, obj)
19+
# @objc [obj::id{MPSGraphShapedType} initWithShape:shape::id{MPSShape}
20+
# dataType:dataType::MPSDataType]::id{MPSGraphShapedType}
21+
# return obj
22+
# end
23+
24+
## MPSGraphDevice.h
25+
# @objcwrapper MPSGraphDevice <: MPSGraphType
26+
27+
function MPSGraphDevice(device::MTLDevice)
28+
obj = @objc [MPSGraphDevice deviceWithMTLDevice:device::id{MTLDevice}]::id{MPSGraphDevice}
29+
MPSGraphDevice(obj)
30+
end
31+
32+
# @objcwrapper MPSGraphExecutionDescriptor <: MPSGraphObject
33+
34+
function MPSGraphExecutionDescriptor()
35+
MPSGraphExecutionDescriptor(@objc [MPSGraphExecutionDescriptor new]::id{MPSGraphExecutionDescriptor})
36+
end
37+
38+
# @objcwrapper MPSGraphCompilationDescriptor <: MPSGraphObject
39+
40+
function MPSGraphCompilationDescriptor()
41+
MPSGraphCompilationDescriptor(@objc [MPSGraphCompilationDescriptor new]::id{MPSGraphCompilationDescriptor})
42+
end

lib/mpsgraphs/execution.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
2+
MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, resultsDictionary::MPSGraphTensorDataDictionary) = @inline MPS.encode!(commandBuffer, graph, feeds, nil, resultsDictionary, MPSGraphExecutionDescriptor())
3+
function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetOperations, resultsDictionary::MPSGraphTensorDataDictionary, executionDescriptor)
4+
@objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
5+
feeds:feeds::id{MPSGraphTensorDataDictionary}
6+
targetOperations:targetOperations::id{Object}
7+
resultsDictionary:resultsDictionary::id{MPSGraphTensorDataDictionary}
8+
executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}]::Nothing
9+
return resultsDictionary
10+
end
11+
12+
function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil, executionDescriptor=MPSGraphExecutionDescriptor())
13+
obj = @objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
14+
feeds:feeds::id{MPSGraphTensorDataDictionary}
15+
targetTensors:targetTensors::id{NSArray}
16+
targetOperations:targetOperations::id{Object}
17+
executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}]::id{MPSGraphTensorDataDictionary}
18+
MPSGraphTensorDataDictionary(obj)
19+
end
20+
21+
function run(graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil)
22+
obj = @objc [graph::id{MPSGraph} runWithFeeds:feeds::id{MPSGraphTensorDataDictionary}
23+
targetTensors:targetTensors::id{NSArray}
24+
targetOperations:targetOperations::id{Object}]::id{MPSGraphTensorDataDictionary}
25+
MPSGraphTensorDataDictionary(obj)
26+
end
27+
28+
function run(graph::MPSGraph, commandQueue::MTLCommandQueue, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray)
29+
obj = @objc [graph::id{MPSGraph} runWithMTLCommandQueue:commandQueue::id{MTLCommandQueue}
30+
feeds:feeds::id{MPSGraphTensorDataDictionary}
31+
targetTensors:targetTensors::id{NSArray}
32+
targetOperations:nil::id{Object}]::id{MPSGraphTensorDataDictionary}
33+
MPSGraphTensorDataDictionary(obj)
34+
end

0 commit comments

Comments
 (0)