|
| 1 | + |
| 2 | +MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, resultsDictionary::MPSGraphTensorDataDictionary) = @inline MPS.encode!(commandBuffer, graph, feeds, nil, resultsDictionary, MPSGraphExecutionDescriptor()) |
| 3 | +function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetOperations, resultsDictionary::MPSGraphTensorDataDictionary, executionDescriptor) |
| 4 | + @objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer} |
| 5 | + feedss:feeds::id{MPSGraphTensorDataDictionary} |
| 6 | + targetOperations:targetOperations::id{Object} |
| 7 | + resultsDictionary:resultsDictionary::id{MPSGraphTensorDataDictionary} |
| 8 | + executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}]::Nothing |
| 9 | + return resultsDictionary |
| 10 | +end |
| 11 | + |
| 12 | +function MPS.encode!(commandBuffer::MPSCommandBuffer, graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil, executionDescriptor=MPSGraphExecutionDescriptor()) |
| 13 | + obj = @objc [graph::id{MPSGraph} encodeToCommandBuffer:commandBuffer::id{MPSCommandBuffer} |
| 14 | + feeds:feeds::id{MPSGraphTensorDataDictionary} |
| 15 | + targetTensors:targetTensors::id{NSArray} |
| 16 | + targetOperations:targetOperations::id{Object} |
| 17 | + executionDescriptor:executionDescriptor::id{MPSGraphExecutionDescriptor}]::id{MPSGraphTensorDataDictionary} |
| 18 | + MPSGraphTensorDataDictionary(obj) |
| 19 | +end |
| 20 | + |
| 21 | +function run(graph::MPSGraph, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray, targetOperations=nil) |
| 22 | + obj = @objc [graph::id{MPSGraph} runWithFeeds:feeds::id{MPSGraphTensorDataDictionary} |
| 23 | + targetTensors:targetTensors::id{NSArray} |
| 24 | + targetOperations:targetOperations::id{Object}]::id{MPSGraphTensorDataDictionary} |
| 25 | + MPSGraphTensorDataDictionary(obj) |
| 26 | +end |
| 27 | + |
| 28 | +function run(graph::MPSGraph, commandQueue::MTLCommandQueue, feeds::MPSGraphTensorDataDictionary, targetTensors::NSArray) |
| 29 | + obj = @objc [graph::id{MPSGraph} runWithMTLCommandQueue:commandQueue::id{MTLCommandQueue} |
| 30 | + feeds:feeds::id{MPSGraphTensorDataDictionary} |
| 31 | + targetTensors:targetTensors::id{NSArray} |
| 32 | + targetOperations:nil::id{Object}]::id{MPSGraphTensorDataDictionary} |
| 33 | + MPSGraphTensorDataDictionary(obj) |
| 34 | +end |
0 commit comments