Skip to content

Commit bc12f52

Browse files
committed
Initial MPSGraph support
1 parent cb3d8dc commit bc12f52

File tree

12 files changed

+754
-0
lines changed

12 files changed

+754
-0
lines changed

lib/mpsgraphs/MPSGraphs.jl

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,25 @@
1+
"""
2+
# MPSGraphs
3+
4+
`MPSGraphs` is where the Metal Performance Shaders Graph API wrappers are defined.
5+
6+
Not all functionality is currently implemented or documented. For further details,
7+
refer to the [official Apple documentation](https://developer.apple.com/documentation/metalperformanceshadersgraph).
8+
"""
9+
module MPSGraphs
10+
11+
using ..Metal
12+
using .MTL
13+
using .MPS: MPSDataType, MPSMatrix, MPSVector, MPSShape, MPSNDArray
14+
15+
using CEnum
16+
using ObjectiveC, .Foundation, .Dispatch
17+
18+
include("libmpsgraph.jl")
19+
20+
include("core.jl")
21+
include("tensor.jl")
22+
include("operations.jl")
23+
include("random.jl")
24+
25+
end

lib/mpsgraphs/core.jl

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,28 @@
1+
# Contains definitions for api from MPSGraphCore.h, MPSGraphDevice.h
2+
3+
## MPSGraphCore.h
4+
# @objcwrapper MPSGraphObject <: NSObject
5+
# @objcwrapper MPSGraphType <: MPSGraphObject
6+
7+
# @objcwrapper MPSGraph <: MPSGraphObject
8+
function MPSGraph()
9+
MPSGraph(@objc [MPSGraph new]::id{MPSGraph})
10+
end
11+
12+
# @objcwrapper MPSGraphShapedType <: MPSGraphType
13+
function MPSGraphShapedType(shape::MPSShape, dataType)
14+
tmp = @objc [MPSGraphShapedType alloc]::id{MPSGraphShapedType}
15+
obj = MPSGraphShapedType(tmp)
16+
finalizer(release, obj)
17+
@objc [obj::id{MPSGraphShapedType} initWithShape:shape::id{MPSShape}
18+
dataType:dataType::MPSDataType]::id{MPSGraphShapedType}
19+
return obj
20+
end
21+
22+
## MPSGraphDevice.h
23+
# @objcwrapper MPSGraphDevice <: MPSGraphType
24+
25+
function MPSGraphDevice(device::MTLDevice)
26+
obj = @objc [MPSGraphDevice deviceWithMTLDevice:device::id{MTLDevice}]::id{MPSGraphDevice}
27+
MPSGraphDevice(obj)
28+
end

lib/mpsgraphs/libmpsgraph.jl

Lines changed: 460 additions & 0 deletions
Large diffs are not rendered by default.

lib/mpsgraphs/operations.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,24 @@
1+
2+
function matrixMultiplicationWithPrimaryTensor(graph::MPSGraph, primary::MPSGraphTensor, secondary::MPSGraphTensor, name="matmul")
3+
obj = @objc [graph::id{MPSGraph} matrixMultiplicationWithPrimaryTensor:primary::id{MPSGraphTensor}
4+
secondaryTensor:secondary::id{MPSGraphTensor}
5+
name:name::id{NSString}]::id{MPSGraphTensor}
6+
MPSGraphTensor(obj)
7+
end
8+
9+
run(graph::MPSGraph, feeds::Dict, targetTensors::Vector) = run(graph, MPSGraphTensorDataDictionary(feeds), NSArray(targetTensors))
10+
function run(graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray)
11+
obj = @objc [graph::id{MPSGraph} runWithFeeds:feeds::id{MPSGraphTensorDataDictionary}
12+
targetTensors:targetTensors::id{NSArray}
13+
targetOperations:nil::id{Object}]::id{MPSGraphTensorDataDictionary}
14+
MPSGraphTensorDataDictionary(obj)
15+
end
16+
17+
run(graph::MPSGraph, commandQueue::MTLCommandQueue, feeds::Dict, targetTensors::Vector) = run(graph, commandQueue, MPSGraphTensorDataDictionary(feeds), NSArray(targetTensors))
18+
function run(graph::MPSGraph, commandQueue::MTLCommandQueue, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray)
19+
obj = @objc [graph::id{MPSGraph} runWithMTLCommandQueue:commandQueue::id{MTLCommandQueue}
20+
feeds:feeds::id{MPSGraphTensorDataDictionary}
21+
targetTensors:targetTensors::id{NSArray}
22+
targetOperations:nil::id{Object}]::id{MPSGraphTensorDataDictionary}
23+
MPSGraphTensorDataDictionary(obj)
24+
end

