Skip to content

Commit 82e0f8b

Browse files
committed
More operations
1 parent bdd1f91 commit 82e0f8b

File tree

1 file changed

+19
-0
lines changed

1 file changed

+19
-0
lines changed

lib/mpsgraphs/operations.jl

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,17 @@
11

2+
function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shape::MPSShape, name="broadcast")
3+
obj = @objc [graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor}
4+
toShape:shape::id{MPSShape}
5+
name:name::id{NSString}]::id{MPSGraphTensor}
6+
MPSGraphTensor(obj)
7+
end
8+
function broadcastTensor(graph::MPSGraph, tensor::MPSGraphTensor, shapeTensor::MPSGraphTensor, name="broadcast")
9+
obj = @objc [graph::id{MPSGraph} broadcastTensor:tensor::id{MPSGraphTensor}
10+
toShapeTensor:shapeTensor::id{MPSGraphTensor}
11+
name:name::id{NSString}]::id{MPSGraphTensor}
12+
MPSGraphTensor(obj)
13+
end
14+
215
function castTensor(graph::MPSGraph, tensor::MPSGraphTensor, toType, name = "cast")
316
obj = @objc [graph::id{MPSGraph} castTensor:tensor::id{MPSGraphTensor}
417
toType:toType::MPSDataType
@@ -40,6 +53,12 @@ function transposeTensor(graph::MPSGraph, tensor::MPSGraphTensor, dimension, wit
4053
MPSGraphTensor(obj)
4154
end
4255

56+
function shapeOfTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "shapeOfTensor")
57+
obj = @objc [graph::id{MPSGraph} shapeOfTensor:tensor::id{MPSGraphTensor}
58+
name:name::id{NSString}]::id{MPSGraphTensor}
59+
MPSGraphTensor(obj)
60+
end
61+
4362
function identityWithTensor(graph::MPSGraph, tensor::MPSGraphTensor, name = "identity")
4463
obj = @objc [graph::id{MPSGraph} identityWithTensor:tensor::id{MPSGraphTensor}
4564
name:name::id{NSString}]::id{MPSGraphTensor}

0 commit comments

Comments
 (0)