Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
777daf8
Initial MPSGraph support
christiangnrd Dec 18, 2024
4e14cc1
Use MPSGraph for matrix multiplication
christiangnrd Mar 17, 2025
4496900
Add encode function and move run functions
christiangnrd Mar 18, 2025
0e068e0
Implementation more similar to MPS implementation
christiangnrd Mar 18, 2025
46cf94a
Also use MPSGraph matmul for int -> float matmul
christiangnrd Mar 18, 2025
e2fd80d
Remove copying
christiangnrd Mar 18, 2025
119a84d
All operations using Float32
christiangnrd Mar 18, 2025
f539b67
Tweak
christiangnrd Mar 18, 2025
b127183
Fix cast to Float32 when beta != 0
christiangnrd Mar 18, 2025
608ef4a
More operations
christiangnrd Mar 19, 2025
07ae758
Support more complex broadcasting behaviour
christiangnrd Mar 19, 2025
8c569fd
More API
christiangnrd Mar 19, 2025
541d9bc
Use optimization Level 0 by default to disable use of neural engine
christiangnrd Mar 19, 2025
ecef13a
@autoreleasepool
christiangnrd Mar 19, 2025
777c3cc
Push test script
christiangnrd Mar 19, 2025
1fe0db0
Fix and test MPSGraph random
christiangnrd Mar 20, 2025
68ca79a
Comment out unused and slightly broken code
christiangnrd Mar 20, 2025
d8d98a6
Tests and coverage-related fixes
christiangnrd Mar 20, 2025
1d4e678
Add comment explaining why use OptimizationLevel0
christiangnrd Mar 20, 2025
cb27f46
format
christiangnrd Mar 20, 2025
1edfcb9
Guard against invalid ranks
christiangnrd Mar 21, 2025
e7791d9
Move flopscomp to examples
christiangnrd Mar 21, 2025
dccf5bf
Fix running tests locally
christiangnrd Mar 21, 2025
faa8b6a
Fix
christiangnrd Mar 21, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
187 changes: 187 additions & 0 deletions examples/flopscomp.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,187 @@

using Metal, GPUArrays, LinearAlgebra, Printf, AppleAccelerate

testing = (@isdefined TESTING) && TESTING

@static if !testing
using Plots
using Plots.Measures
end

const Ts=[
(Int8, Float16),
(Int8, Float32),
(Int16, Float32),
(Float16, Float16),
(Float16, Float32),
(Float32, Float32),
]

n_gpu_cores = "??"
# Comment this out if scary. Please mention number of cores in your comment when uploading the figure
system_prof = read(`system_profiler SPDisplaysDataType`, String)
n_gpu_cores = only(match(r"Total Number of Cores:\s*(\d+)", system_prof).captures)

PLOT_TITLE = "Matmul peakflops for $(device().name) ($n_gpu_cores GPU cores)"

function cpupeakflops(; n::Integer=4096,
n_batch::Integer=1,
inT::DataType=Float32,
outT::DataType=inT,
ntrials::Integer=4,
verify=true)
t = Base.zeros(Float64, ntrials)
n_batch == 1 || @warn "n_batch > 1 not supported for `mul!`, running with n_batch=1"
n_batch = 1
shape = (n, n)
for i=1:ntrials
c = zeros(outT, shape...)
a = ones(inT, shape...)
b = ones(inT, shape...)
t[i] = @elapsed mul!(c, a, b)
verify && @assert only(unique(Array(c))) == n
end

return n_batch*2*Float64(n)^3 / minimum(t)
end
function _peakflops(f, n, n_batch, inT, outT, ntrials; verify=true)
t = Base.zeros(Float64, ntrials)
shape = n_batch == 1 ? (n, n) : (n, n, n_batch)
for i=1:ntrials
c = mtl(zeros(outT, shape...))
a = mtl(ones(inT, shape...))
b = mtl(ones(inT, shape...))
t[i] = @elapsed Metal.@sync f(c, a, b)
verify && @assert only(unique(Array(c))) == n
end

return n_batch*2*Float64(n)^3 / minimum(t)
end
function gpuarrpeakflops(; n::Integer=4096,
n_batch::Integer=1,
inT::DataType=Float32,
outT::DataType=inT,
ntrials::Integer=3,
verify=true)
n_batch == 1 || @warn "n_batch > 1 not supported for `GPUArrays.generic_matmatmul!`, running with n_batch=1"
_peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b
GPUArrays.generic_matmatmul!(c, LinearAlgebra.wrap(a, 'N'), LinearAlgebra.wrap(b, 'N'), 1, 0)
end
end
function mpspeakflops(; n::Integer=4096,
n_batch::Integer=1,
inT::DataType=Float32,
outT::DataType=inT,
ntrials::Integer=3,
verify=true)
_peakflops(MPS.matmul!, n, n_batch, inT, outT, ntrials; verify)
end
function graphpeakflops(; n::Integer=4096,
n_batch::Integer=1,
inT::DataType=Float32,
outT::DataType=inT,
ntrials::Integer=3,
verify=true)
_peakflops(MPSGraphs.graph_matmul!, n, n_batch, inT, outT, ntrials; verify)
end
function anepeakflops(; kwargs...)
# VERY HACKY
newDesc = MPSGraphs.MPSGraphCompilationDescriptor()
# Use optimization level 0 to avoid operations being moved to the neural engine
newDesc.optimizationLevel = MPSGraphs.MPSGraphOptimizationLevel1

