Skip to content

Commit fd4a943

Browse files
committed
add support for broadcast_to
1 parent 4f5bfd7 commit fd4a943

File tree

2 files changed

+15
-0
lines changed

2 files changed

+15
-0
lines changed

tests/test_backend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2955,6 +2955,14 @@ def func(input_holder):
29552955
for input_val in input_vals:
29562956
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val})
29572957

2958+
@check_opset_min_version(8)
2959+
def test_broadcast(self):
2960+
input_tensor_val = np.random.randint(low=0, high=256, size=[2, 3]).astype(np.float32)
2961+
new_shape_val = np.array([3,2,3]).astype(np.int64)
2962+
def func(input_tensor, new_shape):
2963+
return tf.broadcast_to(input_tensor, new_shape, _TFOUTPUT)
2964+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_tensor_val, _INPUT1: new_shape_val})
2965+
29582966

29592967
if __name__ == '__main__':
29602968
unittest_main()

tf2onnx/onnx_opset/tensor.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1806,3 +1806,10 @@ def version_11(cls, ctx, node, **kwargs):
18061806
ctx.remove_node(node.name)
18071807
squeezed_result = ctx.make_node('Squeeze', [gathered_result.output[0]], attr={"axes": [-1]},
18081808
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
1809+
1810+
@tf_op("BroadcastTo")
1811+
class BroadcastTo:
1812+
@classmethod
1813+
def version_8(cls, ctx, node, **kwargs):
1814+
node.type = "Expand"
1815+
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)

0 commit comments

Comments
 (0)