diff --git a/examples/flopscomp.jl b/examples/flopscomp.jl new file mode 100644 index 000000000..4436cedd0 --- /dev/null +++ b/examples/flopscomp.jl @@ -0,0 +1,187 @@ + +using Metal, GPUArrays, LinearAlgebra, Printf, AppleAccelerate + +testing = (@isdefined TESTING) && TESTING + +@static if !testing + using Plots + using Plots.Measures +end + +const Ts=[ + (Int8, Float16), + (Int8, Float32), + (Int16, Float32), + (Float16, Float16), + (Float16, Float32), + (Float32, Float32), + ] + +n_gpu_cores = "??" +# Comment this out if scary. Please mention number of cores in your comment when uploading the figure +system_prof = read(`system_profiler SPDisplaysDataType`, String) +n_gpu_cores = only(match(r"Total Number of Cores:\s*(\d+)", system_prof).captures) + +PLOT_TITLE = "Matmul peakflops for $(device().name) ($n_gpu_cores GPU cores)" + +function cpupeakflops(; n::Integer=4096, + n_batch::Integer=1, + inT::DataType=Float32, + outT::DataType=inT, + ntrials::Integer=4, + verify=true) + t = Base.zeros(Float64, ntrials) + n_batch == 1 || @warn "n_batch > 1 not supported for `mul!`, running with n_batch=1" + n_batch = 1 + shape = (n, n) + for i=1:ntrials + c = zeros(outT, shape...) + a = ones(inT, shape...) + b = ones(inT, shape...) + t[i] = @elapsed mul!(c, a, b) + verify && @assert only(unique(Array(c))) == n + end + + return n_batch*2*Float64(n)^3 / minimum(t) +end +function _peakflops(f, n, n_batch, inT, outT, ntrials; verify=true) + t = Base.zeros(Float64, ntrials) + shape = n_batch == 1 ? (n, n) : (n, n, n_batch) + for i=1:ntrials + c = mtl(zeros(outT, shape...)) + a = mtl(ones(inT, shape...)) + b = mtl(ones(inT, shape...)) + t[i] = @elapsed Metal.@sync f(c, a, b) + verify && @assert only(unique(Array(c))) == n + end + + return n_batch*2*Float64(n)^3 / minimum(t) +end +function gpuarrpeakflops(; n::Integer=4096, + n_batch::Integer=1, + inT::DataType=Float32, + outT::DataType=inT, + ntrials::Integer=3, + verify=true) + n_batch == 1 || @warn "n_batch > 1 not supported for `GPUArrays.generic_matmatmul!`, running with n_batch=1" + _peakflops(n, 1, inT, outT, ntrials; verify) do c, a, b + GPUArrays.generic_matmatmul!(c, LinearAlgebra.wrap(a, 'N'), LinearAlgebra.wrap(b, 'N'), 1, 0) + end +end +function mpspeakflops(; n::Integer=4096, + n_batch::Integer=1, + inT::DataType=Float32, + outT::DataType=inT, + ntrials::Integer=3, + verify=true) + _peakflops(MPS.matmul!, n, n_batch, inT, outT, ntrials; verify) +end +function graphpeakflops(; n::Integer=4096, + n_batch::Integer=1, + inT::DataType=Float32, + outT::DataType=inT, + ntrials::Integer=3, + verify=true) + _peakflops(MPSGraphs.graph_matmul!, n, n_batch, inT, outT, ntrials; verify) +end +function anepeakflops(; kwargs...) + # VERY HACKY + newDesc = MPSGraphs.MPSGraphCompilationDescriptor() + # Use optimization level 0 to avoid operations being moved to the neural engine + newDesc.optimizationLevel = MPSGraphs.MPSGraphOptimizationLevel1 + + oldDesc = MPSGraphs._default_exec_desc[].compilationDescriptor + + MPSGraphs._default_exec_desc[].compilationDescriptor = newDesc + res = graphpeakflops(; kwargs...) + MPSGraphs._default_exec_desc[].compilationDescriptor = oldDesc + + return res +end + +function compare(Ns, Fs, inT, outT=inT; n_batch=1, ntrials) + results = Dict() + + newFs = if (outT == Float16 || (outT == Float32 && inT == Float16)) + Fs + else + filter(x -> !occursin("ANE", x[2]),Fs) + end + + for (_, info_str) in newFs + results[info_str] = Float64[] + end + + prefixstr = "\33[2K\r($inT, $outT) " + @time "$((inT, outT))" begin + for n in Ns + verify = (n < maxintfloat(outT) && (inT != Float16 || (n < maxintfloat(inT)))) + n_str = "$n: " + for (f, info_str) in newFs + print(prefixstr, n_str, info_str) + push!(results[info_str], f(; inT, outT, n, n_batch, ntrials, verify)) + GC.gc() + end + end + print("\33[2K\r") + end + return results +end + +function runcomparison(; Ns=[50, 64, 100, 128, 250, 256, 500, 512, 1000, 1024, 2000, 2048, 4000, 4096, 6000, 6144, 8000, 8192],#, 10000], + Fs=[ + (mpspeakflops, "MPS"), + (graphpeakflops, "MPSGraph"), + (anepeakflops, "MPSGraph (ANE)"), + # (gpuarrpeakflops, "GPUArrays"), + # (cpupeakflops, "CPU (AppleAccelerate)"), # Uncomment to test CPU performance + ], + n_batch=1, + ntrials=5) + res = Dict() + + for (inT, outT) in Ts + res[(inT,outT)] = (n_batch, Ns, compare(Ns, Fs, inT, outT; n_batch, ntrials)) + end + return res +end + +function plot_results(res, Fs=["MPS", "MPSGraph", "MPSGraph (ANE)"]; outpath=nothing, outtype="svg", plt_title=PLOT_TITLE) + ylim_upper = 9e12 + resplts = [] + + n_batches = [] + + for (inT, outT) in Ts + n_batch, Ns, tmpres = res[(inT,outT)] + + plt = plot(xlabel="N, n_batch=$(n_batch)", legendtitle="($inT, $outT)") + for info_str in Fs + haskey(tmpres, info_str) || continue + + flops = tmpres[info_str] + peakf = @sprintf("%.3e", maximum(flops)) + if maximum(flops) > ylim_upper + ylim_upper = maximum(flops) * 1.02 + end + plot!(plt, Ns, tmpres[info_str]; linewidth=1.5, label="$(peakf) peak: $info_str") + end + push!(resplts, plt) + push!(n_batches, n_batch) + end + + finalplot = plot(resplts...; layout=(2,3), + ylim=(0,ylim_upper), + plot_title=plt_title, + tickfonthalign=:left, + bottommargin=15pt, + size=(2000,1200)) + if !isnothing(outpath) + savefig(plot(finalplot, dpi=500), joinpath(outpath, "bench_all_$(first(n_batches)).$outtype")) + end + return finalplot +end + +if testing + runcomparison(Ns=[50, 64, 100, 128, 250, 256, 500, 512]) +end diff --git a/lib/mps/MPS.jl b/lib/mps/MPS.jl index dacc8817e..0ef1f2d0e 100644 --- a/lib/mps/MPS.jl +++ b/lib/mps/MPS.jl @@ -21,7 +21,7 @@ using BFloat16s const MtlFloat = Union{Float32, Float16} const MPSShape = NSArray#{NSNumber} -Base.convert(::Type{MPSShape}, tuple::Union{Vector{N},NTuple{N, <:Integer}}) where N = NSArray(NSNumber.(collect(tuple))) +Base.convert(::Type{MPSShape}, tuple::Union{Vector{T},NTuple{T, <:Integer}}) where T = NSArray(NSNumber.(collect(tuple))) # Valid combination of input (A and B matrices) and output (C) types const MPS_VALID_MATMUL_TYPES = diff --git a/lib/mps/ndarray.jl b/lib/mps/ndarray.jl index 6fe2418a0..f1d16e531 100644 --- a/lib/mps/ndarray.jl +++ b/lib/mps/ndarray.jl @@ -7,6 +7,8 @@ export MPSNDArrayDescriptor # @objcwrapper immutable=false MPSNDArrayDescriptor <: NSObject function MPSNDArrayDescriptor(dataType::DataType, dimensionCount, dimensionSizes::Ptr) + 1 <= dimensionCount <= 16 || throw(ArgumentError("`dimensionCount` must be between 1 and 16 inclusive")) + desc = @objc [MPSNDArrayDescriptor descriptorWithDataType:dataType::MPSDataType dimensionCount:dimensionCount::NSUInteger dimensionSizes:dimensionSizes::Ptr{NSUInteger}]::id{MPSNDArrayDescriptor} diff --git a/lib/mpsgraphs/MPSGraphs.jl b/lib/mpsgraphs/MPSGraphs.jl new file mode 100644 index 000000000..2a06ae140 --- /dev/null +++ b/lib/mpsgraphs/MPSGraphs.jl @@ -0,0 +1,51 @@ +""" +# MPSGraphs + +`MPSGraphs` is where the Metal Performance Shaders Graph API wrappers are defined. + +Not all functionality is currently implemented or documented. For further details, +refer to the [official Apple documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph). +""" +module MPSGraphs + +using ..Metal +using .MTL +using .MPS +using .MPS: MPSDataType, MPSShape, exportDataWithCommandBuffer + +using CEnum +using ObjectiveC, .Foundation, .Dispatch + +# Valid combination of input (A and B matrices) and output (C) types +# The commented type combinations work but are slower than with MPSMatrixMultiplicatiom +const MPSGRAPH_VALID_MATMUL_TYPES = + [ + (Int8, Float16), + (Int8, Float32), + (Int16, Float32), + (Float16, Float16), + (Float16, Float32), + (Float32, Float32), + ] + +const MPSGRAPH_VALID_MATVECMUL_TYPES = + [ + (Int8, Float16), + (Int8, Float32), + (Int16, Float32), + (Float16, Float16), + (Float16, Float32), + (Float32, Float32), + ] + +include("libmpsgraph.jl") + +include("core.jl") +include("tensor.jl") +include("execution.jl") +include("operations.jl") +include("random.jl") + +include("matmul.jl") + +end diff --git a/lib/mpsgraphs/core.jl b/lib/mpsgraphs/core.jl new file mode 100644 index 000000000..2f2e868e7 --- /dev/null +++ b/lib/mpsgraphs/core.jl @@ -0,0 +1,42 @@ +# Contains definitions for api from MPSGraphCore.h, MPSGraphDevice.h + +## MPSGraphCore.h +# @objcwrapper MPSGraphObject <: NSObject +# @objcwrapper MPSGraphType <: MPSGraphObject + +# @objcwrapper MPSGraph <: MPSGraphObject +function MPSGraph() + MPSGraph(@objc [MPSGraph new]::id{MPSGraph}) +end + +# @objcwrapper immutable=true MPSGraphShapedType <: MPSGraphType + +# XXX: Not used yet and needs fixing +# function MPSGraphShapedType(shape::MPSShape, dataType) +# tmp = @objc [MPSGraphShapedType alloc]::id{MPSGraphShapedType} +# obj = MPSGraphShapedType(tmp) +# finalizer(release, obj) +# @objc [obj::id{MPSGraphShapedType} initWithShape:shape::id{MPSShape} +# dataType:dataType::MPSDataType]::id{MPSGraphShapedType} +# return obj +# end + +## MPSGraphDevice.h +# @objcwrapper MPSGraphDevice <: MPSGraphType + +function MPSGraphDevice(device::MTLDevice) + obj = @objc [MPSGraphDevice deviceWithMTLDevice:device::id{MTLDevice}]::id{MPSGraphDevice} + MPSGraphDevice(obj) +end + +# @objcwrapper MPSGraphExecutionDescriptor <: MPSGraphObject + +function MPSGraphExecutionDescriptor() + MPSGraphExecutionDescriptor(@objc [MPSGraphExecutionDescriptor new]::id{MPSGraphExecutionDescriptor}) +end + +# @objcwrapper MPSGraphCompilationDescriptor <: MPSGraphObject + +function MPSGraphCompilationDescriptor() + MPSGraphCompilationDescriptor(@objc [MPSGraphCompilationDescriptor new]::id{MPSGraphCompilationDescriptor}) +end diff --git a/lib/mpsgraphs/execution.jl b/lib/mpsgraphs/execution.jl new file mode 100644 index 000000000..9eaf82df9 --- /dev/null +++ b/lib/mpsgraphs/execution.jl @@ -0,0 +1,34 @@ + +MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, resultsDictionary::MPSGraphTensorDataDictionary) = @inline MPS.encode!(commandBuffer, graph, feeds, nil, resultsDictionary, MPSGraphExecutionDescriptor()) +function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetOperations, resultsDictionary::MPSGraphTensorDataDictionary, executionDescriptor) + @objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer} + feeds:feeds::id{MPSGraphTensorDataDictionary} + targetOperations:targetOperations::id{Object} + resultsDictionary:resultsDictionary::id{MPSGraphTensorDataDictionary} + executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}]::Nothing + return resultsDictionary +end + +function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil, executionDescriptor=MPSGraphExecutionDescriptor()) + obj = @objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer} + feeds:feeds::id{MPSGraphTensorDataDictionary} + targetTensors:targetTensors::id{NSArray} + targetOperations:targetOperations::id{Object} + executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}]::id{MPSGraphTensorDataDictionary} + MPSGraphTensorDataDictionary(obj) +end + +function run(graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil) + obj = @objc [graph::id{MPSGraph} runWithFeeds:feeds::id{MPSGraphTensorDataDictionary} + targetTensors:targetTensors::id{NSArray} + targetOperations:targetOperations::id{Object}]::id{MPSGraphTensorDataDictionary} + MPSGraphTensorDataDictionary(obj) +end + +function run(graph::MPSGraph, commandQueue::MTLCommandQueue, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray) + obj = @objc [graph::id{MPSGraph} runWithMTLCommandQueue:commandQueue::id{MTLCommandQueue} + feeds:feeds::id{MPSGraphTensorDataDictionary} + targetTensors:targetTensors::id{NSArray} + targetOperations:nil::id{Object}]::id{MPSGraphTensorDataDictionary} + MPSGraphTensorDataDictionary(obj) +end diff --git a/lib/mpsgraphs/libmpsgraph.jl b/lib/mpsgraphs/libmpsgraph.jl new file mode 100644 index 000000000..97042cca6 --- /dev/null +++ b/lib/mpsgraphs/libmpsgraph.jl @@ -0,0 +1,460 @@ +# This file is automatically generated. Do not edit! +# To re-generate, execute res/wrap/wrap.jl + +using CEnum: CEnum, @cenum + +const MPSGraphCallableMap = NSDictionary + +const MPSGraphTensorDataDictionary = NSDictionary + + +@objcwrapper immutable = true availability = macos(v"14.0.0") MPSGraphObject <: NSObject + +@objcwrapper immutable = true MPSGraphType <: MPSGraphObject + +@objcwrapper immutable = true MPSGraphShapedType <: MPSGraphType + +@objcproperties MPSGraphShapedType begin + @autoproperty shape::id{MPSShape} setter = setShape + @autoproperty dataType::MPSDataType setter = setDataType +end + +@cenum MPSGraphTensorNamedDataLayout::UInt64 begin + MPSGraphTensorNamedDataLayoutNCHW = 0x0000000000000000 + MPSGraphTensorNamedDataLayoutNHWC = 0x0000000000000001 + MPSGraphTensorNamedDataLayoutOIHW = 0x0000000000000002 + MPSGraphTensorNamedDataLayoutHWIO = 0x0000000000000003 + MPSGraphTensorNamedDataLayoutCHW = 0x0000000000000004 + MPSGraphTensorNamedDataLayoutHWC = 0x0000000000000005 + MPSGraphTensorNamedDataLayoutHW = 0x0000000000000006 + MPSGraphTensorNamedDataLayoutNCDHW = 0x0000000000000007 + MPSGraphTensorNamedDataLayoutNDHWC = 0x0000000000000008 + MPSGraphTensorNamedDataLayoutOIDHW = 0x0000000000000009 + MPSGraphTensorNamedDataLayoutDHWIO = 0x000000000000000a +end + +@cenum MPSGraphPaddingStyle::UInt64 begin + MPSGraphPaddingStyleExplicit = 0x0000000000000000 + MPSGraphPaddingStyleTF_VALID = 0x0000000000000001 + MPSGraphPaddingStyleTF_SAME = 0x0000000000000002 + MPSGraphPaddingStyleExplicitOffset = 0x0000000000000003 + MPSGraphPaddingStyleONNX_SAME_LOWER = 0x0000000000000004 +end + +@cenum MPSGraphPaddingMode::Int64 begin + MPSGraphPaddingModeConstant = 0 + MPSGraphPaddingModeReflect = 1 + MPSGraphPaddingModeSymmetric = 2 + MPSGraphPaddingModeClampToEdge = 3 + MPSGraphPaddingModeZero = 4 + MPSGraphPaddingModePeriodic = 5 + MPSGraphPaddingModeAntiPeriodic = 6 +end + +@cenum MPSGraphReductionMode::UInt64 begin + MPSGraphReductionModeMin = 0x0000000000000000 + MPSGraphReductionModeMax = 0x0000000000000001 + MPSGraphReductionModeSum = 0x0000000000000002 + MPSGraphReductionModeProduct = 0x0000000000000003 + MPSGraphReductionModeArgumentMin = 0x0000000000000004 + MPSGraphReductionModeArgumentMax = 0x0000000000000005 +end + +@cenum MPSGraphDeviceType::UInt32 begin + MPSGraphDeviceTypeMetal = 0x0000000000000000 +end + +@objcwrapper immutable = true MPSGraphDevice <: MPSGraphObject + +@objcproperties MPSGraphDevice begin + @autoproperty type::MPSGraphDeviceType + @autoproperty metalDevice::id{MTLDevice} +end + +@cenum MPSGraphOptions::UInt64 begin + MPSGraphOptionsNone = 0x0000000000000000 + MPSGraphOptionsSynchronizeResults = 0x0000000000000001 + MPSGraphOptionsVerbose = 0x0000000000000002 + MPSGraphOptionsDefault = 0x0000000000000001 +end + +@objcwrapper immutable = true MPSGraph <: MPSGraphObject + +@objcproperties MPSGraph begin + @autoproperty options::MPSGraphOptions setter = setOptions + @autoproperty placeholderTensors::id{NSArray} type = Vector{MPSGraphTensor} +end + +@objcwrapper immutable = true MPSGraphOperation <: MPSGraphObject + +@objcproperties MPSGraphOperation begin + @autoproperty inputTensors::id{NSArray} type = Vector{MPSGraphTensor} + @autoproperty outputTensors::id{NSArray} type = Vector{MPSGraphTensor} + @autoproperty controlDependencies::id{NSArray} type = Vector{MPSGraphOperation} + @autoproperty graph::id{MPSGraph} + @autoproperty name::id{NSString} +end + +@objcwrapper immutable = true MPSGraphTensor <: MPSGraphObject + +@objcproperties MPSGraphTensor begin + @autoproperty shape::id{MPSShape} + @autoproperty dataType::MPSDataType + @autoproperty operation::id{MPSGraphOperation} +end + +@objcwrapper immutable = false MPSGraphTensorData <: MPSGraphObject + +@objcproperties MPSGraphTensorData begin + @autoproperty shape::id{MPSShape} + @autoproperty dataType::MPSDataType + @autoproperty device::id{MPSGraphDevice} +end + +@cenum MPSGraphOptimization::UInt64 begin + MPSGraphOptimizationLevel0 = 0x0000000000000000 + MPSGraphOptimizationLevel1 = 0x0000000000000001 +end + +@cenum MPSGraphOptimizationProfile::UInt64 begin + MPSGraphOptimizationProfilePerformance = 0x0000000000000000 + MPSGraphOptimizationProfilePowerEfficiency = 0x0000000000000001 +end + +@cenum MPSGraphExecutionStage::UInt64 begin + MPSGraphExecutionStageCompleted = 0x0000000000000000 +end + +@objcwrapper immutable = true MPSGraphCompilationDescriptor <: MPSGraphObject + +@objcproperties MPSGraphCompilationDescriptor begin + @autoproperty optimizationLevel::MPSGraphOptimization setter = setOptimizationLevel + @autoproperty waitForCompilationCompletion::Bool setter = setWaitForCompilationCompletion + #= setter = setCompilationCompletionHandler:Skipping property compilationCompletionHandler because it is a CLBlockPointer void (^)(MPSGraphExecutable *, NSError *) =# + @autoproperty dispatchQueue::id{dispatch_queue_t} setter = setDispatchQueue + @autoproperty optimizationProfile::MPSGraphOptimizationProfile setter = setOptimizationProfile + @autoproperty callables::id{MPSGraphCallableMap} setter = setCallables availability = macos(v"14.1.0") +end + +@objcwrapper immutable = true MPSGraphExecutionDescriptor <: MPSGraphObject + +@objcproperties MPSGraphExecutionDescriptor begin + #= setter = setScheduledHandler:Skipping property scheduledHandler because it is a CLBlockPointer void (^)(NSDictionary *, NSError *) =# + #= setter = setCompletionHandler:Skipping property completionHandler because it is a CLBlockPointer void (^)(NSDictionary *, NSError *) =# + @autoproperty waitUntilCompleted::Bool setter = setWaitUntilCompleted + @autoproperty compilationDescriptor::id{MPSGraphCompilationDescriptor} setter = setCompilationDescriptor +end + +@objcwrapper immutable = true MPSGraphExecutableExecutionDescriptor <: MPSGraphObject + +@objcproperties MPSGraphExecutableExecutionDescriptor begin + #= setter = setScheduledHandler:Skipping property scheduledHandler because it is a CLBlockPointer void (^)(NSArray *, NSError *) =# + #= setter = setCompletionHandler:Skipping property completionHandler because it is a CLBlockPointer void (^)(NSArray *, NSError *) =# + @autoproperty waitUntilCompleted::Bool setter = setWaitUntilCompleted +end + +@cenum MPSGraphDeploymentPlatform::UInt64 begin + MPSGraphDeploymentPlatformMacOS = 0x0000000000000000 + MPSGraphDeploymentPlatformIOS = 0x0000000000000001 + MPSGraphDeploymentPlatformTvOS = 0x0000000000000002 + MPSGraphDeploymentPlatformVisionOS = 0x0000000000000003 +end + +@objcwrapper immutable = true availability = macos(v"14.0.0") MPSGraphExecutableSerializationDescriptor <: MPSGraphObject + +@objcproperties MPSGraphExecutableSerializationDescriptor begin + @autoproperty append::Bool setter = setAppend + @autoproperty deploymentPlatform::MPSGraphDeploymentPlatform setter = setDeploymentPlatform + @autoproperty minimumDeploymentTarget::id{NSString} setter = setMinimumDeploymentTarget +end + +@objcwrapper immutable = true MPSGraphExecutable <: MPSGraphObject + +@objcproperties MPSGraphExecutable begin + @autoproperty options::MPSGraphOptions setter = setOptions + @autoproperty feedTensors::id{NSArray} type = Vector{MPSGraphTensor} + @autoproperty targetTensors::id{NSArray} type = Vector{MPSGraphTensor} +end + +@objcwrapper immutable = true MPSGraphConvolution2DOpDescriptor <: MPSGraphObject + +@objcproperties MPSGraphConvolution2DOpDescriptor begin + @autoproperty strideInX::UInt64 setter = setStrideInX + @autoproperty strideInY::UInt64 setter = setStrideInY + @autoproperty dilationRateInX::UInt64 setter = setDilationRateInX + @autoproperty dilationRateInY::UInt64 setter = setDilationRateInY + @autoproperty paddingLeft::UInt64 setter = setPaddingLeft + @autoproperty paddingRight::UInt64 setter = setPaddingRight + @autoproperty paddingTop::UInt64 setter = setPaddingTop + @autoproperty paddingBottom::UInt64 setter = setPaddingBottom + @autoproperty paddingStyle::MPSGraphPaddingStyle setter = setPaddingStyle + @autoproperty dataLayout::MPSGraphTensorNamedDataLayout setter = setDataLayout + @autoproperty weightsLayout::MPSGraphTensorNamedDataLayout setter = setWeightsLayout + @autoproperty groups::UInt64 setter = setGroups +end + +@objcwrapper immutable = true availability = macos(v"13.2.0") MPSGraphConvolution3DOpDescriptor <: MPSGraphObject + +@objcproperties MPSGraphConvolution3DOpDescriptor begin + @autoproperty strideInX::UInt64 setter = setStrideInX + @autoproperty strideInY::UInt64 setter = setStrideInY + @autoproperty strideInZ::UInt64 setter = setStrideInZ + @autoproperty dilationRateInX::UInt64 setter = setDilationRateInX + @autoproperty dilationRateInY::UInt64 setter = setDilationRateInY + @autoproperty dilationRateInZ::UInt64 setter = setDilationRateInZ + @autoproperty paddingLeft::UInt64 setter = setPaddingLeft + @autoproperty paddingRight::UInt64 setter = setPaddingRight + @autoproperty paddingTop::UInt64 setter = setPaddingTop + @autoproperty paddingBottom::UInt64 setter = setPaddingBottom + @autoproperty paddingFront::UInt64 setter = setPaddingFront + @autoproperty paddingBack::UInt64 setter = setPaddingBack + @autoproperty paddingStyle::MPSGraphPaddingStyle setter = setPaddingStyle + @autoproperty dataLayout::MPSGraphTensorNamedDataLayout setter = setDataLayout + @autoproperty weightsLayout::MPSGraphTensorNamedDataLayout setter = setWeightsLayout + @autoproperty groups::UInt64 setter = setGroups +end + +@objcwrapper immutable = true MPSGraphDepthwiseConvolution2DOpDescriptor <: MPSGraphObject + +@objcproperties MPSGraphDepthwiseConvolution2DOpDescriptor begin + @autoproperty strideInX::UInt64 setter = setStrideInX + @autoproperty strideInY::UInt64 setter = setStrideInY + @autoproperty dilationRateInX::UInt64 setter = setDilationRateInX + @autoproperty dilationRateInY::UInt64 setter = setDilationRateInY + @autoproperty paddingLeft::UInt64 setter = setPaddingLeft + @autoproperty paddingRight::UInt64 setter = setPaddingRight + @autoproperty paddingTop::UInt64 setter = setPaddingTop + @autoproperty paddingBottom::UInt64 setter = setPaddingBottom + @autoproperty paddingStyle::MPSGraphPaddingStyle setter = setPaddingStyle + @autoproperty dataLayout::MPSGraphTensorNamedDataLayout setter = setDataLayout + @autoproperty weightsLayout::MPSGraphTensorNamedDataLayout setter = setWeightsLayout +end + +@objcwrapper immutable = true MPSGraphDepthwiseConvolution3DOpDescriptor <: MPSGraphObject + +@objcproperties MPSGraphDepthwiseConvolution3DOpDescriptor begin + @autoproperty strides::id{NSArray} type = Vector{NSNumber} setter = setStrides + @autoproperty dilationRates::id{NSArray} type = Vector{NSNumber} setter = setDilationRates + @autoproperty paddingValues::id{NSArray} type = Vector{NSNumber} setter = setPaddingValues + @autoproperty paddingStyle::MPSGraphPaddingStyle setter = setPaddingStyle + @autoproperty channelDimensionIndex::Int64 setter = setChannelDimensionIndex +end + +@cenum MPSGraphFFTScalingMode::UInt64 begin + MPSGraphFFTScalingModeNone = 0x0000000000000000 + MPSGraphFFTScalingModeSize = 0x0000000000000001 + MPSGraphFFTScalingModeUnitary = 0x0000000000000002 +end + +@objcwrapper immutable = true availability = macos(v"14.0.0") MPSGraphFFTDescriptor <: MPSGraphObject + +@objcproperties MPSGraphFFTDescriptor begin + @autoproperty inverse::Bool setter = setInverse + @autoproperty scalingMode::MPSGraphFFTScalingMode setter = setScalingMode + @autoproperty roundToOddHermitean::Bool setter = setRoundToOddHermitean +end + +@objcwrapper immutable = true availability = macos(v"14.0.0") MPSGraphImToColOpDescriptor <: MPSGraphObject + +@objcproperties MPSGraphImToColOpDescriptor begin + @autoproperty kernelWidth::UInt64 setter = setKernelWidth + @autoproperty kernelHeight::UInt64 setter = setKernelHeight + @autoproperty strideInX::UInt64 setter = setStrideInX + @autoproperty strideInY::UInt64 setter = setStrideInY + @autoproperty dilationRateInX::UInt64 setter = setDilationRateInX + @autoproperty dilationRateInY::UInt64 setter = setDilationRateInY + @autoproperty paddingLeft::UInt64 setter = setPaddingLeft + @autoproperty paddingRight::UInt64 setter = setPaddingRight + @autoproperty paddingTop::UInt64 setter = setPaddingTop + @autoproperty paddingBottom::UInt64 setter = setPaddingBottom + @autoproperty dataLayout::MPSGraphTensorNamedDataLayout setter = setDataLayout +end + +@cenum MPSGraphLossReductionType::UInt64 begin + MPSGraphLossReductionTypeNone = 0x0000000000000000 + MPSGraphLossReductionTypeAxis = 0x0000000000000000 + MPSGraphLossReductionTypeSum = 0x0000000000000001 + MPSGraphLossReductionTypeMean = 0x0000000000000002 +end + +@objcwrapper immutable = true MPSGraphVariableOp <: MPSGraphOperation + +@objcproperties MPSGraphVariableOp begin + @autoproperty shape::id{MPSShape} + @autoproperty dataType::MPSDataType +end + +@cenum MPSGraphNonMaximumSuppressionCoordinateMode::UInt64 begin + MPSGraphNonMaximumSuppressionCoordinateModeCornersHeightFirst = 0x0000000000000000 + MPSGraphNonMaximumSuppressionCoordinateModeCornersWidthFirst = 0x0000000000000001 + MPSGraphNonMaximumSuppressionCoordinateModeCentersHeightFirst = 0x0000000000000002 + MPSGraphNonMaximumSuppressionCoordinateModeCentersWidthFirst = 0x0000000000000003 +end + +@cenum MPSGraphPoolingReturnIndicesMode::UInt64 begin + MPSGraphPoolingReturnIndicesNone = 0x0000000000000000 + MPSGraphPoolingReturnIndicesGlobalFlatten1D = 0x0000000000000001 + MPSGraphPoolingReturnIndicesGlobalFlatten2D = 0x0000000000000002 + MPSGraphPoolingReturnIndicesGlobalFlatten3D = 0x0000000000000003 + MPSGraphPoolingReturnIndicesGlobalFlatten4D = 0x0000000000000004 + MPSGraphPoolingReturnIndicesLocalFlatten1D = 0x0000000000000005 + MPSGraphPoolingReturnIndicesLocalFlatten2D = 0x0000000000000006 + MPSGraphPoolingReturnIndicesLocalFlatten3D = 0x0000000000000007 + MPSGraphPoolingReturnIndicesLocalFlatten4D = 0x0000000000000008 +end + +@objcwrapper immutable = true MPSGraphPooling2DOpDescriptor <: MPSGraphObject + +@objcproperties MPSGraphPooling2DOpDescriptor begin + @autoproperty kernelWidth::UInt64 setter = setKernelWidth + @autoproperty kernelHeight::UInt64 setter = setKernelHeight + @autoproperty strideInX::UInt64 setter = setStrideInX + @autoproperty strideInY::UInt64 setter = setStrideInY + @autoproperty dilationRateInX::UInt64 setter = setDilationRateInX + @autoproperty dilationRateInY::UInt64 setter = setDilationRateInY + @autoproperty paddingLeft::UInt64 setter = setPaddingLeft + @autoproperty paddingRight::UInt64 setter = setPaddingRight + @autoproperty paddingTop::UInt64 setter = setPaddingTop + @autoproperty paddingBottom::UInt64 setter = setPaddingBottom + @autoproperty paddingStyle::MPSGraphPaddingStyle setter = setPaddingStyle + @autoproperty dataLayout::MPSGraphTensorNamedDataLayout setter = setDataLayout + @autoproperty returnIndicesMode::MPSGraphPoolingReturnIndicesMode setter = setReturnIndicesMode + @autoproperty returnIndicesDataType::MPSDataType setter = setReturnIndicesDataType + @autoproperty ceilMode::Bool setter = setCeilMode + @autoproperty includeZeroPadToAverage::Bool setter = setIncludeZeroPadToAverage +end + +@objcwrapper immutable = true MPSGraphPooling4DOpDescriptor <: MPSGraphObject + +@objcproperties MPSGraphPooling4DOpDescriptor begin + @autoproperty kernelSizes::id{NSArray} type = Vector{NSNumber} setter = setKernelSizes + @autoproperty strides::id{NSArray} type = Vector{NSNumber} setter = setStrides + @autoproperty dilationRates::id{NSArray} type = Vector{NSNumber} setter = setDilationRates + @autoproperty paddingValues::id{NSArray} type = Vector{NSNumber} setter = setPaddingValues + @autoproperty paddingStyle::MPSGraphPaddingStyle setter = setPaddingStyle + @autoproperty ceilMode::Bool setter = setCeilMode + @autoproperty includeZeroPadToAverage::Bool setter = setIncludeZeroPadToAverage + @autoproperty returnIndicesMode::MPSGraphPoolingReturnIndicesMode setter = setReturnIndicesMode + @autoproperty returnIndicesDataType::MPSDataType setter = setReturnIndicesDataType +end + +@cenum MPSGraphRandomDistribution::UInt64 begin + MPSGraphRandomDistributionUniform = 0x0000000000000000 + MPSGraphRandomDistributionNormal = 0x0000000000000001 + MPSGraphRandomDistributionTruncatedNormal = 0x0000000000000002 +end + +@cenum MPSGraphRandomNormalSamplingMethod::UInt64 begin + MPSGraphRandomNormalSamplingInvCDF = 0x0000000000000000 + MPSGraphRandomNormalSamplingBoxMuller = 0x0000000000000001 +end + +@objcwrapper immutable = true MPSGraphRandomOpDescriptor <: MPSGraphObject + +@objcproperties MPSGraphRandomOpDescriptor begin + @autoproperty distribution::MPSGraphRandomDistribution setter = setDistribution + @autoproperty dataType::MPSDataType setter = setDataType + @autoproperty min::Cfloat setter = setMin + @autoproperty max::Cfloat setter = setMax + @autoproperty minInteger::Int64 setter = setMinInteger + @autoproperty maxInteger::Int64 setter = setMaxInteger + @autoproperty mean::Cfloat setter = setMean + @autoproperty standardDeviation::Cfloat setter = setStandardDeviation + @autoproperty samplingMethod::MPSGraphRandomNormalSamplingMethod setter = setSamplingMethod +end + +@cenum MPSGraphResizeMode::UInt64 begin + MPSGraphResizeNearest = 0x0000000000000000 + MPSGraphResizeBilinear = 0x0000000000000001 +end + +@cenum MPSGraphResizeNearestRoundingMode::UInt64 begin + MPSGraphResizeNearestRoundingModeRoundPreferCeil = 0x0000000000000000 + MPSGraphResizeNearestRoundingModeRoundPreferFloor = 0x0000000000000001 + MPSGraphResizeNearestRoundingModeCeil = 0x0000000000000002 + MPSGraphResizeNearestRoundingModeFloor = 0x0000000000000003 + MPSGraphResizeNearestRoundingModeRoundToEven = 0x0000000000000004 + MPSGraphResizeNearestRoundingModeRoundToOdd = 0x0000000000000005 +end + +@cenum MPSGraphRNNActivation::UInt64 begin + MPSGraphRNNActivationNone = 0x0000000000000000 + MPSGraphRNNActivationRelu = 0x0000000000000001 + MPSGraphRNNActivationTanh = 0x0000000000000002 + MPSGraphRNNActivationSigmoid = 0x0000000000000003 + MPSGraphRNNActivationHardSigmoid = 0x0000000000000004 +end + +@objcwrapper immutable = true MPSGraphSingleGateRNNDescriptor <: MPSGraphObject + +@objcproperties MPSGraphSingleGateRNNDescriptor begin + @autoproperty reverse::Bool setter = setReverse + @autoproperty bidirectional::Bool setter = setBidirectional + @autoproperty training::Bool setter = setTraining + @autoproperty activation::MPSGraphRNNActivation setter = setActivation +end + +@objcwrapper immutable = true MPSGraphLSTMDescriptor <: MPSGraphObject + +@objcproperties MPSGraphLSTMDescriptor begin + @autoproperty reverse::Bool setter = setReverse + @autoproperty bidirectional::Bool setter = setBidirectional + @autoproperty produceCell::Bool setter = setProduceCell + @autoproperty training::Bool setter = setTraining + @autoproperty forgetGateLast::Bool setter = setForgetGateLast + @autoproperty inputGateActivation::MPSGraphRNNActivation setter = setInputGateActivation + @autoproperty forgetGateActivation::MPSGraphRNNActivation setter = setForgetGateActivation + @autoproperty cellGateActivation::MPSGraphRNNActivation setter = setCellGateActivation + @autoproperty outputGateActivation::MPSGraphRNNActivation setter = setOutputGateActivation + @autoproperty activation::MPSGraphRNNActivation setter = setActivation +end + +@objcwrapper immutable = true MPSGraphGRUDescriptor <: MPSGraphObject + +@objcproperties MPSGraphGRUDescriptor begin + @autoproperty reverse::Bool setter = setReverse + @autoproperty bidirectional::Bool setter = setBidirectional + @autoproperty training::Bool setter = setTraining + @autoproperty resetGateFirst::Bool setter = setResetGateFirst + @autoproperty resetAfter::Bool setter = setResetAfter + @autoproperty flipZ::Bool setter = setFlipZ + @autoproperty updateGateActivation::MPSGraphRNNActivation setter = setUpdateGateActivation + @autoproperty resetGateActivation::MPSGraphRNNActivation setter = setResetGateActivation + @autoproperty outputGateActivation::MPSGraphRNNActivation setter = setOutputGateActivation +end + +@cenum MPSGraphScatterMode::Int64 begin + MPSGraphScatterModeAdd = 0 + MPSGraphScatterModeSub = 1 + MPSGraphScatterModeMul = 2 + MPSGraphScatterModeDiv = 3 + MPSGraphScatterModeMin = 4 + MPSGraphScatterModeMax = 5 + MPSGraphScatterModeSet = 6 +end + +@cenum MPSGraphSparseStorageType::UInt64 begin + MPSGraphSparseStorageCOO = 0x0000000000000000 + MPSGraphSparseStorageCSC = 0x0000000000000001 + MPSGraphSparseStorageCSR = 0x0000000000000002 +end + +@objcwrapper immutable = true MPSGraphCreateSparseOpDescriptor <: MPSGraphObject + +@objcproperties MPSGraphCreateSparseOpDescriptor begin + @autoproperty sparseStorageType::MPSGraphSparseStorageType setter = setSparseStorageType + @autoproperty dataType::MPSDataType setter = setDataType +end + +@objcwrapper immutable = true MPSGraphStencilOpDescriptor <: MPSGraphObject + +@objcproperties MPSGraphStencilOpDescriptor begin + @autoproperty reductionMode::MPSGraphReductionMode setter = setReductionMode + @autoproperty offsets::id{MPSShape} setter = setOffsets + @autoproperty strides::id{MPSShape} setter = setStrides + @autoproperty dilationRates::id{MPSShape} setter = setDilationRates + @autoproperty explicitPadding::id{MPSShape} setter = setExplicitPadding + @autoproperty boundaryMode::MPSGraphPaddingMode setter = setBoundaryMode + @autoproperty paddingStyle::MPSGraphPaddingStyle setter = setPaddingStyle + @autoproperty paddingConstant::Cfloat setter = setPaddingConstant +end diff --git a/lib/mpsgraphs/matmul.jl b/lib/mpsgraphs/matmul.jl new file mode 100644 index 000000000..0cd3fe081 --- /dev/null +++ b/lib/mpsgraphs/matmul.jl @@ -0,0 +1,102 @@ +#= +Creates a default MPSGraphExecutionDescriptor with a MPSGraphCompilationDescriptor + set to use optimization level 0 instead of 1. This is because level 1 causes operations + on eltypes <= 16 bytes to be executed on the ANE instead of the GPU, leading to worse + performance and hangs when the matrices are too big +=# +@static if isdefined(Base, :OncePerProcess) # VERSION >= v"1.12.0-DEV.1421" + const default_exec_desc = OncePerProcess{MPSGraphExecutionDescriptor}() do + compDesc = MPSGraphCompilationDescriptor() + # Use optimization level 0 to avoid operations being moved to the neural engine + compDesc.optimizationLevel = MPSGraphOptimizationLevel0 + + execDesc = MPSGraphExecutionDescriptor() + execDesc.compilationDescriptor = compDesc + execDesc + end +else + const _default_exec_desc::Ref{MPSGraphExecutionDescriptor} = Ref{MPSGraphExecutionDescriptor}() + function default_exec_desc() + if !isassigned(_default_exec_desc) + compDesc = MPSGraphCompilationDescriptor() + # Use optimization level 0 to avoid operations being moved to the neural engine + compDesc.optimizationLevel = MPSGraphOptimizationLevel0 + + _default_exec_desc[] = MPSGraphExecutionDescriptor() + _default_exec_desc[].compilationDescriptor = compDesc + end + _default_exec_desc[] + end +end + + +@autoreleasepool function _matmul!(c::MtlArray{Tc}, a::MtlArray{Tab, Na}, b::MtlArray{Tab, Nb}, + alpha::Number, beta::Number, + transpose_a, transpose_b) where {Tc, Tab, Na, Nb} + graph = MPSGraph() + + placeA = placeholderTensor(graph, size(a), Tab) + placeB = placeholderTensor(graph, size(b), Tab) + placeC = placeholderTensor(graph, size(c), Tc) + + feeds = Dict{MPSGraphTensor, MPSGraphTensorData}( + placeA => MPSGraphTensorData(a), + placeB => MPSGraphTensorData(b), + placeC => MPSGraphTensorData(c) + ) + + # cast to output eltype if input type is an integer type + castT = Tab <: Integer ? Tc : Tab + castA = castTensor(graph, placeA, castT, "castA") + castB = castTensor(graph, placeB, castT, "castB") + + transA = transpose_a ? transposeTensor(graph, castA, Na-2, Na-1, "transpose_a") : castA + transB = transpose_b ? transposeTensor(graph, castB, Nb-2, Nb-1, "transpose_b") : castB + + nBatchA = Na == 2 ? 1 : size(transA)[1] + nBatchB = Nb == 2 ? 1 : size(transB)[1] + + # for batched matmul between different sized tensors + broadcastA, broadcastB = if nBatchA == nBatchB + transA, transB + elseif Na == 1 + broadcastTensor(graph, transA, convert(MPSShape, [nBatchB, size(transA)[2:end]...])), transB + elseif Nb == 1 + transA, broadcastTensor(graph, transB, convert(MPSShape, [nBatchA, size(transB)[2:end]...])) + else + transA, transB + end + + matmul = matrixMultiplicationWithPrimaryTensor(graph, broadcastB, broadcastA) + + afteralpha = let alphatensor = constantWithScalar(graph, alpha, castT) + multiplicationWithPrimaryTensor(graph, alphatensor, matmul) + end + + afterbeta = let betatensor = constantWithScalar(graph, beta, castT) + castplaceC = castTensor(graph, placeC, castT, "castplaceC") + betaC = multiplicationWithPrimaryTensor(graph, betatensor, castplaceC) + afterbeta = additionWithPrimaryTensor(graph, afteralpha, betaC) + end + + castC = castTensor(graph, afterbeta, Tc, "castC") + + resultdict = Dict{MPSGraphTensor, MPSGraphTensorData}( + castC => feeds[placeC] + ) + + cmdbuf = MPSCommandBuffer(Metal.global_queue(device())) + encode!(cmdbuf, graph, NSDictionary(feeds), nil, NSDictionary(resultdict), default_exec_desc()) + commit!(cmdbuf) + wait_completed(cmdbuf) + + return c +end + +function graph_matmul!(c::MtlArray{Tc, N}, a::MtlArray{Tab, N}, b::MtlArray{Tab, N}, alpha::Number = true, beta::Number = false, transpose_a = false, transpose_b = false) where {Tc, Tab, N} + _matmul!(c, a, b, alpha, beta, transpose_a, transpose_b) +end + +function graph_matvecmul!(c::MtlVector{Tc}, a::MtlMatrix{Tab}, b::MtlVector{Tab}, alpha::Number = true, beta::Number = false, transpose = false) where {Tc, Tab} + _matmul!(c, a, b, alpha, beta, transpose, false) +end diff --git a/lib/mpsgraphs/operations.jl b/lib/mpsgraphs/operations.jl new file mode 100644 index 000000000..107c9ae31 --- /dev/null +++ b/lib/mpsgraphs/operations.jl @@ -0,0 +1,76 @@ + +function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shape::MPSShape, name="broadcast") + obj = @objc [graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor} + toShape:shape::id{MPSShape} + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end +function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shapeTensor::MPSGraphTensor, name="broadcast") + obj = @objc [graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor} + toShapeTensor:shapeTensor::id{MPSGraphTensor} + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end + +function castTensor(graph::MPSGraph, tensor::MPSGraphTensor, toType, name = "cast") + obj = @objc [graph::id{MPSGraph} castTensor:tensor::id{MPSGraphTensor} + toType:toType::MPSDataType + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end + +function constantWithScalar(graph::MPSGraph, scalar::Number, dataType) + obj = @objc [graph::id{MPSGraph} constantWithScalar:scalar::Float64 + dataType:dataType::MPSDataType]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end + +function matrixMultiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "matmul") + obj = @objc [graph::id{MPSGraph} matrixMultiplicationWithPrimaryTensor:primary::id{MPSGraphTensor} + secondaryTensor:secondary::id{MPSGraphTensor} + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end + +function multiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "mul") + obj = @objc [graph::id{MPSGraph} multiplicationWithPrimaryTensor:primary::id{MPSGraphTensor} + secondaryTensor:secondary::id{MPSGraphTensor} + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end +function additionWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name = "add") + obj = @objc [graph::id{MPSGraph} additionWithPrimaryTensor:primary::id{MPSGraphTensor} + secondaryTensor:secondary::id{MPSGraphTensor} + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end + +function transposeTensor(graph::MPSGraph, tensor::MPSGraphTensor, dimension, withDimension, name = "transpose") + obj = @objc [graph::id{MPSGraph} transposeTensor:tensor::id{MPSGraphTensor} + dimension:dimension::NSUInteger + withDimension:withDimension::NSUInteger + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end + +function shapeOfTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "shapeOfTensor") + obj = @objc [graph::id{MPSGraph} shapeOfTensor:tensor::id{MPSGraphTensor} + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end + +function identityWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "identity") + obj = @objc [graph::id{MPSGraph} identityWithTensor:tensor::id{MPSGraphTensor} + name:name::id{NSString}]::id{MPSGraphTensor} + MPSGraphTensor(obj) +end + +""" + dump_graph(graph::MPSGraph) + +Dumps the `graph`. + +!!! warning + This function is undocumented from Apple so it may stop working at any time. +""" +dump_graph(graph::MPSGraph) = @objc [graph::id{MPSGraph} dump]::Nothing ## COV_EXCL_LINE diff --git a/lib/mpsgraphs/random.jl b/lib/mpsgraphs/random.jl new file mode 100644 index 000000000..b984dcae0 --- /dev/null +++ b/lib/mpsgraphs/random.jl @@ -0,0 +1,8 @@ +# @objcwrapper immutable=false MPSGraphRandomOpDescriptor <: MPSGraphObject + +function MPSGraphRandomOpDescriptor(distribution::MPSGraphRandomDistribution, dataType) + desc = @objc [MPSGraphRandomOpDescriptor descriptorWithDistribution:distribution::MPSGraphRandomDistribution + dataType:dataType::MPSDataType]::id{MPSGraphRandomOpDescriptor} + obj = MPSGraphRandomOpDescriptor(desc) + return obj +end diff --git a/lib/mpsgraphs/tensor.jl b/lib/mpsgraphs/tensor.jl new file mode 100644 index 000000000..67518f5b5 --- /dev/null +++ b/lib/mpsgraphs/tensor.jl @@ -0,0 +1,122 @@ +# Contains definitions for api from MPSGraphTensor.h, MPSGraphTensorData.h, MPSGraphOperation.h + +## MPSGraphTensor.h +# @objcwrapper MPSGraphTensor <: MPSGraphObject + +# Define MPSGraphOperation here to define the MPSGraphTensor properties +# @objcwrapper MPSGraphOperation <: MPSGraphObject + +function Base.size(td::MPSGraphTensor) + temp = map(td.shape) do nsnum + NSNumber(reinterpret(id{NSNumber}, nsnum)).unsignedIntegerValue |> Int + end + Tuple(temp) +end + +function placeholderTensor(graph::MPSGraph, shape::Union{Vector, Tuple}, args...) + mpsshape = convert(MPSShape, reverse(shape)) + return placeholderTensor(graph, mpsshape, args...) +end +function placeholderTensor(graph::MPSGraph, shape::MPSShape, dataType::Type, name = "placeholder tensor") + obj = @objc [graph::id{MPSGraph} placeholderWithShape:shape::id{MPSShape} + dataType:dataType::MPSDataType + name:name::id{NSString}]::id{MPSGraphTensor} + return MPSGraphTensor(obj) +end + +## MPSGraphTensorData.h +# @objcwrapper immutable=false MPSGraphTensorData <: MPSGraphObject + +function Base.size(td::MPSGraphTensorData) + temp = map(td.shape) do nsnum + NSNumber(reinterpret(id{NSNumber}, nsnum)).unsignedIntegerValue |> Int + end + Tuple(temp) +end + +function MPSGraphTensorData(buffer::MTLBuffer, shape::MPSShape, dataType) + obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData} + tensor = MPSGraphTensorData(obj) + finalizer(release, tensor) + @objc [tensor::id{MPSGraphTensorData} initWithMTLBuffer:buffer::id{MTLBuffer} + shape:shape::id{MPSShape} + dataType:dataType::MPSDataType]::id{MPSGraphTensorData} + return tensor +end +function MPSGraphTensorData(buffer::MTLBuffer, shape::MPSShape, dataType, rowBytes) + obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData} + tensor = MPSGraphTensorData(obj) + finalizer(release, tensor) + @objc [tensor::id{MPSGraphTensorData} initWithMTLBuffer:buffer::id{MTLBuffer} + shape:shape::id{MPSShape} + dataType:dataType::MPSDataType + rowBytes:rowBytes::NSUInteger]::id{MPSGraphTensorData} + return tensor +end +MPSGraphTensorData(matrix::MtlArray{T}) where T = MPSGraphTensorData(matrix.data[], convert(MPSShape, reverse(size(matrix))), T) + +function MPSGraphTensorData(matrix::MPSMatrix) + obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData} + tensor = MPSGraphTensorData(obj) + finalizer(release, tensor) + @objc [tensor::id{MPSGraphTensorData} initWithMPSMatrix:matrix::id{MPSMatrix}]::id{MPSGraphTensorData} + return tensor +end + +function MPSGraphTensorData(matrix::MPSMatrix, rank) + 1 <= rank <= 16 || throw(ArgumentError("`rank` must be between 1 and 16 inclusive")) + + obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData} + tensor = MPSGraphTensorData(obj) + finalizer(release, tensor) + @objc [tensor::id{MPSGraphTensorData} initWithMPSMatrix:matrix::id{MPSMatrix} + rank:rank::NSUInteger]::id{MPSGraphTensorData} + return tensor +end + +function MPSGraphTensorData(vector::MPSVector) + obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData} + tensor = MPSGraphTensorData(obj) + finalizer(release, tensor) + @objc [tensor::id{MPSGraphTensorData} initWithMPSVector:vector::id{MPSVector}]::id{MPSGraphTensorData} + return tensor +end + +function MPSGraphTensorData(vector::MPSVector, rank) + 1 <= rank <= 16 || throw(ArgumentError("`rank` must be between 1 and 16 inclusive")) + + obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData} + tensor = MPSGraphTensorData(obj) + finalizer(release, tensor) + @objc [tensor::id{MPSGraphTensorData} initWithMPSMatrix:vector::id{MPSVector} + rank:rank::NSUInteger]::id{MPSGraphTensorData} + return tensor +end + +function MPSGraphTensorData(ndarr::MPSNDArray) + obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData} + tensor = MPSGraphTensorData(obj) + finalizer(release, tensor) + @objc [tensor::id{MPSGraphTensorData} initWithMPSNDArray:ndarr::id{MPSNDArray}]::id{MPSGraphTensorData} + return tensor +end +# TODO: MPSImage is not yet implemented +# function MPSGraphTensorData(imgbatch::MPSImageBatch) +# obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData} +# tensor = MPSGraphTensorData(obj) +# finalizer(release, tensor) +# @objc [tensor::id{MPSGraphTensorData} initWithMPSImageBatch:imgbatch::id{MPSImageBatch}]::id{MPSGraphTensorData} +# MPSGraphTensorData(obj) +# end + +""" + MPSNDArray(tens::MPSGraphTensorData) + +Return an MPSNDArray object. + +Will copy contents if the contents are not stored in an MPS ndarray. +""" +function MPS.MPSNDArray(tensor::MPSGraphTensorData) + arr = @objc [tensor::id{MPSNDArray} mpsndarray]::id{MPSNDArray} + MPSNDArray(arr) +end diff --git a/res/wrap/libmpsgraph.toml b/res/wrap/libmpsgraph.toml new file mode 100644 index 000000000..bcd982ddc --- /dev/null +++ b/res/wrap/libmpsgraph.toml @@ -0,0 +1,29 @@ +[general] +library_name = "libmpsgraph" +output_file_path = "../../lib/mpsgraphs/libmpsgraph.jl" +prologue_file_path = "libmpsgraph_prologue.jl" + +minimum_macos_supported = "13" + +printer_blacklist = [ + "mt_macCatalyst", + "mt_ios", + "mt_macos", + "CF.*", + "MTL.*", + "NS.*", + "BOOL" +] + +[codegen] +use_ccall_macro = true +always_NUL_terminated_string = true + +[codegen.macro] +# it's highly recommended to set this entry to "basic". +# if you'd like to skip all of the macros, please set this entry to "disable". +# if you'd like to translate function-like macros to Julia, please set this entry to "aggressive". +macro_mode = "disable" + +[api.MPSGraphTensorData] +immutable=false diff --git a/res/wrap/libmpsgraph_prologue.jl b/res/wrap/libmpsgraph_prologue.jl new file mode 100644 index 000000000..fbed0e2fa --- /dev/null +++ b/res/wrap/libmpsgraph_prologue.jl @@ -0,0 +1,3 @@ +const MPSGraphCallableMap = NSDictionary + +const MPSGraphTensorDataDictionary = NSDictionary diff --git a/res/wrap/wrap.jl b/res/wrap/wrap.jl index d66998b05..29d8d950c 100644 --- a/res/wrap/wrap.jl +++ b/res/wrap/wrap.jl @@ -39,6 +39,12 @@ function main(names::AbstractVector=["all"]; sdk_path=SDK_PATH) push!(ctxs, tctx) end + if "all" in names || "libmpsgraph" in names || "mpsgraph" in names + fwpath = path_to_framework("MetalPerformanceShadersGraph") + tctx = wrap("libmpsgraph", joinpath(fwpath, "MetalPerformanceShadersGraph.h"); defines) + push!(ctxs, tctx) + end + return ctxs end diff --git a/src/Metal.jl b/src/Metal.jl index a4430bdab..5a4d87d09 100644 --- a/src/Metal.jl +++ b/src/Metal.jl @@ -52,6 +52,8 @@ include("compiler/reflection.jl") # libraries include("../lib/mps/MPS.jl") export MPS +include("../lib/mpsgraphs/MPSGraphs.jl") +export MPSGraphs # LinearAlgebra include("linalg.jl") diff --git a/src/initialization.jl b/src/initialization.jl index 2830bff07..871a4283c 100644 --- a/src/initialization.jl +++ b/src/initialization.jl @@ -53,6 +53,7 @@ function __init__() @autoreleasepool try load_framework("CoreGraphics") + load_framework("MetalPerformanceShadersGraph") ver = MTL.MTLCompileOptions().languageVersion @debug "Successfully loaded Metal; targeting v$ver." diff --git a/src/linalg.jl b/src/linalg.jl index 367254dfd..0a6ac7077 100644 --- a/src/linalg.jl +++ b/src/linalg.jl @@ -2,6 +2,24 @@ using LinearAlgebra using LinearAlgebra: MulAddMul, wrap using .MPS using .MPS: MPS_VALID_MATMUL_TYPES, MPS_VALID_MATVECMUL_TYPES, MtlFloat +using .MPSGraphs: MPSGRAPH_VALID_MATMUL_TYPES, MPSGRAPH_VALID_MATVECMUL_TYPES, + graph_matmul!, graph_matvecmul! + +@inline function supports_mps_matmul(A, B, C, valid_types) + MPS.is_supported(device(A)) && + eltype(A) == eltype(B) && + (eltype(A), eltype(C)) in valid_types +end + +@inline function supports_mpsgraph_matmul(A, B, C, valid_types) + MPS.is_supported(device(A)) && + eltype(A) == eltype(B) && + (eltype(A), eltype(C)) in valid_types && + # TODO: remove this limitation + A.offset == 0 && + B.offset == 0 && + C.offset == 0 +end LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatrix, _add::MulAddMul) = LinearAlgebra.generic_matmatmul!(C, tA, tB, A, B, _add.alpha, _add.beta) @@ -28,13 +46,10 @@ LinearAlgebra.generic_matmatmul!(C::MtlMatrix, tA, tB, A::MtlMatrix, B::MtlMatri transA = tA == 'T' || tA == 'C' transB = tB == 'T' || tB == 'C' - typA = eltype(A) - typB = eltype(B) - typC = eltype(C) - - # If possible, dispatch to performance shaders - if MPS.is_supported(device()) && - typA == typB && (typA, typC) in MPS_VALID_MATMUL_TYPES + # If possible, dispatch to MPSGraphs, then performance shaders + if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATMUL_TYPES) + graph_matmul!(C, A, B, alpha, beta, transA, transB) + elseif supports_mps_matmul(A, B, C, MPS_VALID_MATMUL_TYPES) # TODO: Remove once contiguous views are working matmul!(C, A, B, alpha, beta, transA, transB) else GPUArrays.generic_matmatmul!(C, wrap(A, tA), wrap(B, tB), alpha, beta) @@ -66,13 +81,10 @@ LinearAlgebra.generic_matvecmul!(C::MtlVector, tA::AbstractChar, A::MtlMatrix, B transA = tA == 'T' || tA == 'C' - typA = eltype(A) - typB = eltype(B) - typC = eltype(C) - - # If possible, dispatch to performance shaders - if MPS.is_supported(device()) && - typA == typB && (typA, typC) in MPS_VALID_MATVECMUL_TYPES + # If possible, dispatch to MPSGraphs, then performance shaders + if supports_mpsgraph_matmul(A, B, C, MPSGRAPH_VALID_MATVECMUL_TYPES) + graph_matvecmul!(C, A, B, alpha, beta, transA) + elseif supports_mps_matmul(A, B, C, MPS_VALID_MATVECMUL_TYPES) # TODO: Remove once contiguous views are working matvecmul!(C, A, B, alpha, beta, transA) else GPUArrays.generic_matmatmul!(C, wrap(A, tA), B, alpha, beta) diff --git a/test/Project.toml b/test/Project.toml index 22c5da38d..3b87ddaf3 100644 --- a/test/Project.toml +++ b/test/Project.toml @@ -1,5 +1,6 @@ [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" +AppleAccelerate = "13e28ba4-7ad8-5781-acae-3021b1ed3924" BFloat16s = "ab4f0b2a-ad5b-11e8-123f-65d77653426b" BenchmarkTools = "6e4b80f9-dd63-53aa-95a3-0cdb28fa8baf" Dates = "ade2ca70-3891-5945-98fb-dc099432e06a" diff --git a/test/examples.jl b/test/examples.jl index e33a01ab5..d0f9017db 100644 --- a/test/examples.jl +++ b/test/examples.jl @@ -19,6 +19,7 @@ cd(examples_dir) do @testset for example in examples mod = @eval module $(gensym()) end @eval mod begin + const TESTING=true redirect_stdout(devnull) do include($example) end diff --git a/test/linalg.jl b/test/linalg.jl index 0aa986a5a..5d05d0dd2 100644 --- a/test/linalg.jl +++ b/test/linalg.jl @@ -25,7 +25,7 @@ if MPS.is_supported(device()) mtl_view_c = mtl_view_a * mtl_view_b view_c = view_a * view_b - @test Array(mtl_view_c) == view_c + @test Array(mtl_view_c) ≈ view_c end using Metal: storagemode diff --git a/test/mps/linalg.jl b/test/mps/linalg.jl index f2a907835..90c6de9c6 100644 --- a/test/mps/linalg.jl +++ b/test/mps/linalg.jl @@ -34,17 +34,19 @@ if MPS.is_supported(device()) end @testset "batched matrix matrix multiplication" begin - N = 10 + M = 8 + N = 7 + P = 9 batch_size = 3 - rows_a = N + rows_a = M cols_a = N rows_b = N - cols_b = N + cols_b = P - rows_c = rows_a - cols_c = cols_b + rows_c = M + cols_c = P alpha = Float64(1) beta = Float64(1) diff --git a/test/mps/ndarray.jl b/test/mps/ndarray.jl index a667402e5..3fd691de0 100644 --- a/test/mps/ndarray.jl +++ b/test/mps/ndarray.jl @@ -29,6 +29,8 @@ end desc2.numberOfDimensions = 6 @test desc2.numberOfDimensions == 6 + @test_throws ArgumentError MPSNDArrayDescriptor(Float32, ones(Int, 17)) + @static if Metal.macos_version() >= v"15" @test desc1.preferPackedRows == false diff --git a/test/mpsgraphs/core.jl b/test/mpsgraphs/core.jl new file mode 100644 index 000000000..ce716190b --- /dev/null +++ b/test/mpsgraphs/core.jl @@ -0,0 +1,19 @@ + +if MPS.is_supported(device()) + +using .MPS: MPSShape +using .MPSGraphs: MPSGraph, MPSGraphDevice +@testset "Core" begin + +graph = MPSGraph() +@test graph isa MPSGraph + +dev = device() +graphdev = MPSGraphDevice(dev) +@test graphdev isa MPSGraphDevice +@test graphdev.type == MPSGraphs.MPSGraphDeviceTypeMetal +@test graphdev.metalDevice == dev + +end # @testset "Core" + +end # MPS.is_supported(device()) diff --git a/test/mpsgraphs/linalg.jl b/test/mpsgraphs/linalg.jl new file mode 100644 index 000000000..62438980e --- /dev/null +++ b/test/mpsgraphs/linalg.jl @@ -0,0 +1,100 @@ +using LinearAlgebra + + +if MPS.is_supported(device()) + +@testset "mixed-precision matrix matrix multiplication" begin + N = 10 + rows_a = N + cols_a = N + + rows_b = N + cols_b = N + + rows_c = rows_a + cols_c = cols_b + + alpha = Float64(1) + beta = Float64(1) + + @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES + arr_a = rand(input_jl_type, (rows_a, cols_a)) + arr_b = rand(input_jl_type, (rows_b, cols_b)) + arr_c = zeros(accum_jl_type, (rows_c, cols_c)) + + buf_a = MtlArray{input_jl_type}(arr_a) + buf_b = MtlArray{input_jl_type}(arr_b) + buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c)) + + truth_c = (alpha .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (beta .* arr_c) + + MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta) + + @test all(Array(buf_c) .≈ truth_c) + end +end + +@testset "batched matrix matrix multiplication" begin + M = 8 + N = 7 + P = 9 + batch_size = 3 + + rows_a = M + cols_a = N + + rows_b = N + cols_b = P + + rows_c = M + cols_c = P + + alpha = Float64(1) + beta = Float64(1) + + @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATMUL_TYPES + arr_a = rand(input_jl_type, (rows_a, cols_a, batch_size)) + arr_b = rand(input_jl_type, (rows_b, cols_b, batch_size)) + arr_c = zeros(accum_jl_type, (rows_c, cols_c, batch_size)) + + buf_a = MtlArray{input_jl_type}(arr_a) + buf_b = MtlArray{input_jl_type}(arr_b) + buf_c = MtlArray{accum_jl_type}(undef, (rows_c, cols_c, batch_size)) + + truth_c = Array{accum_jl_type}(undef, (rows_c, cols_c, batch_size)) + for i in 1:batch_size + @views truth_c[:, :, i] = (alpha .* accum_jl_type.(arr_a[:, :, i])) * accum_jl_type.(arr_b[:, :, i]) .+ (beta .* arr_c[:, :, i]) + end + + MPSGraphs.graph_matmul!(buf_c, buf_a, buf_b, alpha, beta) + + @test all(Array(buf_c) .≈ truth_c) + end +end + +@testset "mixed-precision matrix vector multiplication" begin + N = 10 + rows = N + cols = N + + alpha = Float64(1) + beta = Float64(0) + + @testset "$(input_jl_type) => $accum_jl_type" for (input_jl_type, accum_jl_type) in MPSGraphs.MPSGRAPH_VALID_MATVECMUL_TYPES + arr_a = rand(input_jl_type, (rows, cols)) + arr_b = rand(input_jl_type, (rows)) + arr_c = zeros(accum_jl_type, (rows)) + + buf_a = MtlArray{input_jl_type}(arr_a) + buf_b = MtlArray{input_jl_type}(arr_b) + buf_c = MtlArray{accum_jl_type}(undef, (rows)) + + truth_c = (accum_jl_type(alpha) .* accum_jl_type.(arr_a)) * accum_jl_type.(arr_b) .+ (accum_jl_type(beta) .* arr_c) + + MPSGraphs.graph_matvecmul!(buf_c, buf_a, buf_b, alpha, beta) + + @test all(Array(buf_c) .≈ truth_c) + end +end + +end # MPS.is_supported(device()) diff --git a/test/mpsgraphs/random.jl b/test/mpsgraphs/random.jl new file mode 100644 index 000000000..4303ee83b --- /dev/null +++ b/test/mpsgraphs/random.jl @@ -0,0 +1,25 @@ +using BFloat16s + +if MPS.is_supported(device()) + +using .MPSGraphs: MPSGraphRandomOpDescriptor, MPSGraphRandomDistributionNormal, MPSGraphRandomDistributionTruncatedNormal, MPSGraphRandomDistributionUniform +@testset "MPSGraph random" begin + # determined by looking at the error message when trying to construct + # an invalid distribution/type combination + for (dist, T) in [(MPSGraphRandomDistributionNormal, Float32), + (MPSGraphRandomDistributionNormal, Float16), + (MPSGraphRandomDistributionNormal, BFloat16), + (MPSGraphRandomDistributionTruncatedNormal, Float32), + (MPSGraphRandomDistributionTruncatedNormal, Float16), + (MPSGraphRandomDistributionTruncatedNormal, BFloat16), + (MPSGraphRandomDistributionUniform, Int64), + (MPSGraphRandomDistributionUniform, Int32), + (MPSGraphRandomDistributionUniform, Float32), + (MPSGraphRandomDistributionUniform, Float16), + (MPSGraphRandomDistributionUniform, BFloat16), + ] + @test MPSGraphRandomOpDescriptor(MPSGraphRandomDistributionNormal, Float32) isa MPSGraphRandomOpDescriptor + end +end + +end # MPS.is_supported(device())