oldDesc = MPSGraphs._default_exec_desc[].compilationDescriptor

MPSGraphs._default_exec_desc[].compilationDescriptor = newDesc
res = graphpeakflops(; kwargs...)
MPSGraphs._default_exec_desc[].compilationDescriptor = oldDesc

return res
end

function compare(Ns, Fs, inT, outT=inT; n_batch=1, ntrials)
results = Dict()

newFs = if (outT == Float16 || (outT == Float32 && inT == Float16))
Fs
else
filter(x -> !occursin("ANE", x[2]),Fs)
end

for (_, info_str) in newFs
results[info_str] = Float64[]
end

prefixstr = "\33[2K\r($inT, $outT) "
@time "$((inT, outT))" begin
for n in Ns
verify = (n < maxintfloat(outT) && (inT != Float16 || (n < maxintfloat(inT))))
n_str = "$n: "
for (f, info_str) in newFs
print(prefixstr, n_str, info_str)
push!(results[info_str], f(; inT, outT, n, n_batch, ntrials, verify))
GC.gc()
end
end
print("\33[2K\r")
end
return results
end

function runcomparison(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4000, 4096, 6000, 6144, 8000, 8192],#, 10000],
Fs=[
(mpspeakflops, "MPS"),
(graphpeakflops, "MPSGraph"),
(anepeakflops, "MPSGraph (ANE)"),
# (gpuarrpeakflops, "GPUArrays"),
# (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance
],
n_batch=1,
ntrials=5)
res = Dict()

for (inT, outT) in Ts
res[(inT,outT)] = (n_batch, Ns, compare(Ns, Fs, inT, outT; n_batch, ntrials))
end
return res
end

function plot_results(res, Fs=["MPS", "MPSGraph", "MPSGraph (ANE)"]; outpath=nothing, outtype="svg", plt_title=PLOT_TITLE)
ylim_upper = 9e12
resplts = []

n_batches = []

for (inT, outT) in Ts
n_batch, Ns, tmpres = res[(inT,outT)]

plt = plot(xlabel="N, n_batch=$(n_batch)", legendtitle="($inT, $outT)")
for info_str in Fs
haskey(tmpres, info_str) || continue

flops = tmpres[info_str]
peakf = @sprintf("%.3e", maximum(flops))
if maximum(flops) > ylim_upper
ylim_upper = maximum(flops) * 1.02
end
plot!(plt, Ns, tmpres[info_str]; linewidth=1.5, label="$(peakf) peak: $info_str")
end
push!(resplts, plt)
push!(n_batches, n_batch)
end

finalplot = plot(resplts...; layout=(2,3),
ylim=(0,ylim_upper),
plot_title=plt_title,
tickfonthalign=:left,
bottommargin=15pt,
size=(2000,1200))
if !isnothing(outpath)
savefig(plot(finalplot, dpi=500), joinpath(outpath, "bench_all_$(first(n_batches)).$outtype"))
end
return finalplot
end

if testing
runcomparison(Ns=[50, 64, 100, 128, 250, 256, 500, 512])
end
2 changes: 1 addition & 1 deletion lib/mps/MPS.jl
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ using BFloat16s
const MtlFloat = Union{Float32, Float16}

const MPSShape = NSArray#{NSNumber}
Base.convert(::Type{MPSShape}, tuple::Union{Vector{N},NTuple{N, <:Integer}}) where N = NSArray(NSNumber.(collect(tuple)))
Base.convert(::Type{MPSShape}, tuple::Union{Vector{T},NTuple{T, <:Integer}}) where T = NSArray(NSNumber.(collect(tuple)))

# Valid combination of input (A and B matrices) and output (C) types
const MPS_VALID_MATMUL_TYPES =
Expand Down
2 changes: 2 additions & 0 deletions lib/mps/ndarray.jl
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@ export MPSNDArrayDescriptor
# @objcwrapper immutable=false MPSNDArrayDescriptor <: NSObject

function MPSNDArrayDescriptor(dataType::DataType, dimensionCount, dimensionSizes::Ptr)
1 <= dimensionCount <= 16 || throw(ArgumentError("`dimensionCount` must be between 1 and 16 inclusive"))

desc = @objc [MPSNDArrayDescriptor descriptorWithDataType:dataType::MPSDataType
dimensionCount:dimensionCount::NSUInteger
dimensionSizes:dimensionSizes::Ptr{NSUInteger}]::id{MPSNDArrayDescriptor}
Expand Down
51 changes: 51 additions & 0 deletions lib/mpsgraphs/MPSGraphs.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
"""
# MPSGraphs

`MPSGraphs` is where the Metal Performance Shaders Graph API wrappers are defined.

Not all functionality is currently implemented or documented. For further details,
refer to the [official Apple documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph).
"""
module MPSGraphs

