Skip to content

Commit 973b0b9

Browse files
More MPSGraphs API (#624)
* MPSGraphShapedType * MPSGraphDevice * MPSGraph compilation/serialization * Fix cenversion to MPSShape * Oops * Add test
1 parent 9b8c604 commit 973b0b9

File tree

6 files changed

+44
-12
lines changed

6 files changed

+44
-12
lines changed

lib/mps/MPS.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ using BFloat16s
2121
const MtlFloat = Union{Float32, Float16}
2222

2323
const MPSShape = NSArray#{NSNumber}
24-
Base.convert(::Type{MPSShape}, tuple::Union{Vector{T},NTuple{T, <:Integer}}) where T = NSArray(NSNumber.(collect(tuple)))
24+
Base.convert(::Type{MPSShape}, tuple::Union{Vector{<:Integer},NTuple{N, <:Integer}}) where N = NSArray(NSNumber.(collect(tuple)))
2525

2626
# Valid combination of input (A and B matrices) and output (C) types
2727
const MPS_VALID_MATMUL_TYPES =

lib/mpsgraphs/core.jl

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -11,15 +11,15 @@ end
1111

1212
# @objcwrapper immutable=true MPSGraphShapedType <: MPSGraphType
1313

14-
# XXX: Not used yet and needs fixing
15-
# function MPSGraphShapedType(shape::MPSShape, dataType)
16-
# tmp = @objc [MPSGraphShapedType alloc]::id{MPSGraphShapedType}
17-
# obj = MPSGraphShapedType(tmp)
18-
# finalizer(release, obj)
19-
# @objc [obj::id{MPSGraphShapedType} initWithShape:shape::id{MPSShape}
20-
# dataType:dataType::MPSDataType]::id{MPSGraphShapedType}
21-
# return obj
22-
# end
14+
MPSGraphShapedType(shape, dataType) = MPSGraphShapedType(convert(MPSShape, shape), dataType)
15+
function MPSGraphShapedType(shape::MPSShape, dataType)
16+
tmp = @objc [MPSGraphShapedType alloc]::id{MPSGraphShapedType}
17+
obj = MPSGraphShapedType(tmp)
18+
finalizer(release, obj)
19+
@objc [obj::id{MPSGraphShapedType} initWithShape:shape::id{MPSShape}
20+
dataType:dataType::MPSDataType]::id{MPSGraphShapedType}
21+
return obj
22+
end
2323

2424
## MPSGraphDevice.h
2525
# @objcwrapper MPSGraphDevice <: MPSGraphType

lib/mpsgraphs/execution.jl

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,3 +32,27 @@ function run(graph::MPSGraph, commandQueue::MTLCommandQueue, feeds::MPSGraphTens
3232
targetOperations:nil::id{Object}]::id{MPSGraphTensorDataDictionary}
3333
MPSGraphTensorDataDictionary(obj)
3434
end
35+
36+
const MPSGraphTensorShapedTypeDictionary = NSDictionary#{MPSGraphTensor, MPSGraphTensorShapedType}
37+
38+
compile(graph::MPSGraph, dev::MTLDevice, feeds::MPSGraphTensorShapedTypeDictionary, targetTensors::NSArray, targetOperations=nil, compilationDescriptor=nil) = compile(graph, MPSGraphDevice(dev), feeds, targetTensors, targetOperations, compilationDescriptor)
39+
function compile(graph::MPSGraph, dev::MPSGraphDevice, feeds::MPSGraphTensorShapedTypeDictionary, targetTensors::NSArray, targetOperations=nil, compilationDescriptor=nil)
40+
exec = @objc [graph::id{MPSGraph} compileWithDevice:dev::id{MPSGraphDevice}
41+
feeds:feeds::id{MPSGraphTensorShapedTypeDictionary}
42+
targetTensors:targetTensors::id{NSArray}
43+
targetOperations:targetOperations::id{Object}
44+
compilationDescriptor:compilationDescriptor::id{Object}]::id{MPSGraphExecutable}
45+
return MPSGraphExecutable(exec)
46+
end
47+
48+
function MPSGraphExecutableSerializationDescriptor()
49+
tmp = @objc [MPSGraphExecutableSerializationDescriptor alloc]::id{MPSGraphExecutableSerializationDescriptor}
50+
obj = MPSGraphExecutableSerializationDescriptor(tmp)
51+
return obj
52+
end
53+
54+
serialize(graphExe::MPSGraphExecutable, url, descriptor=MPSGraphExecutableSerializationDescriptor()) = serialize(graphExe, NSFileURL(url), descriptor)
55+
function serialize(graphExe::MPSGraphExecutable, url::NSURL, descriptor=MPSGraphExecutableSerializationDescriptor())
56+
@objc [graphExe::id{MPSGraphExecutable} serializeToMPSGraphPackageAtURL:url::id{NSURL}
57+
descriptor:descriptor::id{MPSGraphExecutableSerializationDescriptor}]::Nothing
58+
end

lib/mpsgraphs/libmpsgraph.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ const MPSGraphTensorDataDictionary = NSDictionary
1212

1313
@objcwrapper immutable = true MPSGraphType <: MPSGraphObject
1414

15-
@objcwrapper immutable = true MPSGraphShapedType <: MPSGraphType
15+
@objcwrapper immutable = false MPSGraphShapedType <: MPSGraphType
1616

1717
@objcproperties MPSGraphShapedType begin
1818
@autoproperty shape::id{MPSShape} setter = setShape

res/wrap/libmpsgraph.toml

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,3 +27,6 @@ macro_mode = "disable"
2727

2828
[api.MPSGraphTensorData]
2929
immutable=false
30+
31+
[api.MPSGraphShapedType]
32+
immutable=false

test/mpsgraphs/core.jl

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
if MPS.is_supported(device())
22

33
using .MPS: MPSShape
4-
using .MPSGraphs: MPSGraph, MPSGraphDevice
4+
using .MPSGraphs: MPSGraph, MPSGraphDevice, MPSGraphShapedType
55
@testset "Core" begin
66

77
graph = MPSGraph()
@@ -13,6 +13,11 @@ graphdev = MPSGraphDevice(dev)
1313
@test graphdev.type == MPSGraphs.MPSGraphDeviceTypeMetal
1414
@test graphdev.metalDevice == dev
1515

16+
mpsh = convert(MPS.MPSShape, (2,3,4))
17+
shtyp = MPSGraphShapedType(mpsh, Float32)
18+
@test shtyp.shape == convert(MPS.MPSShape,(2,3,4))
19+
@test shtyp.dataType == MPS.MPSDataTypeFloat32
20+
1621
end # @testset "Core"
1722

1823
end # MPS.is_supported(device())

0 commit comments

Comments
 (0)