Skip to content

Commit 87dd357

Browse files
committed
format code
1 parent fd4a943 commit 87dd357

File tree

2 files changed

+5
-2
lines changed

2 files changed

+5
-2
lines changed

tests/test_backend.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2958,9 +2958,11 @@ def func(input_holder):
29582958
@check_opset_min_version(8)
29592959
def test_broadcast(self):
29602960
input_tensor_val = np.random.randint(low=0, high=256, size=[2, 3]).astype(np.float32)
2961-
new_shape_val = np.array([3,2,3]).astype(np.int64)
2961+
new_shape_val = np.array([3, 2, 3]).astype(np.int64)
2962+
29622963
def func(input_tensor, new_shape):
29632964
return tf.broadcast_to(input_tensor, new_shape, _TFOUTPUT)
2965+
29642966
self._run_test_case(func, [_OUTPUT], {_INPUT: input_tensor_val, _INPUT1: new_shape_val})
29652967

29662968

tf2onnx/onnx_opset/tensor.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1807,9 +1807,10 @@ def version_11(cls, ctx, node, **kwargs):
18071807
squeezed_result = ctx.make_node('Squeeze', [gathered_result.output[0]], attr={"axes": [-1]},
18081808
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
18091809

1810+
18101811
@tf_op("BroadcastTo")
18111812
class BroadcastTo:
18121813
@classmethod
18131814
def version_8(cls, ctx, node, **kwargs):
18141815
node.type = "Expand"
1815-
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)
1816+
ctx.insert_new_node_on_input(node, "Cast", node.input[1], to=TensorProto.INT64)

0 commit comments

Comments
 (0)