Skip to content

Commit 48fdca5

Browse files
authored
Merge pull request #839 from RandySheriffH/rashuai/FixBroadcast
Rashuai/fix broadcast
2 parents 7118977 + dda4027 commit 48fdca5

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

tests/test_backend.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2980,6 +2980,16 @@ def func(input_holder):
29802980
for input_val in input_vals:
29812981
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val})
29822982

2983+
@check_opset_min_version(8)
2984+
def test_broadcast(self):
2985+
input_tensor_val = np.random.randint(low=0, high=256, size=[2, 3]).astype(np.float32)
2986+
new_shape_val = np.array([3, 2, 3]).astype(np.int64)
2987+
2988+
def func(input_tensor, new_shape):
2989+
return tf.broadcast_to(input_tensor, new_shape, _TFOUTPUT)
2990+
2991+
self._run_test_case(func, [_OUTPUT], {_INPUT: input_tensor_val, _INPUT1: new_shape_val})
2992+
29832993

29842994
if __name__ == '__main__':
29852995
unittest_main()

tf2onnx/onnx_opset/tensor.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1806,3 +1806,12 @@ def version_11(cls, ctx, node, **kwargs):
18061806
ctx.remove_node(node.name)
18071807
squeezed_result = ctx.make_node('Squeeze', [gathered_result.output[0]], attr={"axes": [-1]},
18081808
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
1809+
1810+
1811+
@tf_op("BroadcastTo")
1812+
class BroadcastTo:
1813+
@classmethod
1814+
def version_8(cls, ctx, node, **kwargs):
1815+
# broadcast by expanding
1816+
node.type = "Expand"
1817+
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)

0 commit comments

Comments
 (0)