using ..Metal
using .MTL
using .MPS
using .MPS: MPSDataType, MPSShape, exportDataWithCommandBuffer

using CEnum
using ObjectiveC, .Foundation, .Dispatch

# Valid combination of input (A and B matrices) and output (C) types
# The commented type combinations work but are slower than with MPSMatrixMultiplicatiom
const MPSGRAPH_VALID_MATMUL_TYPES =
[
(Int8, Float16),
(Int8, Float32),
(Int16, Float32),
(Float16, Float16),
(Float16, Float32),
(Float32, Float32),
]

const MPSGRAPH_VALID_MATVECMUL_TYPES =
[
(Int8, Float16),
(Int8, Float32),
(Int16, Float32),
(Float16, Float16),
(Float16, Float32),
(Float32, Float32),
]

include("libmpsgraph.jl")

include("core.jl")
include("tensor.jl")
include("execution.jl")
include("operations.jl")
include("random.jl")

include("matmul.jl")

end
42 changes: 42 additions & 0 deletions lib/mpsgraphs/core.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
# Contains definitions for api from MPSGraphCore.h, MPSGraphDevice.h

## MPSGraphCore.h
# @objcwrapper MPSGraphObject <: NSObject
# @objcwrapper MPSGraphType <: MPSGraphObject

# @objcwrapper MPSGraph <: MPSGraphObject
function MPSGraph()
MPSGraph(@objc [MPSGraph new]::id{MPSGraph})
end

# @objcwrapper immutable=true MPSGraphShapedType <: MPSGraphType

# XXX: Not used yet and needs fixing
# function MPSGraphShapedType(shape::MPSShape, dataType)
# tmp = @objc [MPSGraphShapedType alloc]::id{MPSGraphShapedType}
# obj = MPSGraphShapedType(tmp)
# finalizer(release, obj)
# @objc [obj::id{MPSGraphShapedType} initWithShape:shape::id{MPSShape}
# dataType:dataType::MPSDataType]::id{MPSGraphShapedType}
# return obj
# end

## MPSGraphDevice.h
# @objcwrapper MPSGraphDevice <: MPSGraphType

function MPSGraphDevice(device::MTLDevice)
obj = @objc [MPSGraphDevice deviceWithMTLDevice:device::id{MTLDevice}]::id{MPSGraphDevice}
MPSGraphDevice(obj)
end

# @objcwrapper MPSGraphExecutionDescriptor <: MPSGraphObject

function MPSGraphExecutionDescriptor()
MPSGraphExecutionDescriptor(@objc [MPSGraphExecutionDescriptor new]::id{MPSGraphExecutionDescriptor})
end

# @objcwrapper MPSGraphCompilationDescriptor <: MPSGraphObject

function MPSGraphCompilationDescriptor()
MPSGraphCompilationDescriptor(@objc [MPSGraphCompilationDescriptor new]::id{MPSGraphCompilationDescriptor})
end
34 changes: 34 additions & 0 deletions lib/mpsgraphs/execution.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,34 @@

MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, resultsDictionary::MPSGraphTensorDataDictionary) = @inline MPS.encode!(commandBuffer, graph, feeds, nil, resultsDictionary, MPSGraphExecutionDescriptor())
function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetOperations, resultsDictionary::MPSGraphTensorDataDictionary, executionDescriptor)
@objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
feeds:feeds::id{MPSGraphTensorDataDictionary}
targetOperations:targetOperations::id{Object}
resultsDictionary:resultsDictionary::id{MPSGraphTensorDataDictionary}
executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}]::Nothing
return resultsDictionary
end

function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil, executionDescriptor=MPSGraphExecutionDescriptor())
obj = @objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
feeds:feeds::id{MPSGraphTensorDataDictionary}
targetTensors:targetTensors::id{NSArray}
targetOperations:targetOperations::id{Object}
executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}]::id{MPSGraphTensorDataDictionary}
MPSGraphTensorDataDictionary(obj)
end

function run(graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil)
obj = @objc [graph::id{MPSGraph} runWithFeeds:feeds::id{MPSGraphTensorDataDictionary}
targetTensors:targetTensors::id{NSArray}
targetOperations:targetOperations::id{Object}]::id{MPSGraphTensorDataDictionary}
MPSGraphTensorDataDictionary(obj)
end

function run(graph::MPSGraph, commandQueue::MTLCommandQueue, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray)
obj = @objc [graph::id{MPSGraph} runWithMTLCommandQueue:commandQueue::id{MTLCommandQueue}
feeds:feeds::id{MPSGraphTensorDataDictionary}
targetTensors:targetTensors::id{NSArray}
targetOperations:nil::id{Object}]::id{MPSGraphTensorDataDictionary}
MPSGraphTensorDataDictionary(obj)
end
Loading