Skip to content

Commit ed385bc

Browse files
committed
Add encode function and move run functions
1 parent 391bfb8 commit ed385bc

File tree

4 files changed

+39
-17
lines changed

4 files changed

+39
-17
lines changed

lib/mpsgraphs/MPSGraphs.jl

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,7 @@ include("libmpsgraph.jl")
4141

4242
include("core.jl")
4343
include("tensor.jl")
44+
include("execution.jl")
4445
include("operations.jl")
4546
include("random.jl")
4647

lib/mpsgraphs/core.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,3 +26,7 @@ function MPSGraphDevice(device::MTLDevice)
2626
obj = @objc [MPSGraphDevice deviceWithMTLDevice:device::id{MTLDevice}]::id{MPSGraphDevice}
2727
MPSGraphDevice(obj)
2828
end
29+
30+
function MPSGraphExecutionDescriptor()
31+
MPSGraphExecutionDescriptor(@objc [MPSGraphExecutionDescriptor new]::id{MPSGraphExecutionDescriptor})
32+
end

lib/mpsgraphs/execution.jl

Lines changed: 34 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,34 @@
1+
2+
MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, resultsDictionary::MPSGraphTensorDataDictionary) = @inline MPS.encode!(commandBuffer, graph, feeds, nil, resultsDictionary, MPSGraphExecutionDescriptor())
3+
function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetOperations, resultsDictionary::MPSGraphTensorDataDictionary, executionDescriptor)
4+
@objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
5+
feedss:feeds::id{MPSGraphTensorDataDictionary}
6+
targetOperations:targetOperations::id{Object}
7+
resultsDictionary:resultsDictionary::id{MPSGraphTensorDataDictionary}
8+
executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}]::Nothing
9+
return resultsDictionary
10+
end
11+
12+
function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil, executionDescriptor=MPSGraphExecutionDescriptor())
13+
obj = @objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer}
14+
feeds:feeds::id{MPSGraphTensorDataDictionary}
15+
targetTensors:targetTensors::id{NSArray}
16+
targetOperations:targetOperations::id{Object}
17+
executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}]::id{MPSGraphTensorDataDictionary}
18+
MPSGraphTensorDataDictionary(obj)
19+
end
20+
21+
function run(graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil)
22+
obj = @objc [graph::id{MPSGraph} runWithFeeds:feeds::id{MPSGraphTensorDataDictionary}
23+
targetTensors:targetTensors::id{NSArray}
24+
targetOperations:targetOperations::id{Object}]::id{MPSGraphTensorDataDictionary}
25+
MPSGraphTensorDataDictionary(obj)
26+
end
27+
28+
function run(graph::MPSGraph, commandQueue::MTLCommandQueue, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray)
29+
obj = @objc [graph::id{MPSGraph} runWithMTLCommandQueue:commandQueue::id{MTLCommandQueue}
30+
feeds:feeds::id{MPSGraphTensorDataDictionary}
31+
targetTensors:targetTensors::id{NSArray}
32+
targetOperations:nil::id{Object}]::id{MPSGraphTensorDataDictionary}
33+
MPSGraphTensorDataDictionary(obj)
34+
end

lib/mpsgraphs/operations.jl

Lines changed: 0 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -45,20 +45,3 @@ function identityWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "ide
4545
name:name::id{NSString}]::id{MPSGraphTensor}
4646
MPSGraphTensor(obj)
4747
end
48-
49-
run(graph::MPSGraph, feeds::Dict, targetTensors::Vector) = run(graph, MPSGraphTensorDataDictionary(feeds), NSArray(targetTensors))
50-
function run(graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray)
51-
obj = @objc [graph::id{MPSGraph} runWithFeeds:feeds::id{MPSGraphTensorDataDictionary}
52-
targetTensors:targetTensors::id{NSArray}
53-
targetOperations:nil::id{Object}]::id{MPSGraphTensorDataDictionary}
54-
MPSGraphTensorDataDictionary(obj)
55-
end
56-
57-
run(graph::MPSGraph, commandQueue::MTLCommandQueue, feeds::Dict, targetTensors::Vector) = run(graph, commandQueue, MPSGraphTensorDataDictionary(feeds), NSArray(targetTensors))
58-
function run(graph::MPSGraph, commandQueue::MTLCommandQueue, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray)
59-
obj = @objc [graph::id{MPSGraph} runWithMTLCommandQueue:commandQueue::id{MTLCommandQueue}
60-
feeds:feeds::id{MPSGraphTensorDataDictionary}
61-
targetTensors:targetTensors::id{NSArray}
62-
targetOperations:nil::id{Object}]::id{MPSGraphTensorDataDictionary}
63-
MPSGraphTensorDataDictionary(obj)
64-
end

0 commit comments

Comments
 (0)