lib/mpsgraphs/random.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
# @objcwrapper immutable=false MPSGraphRandomOpDescriptor <: MPSGraphObject
2+
3+
function MPSMatrixRandomOpDescriptor(distribution::MPSGraphRandomDistribution, dataType::MPSDataType)
4+
desc = @objc [MPSMatrixRandomOpDescriptor descriptorWithDistribution:distribution::MPSGraphRandomDistribution
5+
dataType:dataType::MPSDataType]::id{MPSGraphRandomOpDescriptor}
6+
obj = MPSGraphRandomOpDescriptor(desc)
7+
return obj
8+
end

lib/mpsgraphs/tensor.jl

Lines changed: 124 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,124 @@
1+
# Contains definitions for api from MPSGraphTensor.h, MPSGraphTensorData.h, MPSGraphOperation.h
2+
3+
## MPSGraphTensor.h
4+
# @objcwrapper MPSGraphTensor <: MPSGraphObject
5+
6+
# Define MPSGraphOperation here to define the MPSGraphTensor properties
7+
# @objcwrapper MPSGraphOperation <: MPSGraphObject
8+
9+
function Base.size(td::MPSGraphTensor)
10+
temp = map(td.shape) do nsnum
11+
NSNumber(reinterpret(id{NSNumber}, nsnum)).unsignedIntegerValue |> Int
12+
end
13+
Tuple(temp)
14+
end
15+
16+
function placeholderTensor(graph::MPSGraph, shape::Union{Vector, Tuple}, args...)
17+
mpsshape = convert(MPSShape, shape)
18+
return placeholderTensor(graph, mpsshape, args...)
19+
end
20+
function placeholderTensor(graph::MPSGraph, shape::MPSShape, dataType::Type, name = "placeholder tensor")
21+
obj = @objc [graph::id{MPSGraph} placeholderWithShape:shape::id{MPSShape}
22+
dataType:dataType::MPSDataType
23+
name:name::id{NSString}]::id{MPSGraphTensor}
24+
return MPSGraphTensor(obj)
25+
end
26+
27+
## MPSGraphTensorData.h
28+
# @objcwrapper immutable=false MPSGraphTensorData <: MPSGraphObject
29+
30+
function Base.size(td::MPSGraphTensorData)
31+
temp = map(td.shape) do nsnum
32+
NSNumber(reinterpret(id{NSNumber}, nsnum)).unsignedIntegerValue |> Int
33+
end
34+
Tuple(temp)
35+
end
36+
37+
function MPSGraphTensorData(buffer::MTLBuffer, shape::MPSShape, dataType)
38+
obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
39+
tensor = MPSGraphTensorData(obj)
40+
finalizer(release, tensor)
41+
@objc [tensor::id{MPSGraphTensorData} initWithMTLBuffer:buffer::id{MTLBuffer}
42+
shape:shape::id{MPSShape}
43+
dataType:dataType::MPSDataType]::id{MPSGraphTensorData}
44+
return tensor
45+
end
46+
function MPSGraphTensorData(buffer::MTLBuffer, shape::MPSShape, dataType, rowBytes)
47+
obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
48+
tensor = MPSGraphTensorData(obj)
49+
finalizer(release, tensor)
50+
@objc [tensor::id{MPSGraphTensorData} initWithMTLBuffer:buffer::id{MTLBuffer}
51+
shape:shape::id{MPSShape}
52+
dataType:dataType::MPSDataType
53+
rowBytes:rowBytes::NSUInteger]::id{MPSGraphTensorData}
54+
return tensor
55+
end
56+
# MPSGraphTensorData(matrix::MtlMatrix{T}) where T = MPSGraphTensorData(matrix.data[], convert(MPSShape, reverse(size(matrix))), T)
57+
MPSGraphTensorData(matrix::MtlMatrix) = MPSGraphTensorData(MPSMatrix(matrix))
58+
MPSGraphTensorData(arr::MtlArray{<:Any, 3}) = MPSGraphTensorData(MPSMatrix(arr))
59+
60+
function MPSGraphTensorData(matrix::MPSMatrix)
61+
obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
62+
tensor = MPSGraphTensorData(obj)
63+
finalizer(release, tensor)
64+
@objc [tensor::id{MPSGraphTensorData} initWithMPSMatrix:matrix::id{MPSMatrix}]::id{MPSGraphTensorData}
65+
return tensor
66+
end
67+
68+
# rank must be between 1 and 16 inclusive
69+
function MPSGraphTensorData(matrix::MPSMatrix, rank)
70+
obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
71+
tensor = MPSGraphTensorData(obj)
72+
finalizer(release, tensor)
73+
@objc [tensor::id{MPSGraphTensorData} initWithMPSMatrix:matrix::id{MPSMatrix}
74+
rank:rank::NSUInteger]::id{MPSGraphTensorData}
75+
return tensor
76+
end
77+
78+
function MPSGraphTensorData(vector::MPSVector)
79+
obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
80+
tensor = MPSGraphTensorData(obj)
81+
finalizer(release, tensor)
82+
@objc [tensor::id{MPSGraphTensorData} initWithMPSVector:vector::id{MPSVector}]::id{MPSGraphTensorData}
83+
return tensor
84+
end
85+
MPSGraphTensorData(vector::MtlVector{T}) where T = MPSGraphTensorData(vector.data[], convert(MPSShape, size(vector)), T)
86+
# MPSGraphTensorData(vector::MtlVector) = MPSGraphTensorData(MPSVector(vector))
87+
88+
# rank must be between 1 and 16 inclusive
89+
function MPSGraphTensorData(vector::MPSVector, rank)
90+
obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
91+
tensor = MPSGraphTensorData(obj)
92+
finalizer(release, tensor)
93+
@objc [tensor::id{MPSGraphTensorData} initWithMPSMatrix:vector::id{MPSVector}
94+
rank:rank::NSUInteger]::id{MPSGraphTensorData}
95+
return tensor
96+
end
97+
98+
function MPSGraphTensorData(ndarr::MPSNDArray)
99+
obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
100+
tensor = MPSGraphTensorData(obj)
101+
finalizer(release, tensor)
102+
@objc [tensor::id{MPSGraphTensorData} initWithMPSNDArray:ndarr::id{MPSNDArray}]::id{MPSGraphTensorData}
103+
return tensor
104+
end
105+
# TODO: MPSImage is not yet implemented
106+
# function MPSGraphTensorData(imgbatch::MPSImageBatch)
107+
# obj = @objc [MPSGraphTensorData alloc]::id{MPSGraphTensorData}
108+
# tensor = MPSGraphTensorData(obj)
109+
# finalizer(release, tensor)
110+
# @objc [tensor::id{MPSGraphTensorData} initWithMPSImageBatch:imgbatch::id{MPSImageBatch}]::id{MPSGraphTensorData}
111+
# MPSGraphTensorData(obj)
112+
# end
113+
114+
"""
115+
MPSNDArray(tens::MPSGraphTensorData)
116+
117+
Return an MPSNDArray object.
118+
119+
Will copy contents if the contents are not stored in an MPS ndarray.
120+
"""
121+
function MPS.MPSNDArray(tens::MPSGraphTensorData)
122+
arr = @objc [tens::id{MPSNDArray} mpsndarray]::id{MPSNDArray}
123+
MPSNDArray(arr)
124+
end

res/wrap/libmpsgraph.toml

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,29 @@
1+
[general]
2+
library_name = "libmpsgraph"
3+
output_file_path = "../../lib/mpsgraphs/libmpsgraph.jl"
4+
prologue_file_path = "libmpsgraph_prologue.jl"
5+
6+
minimum_macos_supported = "13"
7+
8+
printer_blacklist = [
9+
"mt_macCatalyst",
10+
"mt_ios",
11+
"mt_macos",
12+
"CF.*",
13+
"MTL.*",
14+
"NS.*",
15+
"BOOL"
16+
]
17+
18+
[codegen]
19+
use_ccall_macro = true
20+
always_NUL_terminated_string = true
21+
22+
[codegen.macro]
23+
# it's highly recommended to set this entry to "basic".
24+
# if you'd like to skip all of the macros, please set this entry to "disable".
25+
# if you'd like to translate function-like macros to Julia, please set this entry to "aggressive".
26+
macro_mode = "disable"
27+
28+
[api.MPSGraphTensorData]
29+
immutable=false

res/wrap/libmpsgraph_prologue.jl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
const MPSGraphCallableMap = NSDictionary
2+
3+
const MPSGraphTensorDataDictionary = NSDictionary

res/wrap/wrap.jl

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -39,6 +39,12 @@ function main(names::AbstractVector=["all"]; sdk_path=SDK_PATH)
3939
push!(ctxs, tctx)
4040
end
4141

42+
if "all" in names || "libmpsgraph" in names || "mpsgraph" in names
43+
fwpath = path_to_framework("MetalPerformanceShadersGraph")
44+
tctx = wrap("libmpsgraph", joinpath(fwpath, "MetalPerformanceShadersGraph.h"); defines)
45+
push!(ctxs, tctx)
46+
end
47+
4248
return ctxs
4349
end
4450

src/Metal.jl

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,8 @@ include("compiler/reflection.jl")
5252
# libraries
5353
include("../lib/mps/MPS.jl")
5454
export MPS
55+
include("../lib/mpsgraphs/MPSGraphs.jl")
56+
export MPSGraphs
5557

5658
# LinearAlgebra
5759
include("linalg.jl")

0 commit comments

Comments
 (0)