Skip to content

Commit 278bf8a

Browse files
Support bool inputs for MatrixBandPart (#2269)
* fix MatrixBandPart bool argument * add test case for bool input Signed-off-by: Alexander Gerstenberger <[email protected]> --------- Signed-off-by: Alexander Gerstenberger <[email protected]> Co-authored-by: Alexander Gerstenberger <[email protected]>
1 parent ca17b3c commit 278bf8a

File tree

2 files changed

+37
-26
lines changed

2 files changed

+37
-26
lines changed

tests/test_backend.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3926,6 +3926,14 @@ def func(input_x, low, high):
39263926
return tf.identity(res, name=_TFOUTPUT)
39273927
self._run_test_case(func, [_OUTPUT], {_INPUT: input_val, _INPUT1: low_val, _INPUT2: high_val})
39283928

3929+
def test_matrix_band_part_bool(self):
3930+
input_val = np.random.choice([False, True], size=(10, 15))
3931+
def func(input_x):
3932+
res = tf.linalg.band_part(input_x, -1, 0)
3933+
res1 = tf.linalg.band_part(input_x, 0, -1)
3934+
return tf.identity(res, name=_TFOUTPUT), tf.identity(res1, name=_TFOUTPUT1)
3935+
self._run_test_case(func, [_OUTPUT, _OUTPUT1], {_INPUT: input_val})
3936+
39293937
def test_floordiv(self):
39303938
input_val_1 = np.random.random_sample(100).astype(np.int32)
39313939
input_val_2 = (np.random.random_sample(100) + 1).astype(np.int32)

tf2onnx/onnx_opset/nn.py

Lines changed: 29 additions & 26 deletions
Original file line numberDiff line numberDiff line change
@@ -1644,6 +1644,27 @@ def version_11(cls, ctx, node, **kwargs):
16441644

16451645
@tf_op("MatrixBandPart")
16461646
class MatrixBandPart:
1647+
@classmethod
1648+
def _apply_mask_and_transform(cls, ctx, node, mask):
1649+
shapes = node.output_shapes
1650+
dtypes = node.output_dtypes
1651+
dtype = ctx.get_dtype(node.input[0])
1652+
data = node.input[0]
1653+
if dtype == TensorProto.BOOL:
1654+
# bool is not supported for 'Mul', so convert mask and input supported dtype
1655+
mask = ctx.make_node("Cast", inputs=mask.output, attr={'to': TensorProto.FLOAT}).output[0]
1656+
data = ctx.make_node("Cast", [data], attr={'to': TensorProto.FLOAT}).output[0]
1657+
result = ctx.make_node(op_type="Mul", inputs=[mask, data], shapes=shapes, dtypes=[TensorProto.FLOAT])
1658+
ctx.remove_node(node.name)
1659+
ctx.make_node("Cast", inputs=result.output, attr={'to': dtype},
1660+
name=node.name, outputs=node.output, dtypes=dtypes)
1661+
else:
1662+
mask = ctx.make_node(op_type="Cast", inputs=mask.output, attr={"to": dtype}).output[0]
1663+
ctx.remove_node(node.name)
1664+
ctx.make_node(op_type="Mul", inputs=[mask, data],
1665+
name=node.name, outputs=node.output, shapes=shapes,
1666+
dtypes=dtypes)
1667+
16471668
@classmethod
16481669
def version_7(cls, ctx, node, **kwargs):
16491670
# T output = MatrixBandPart(T input, int num_lower, int num_upper)
@@ -1714,14 +1735,7 @@ def version_7(cls, ctx, node, **kwargs):
17141735
mask_matrix = ctx.make_node(op_type="Transpose", inputs=cast1.output)
17151736
else:
17161737
mask_matrix = squeeze
1717-
cast2 = ctx.make_node(op_type="Cast", inputs=mask_matrix.output,
1718-
attr={"to": ctx.get_dtype(node.input[0])})
1719-
shapes = node.output_shapes
1720-
dtypes = node.output_dtypes
1721-
ctx.remove_node(node.name)
1722-
ctx.make_node(op_type="Mul", inputs=[cast2.output[0], node.input[0]],
1723-
name=node.name, outputs=node.output, shapes=shapes,
1724-
dtypes=dtypes)
1738+
cls._apply_mask_and_transform(ctx, node, mask_matrix)
17251739

17261740
@classmethod
17271741
def version_11(cls, ctx, node, **kwargs):
@@ -1739,17 +1753,12 @@ def version_11(cls, ctx, node, **kwargs):
17391753
{'data': whole_shape, 'starts': [-2], 'ends': [int_max_val], 'axes': [0]})
17401754
if num_lower_const == 0 and num_upper_const == 0:
17411755
if rank == 2:
1742-
identity_node = ctx.make_node("EyeLike", [data]).output[0]
1756+
identity_node = ctx.make_node("EyeLike", [data])
17431757
else:
17441758
zero_tensor = helper.make_tensor("value", dtype, dims=[1], vals=[0])
17451759
const_of_shape = ctx.make_node("ConstantOfShape", [shape], attr={'value': zero_tensor}).output[0]
1746-
identity_node = ctx.make_node("EyeLike", [const_of_shape]).output[0]
1747-
shapes = node.output_shapes
1748-
dtypes = node.output_dtypes
1749-
ctx.remove_node(node.name)
1750-
ctx.make_node(op_type="Mul", inputs=[identity_node, data],
1751-
name=node.name, outputs=node.output, shapes=shapes,
1752-
dtypes=dtypes)
1760+
identity_node = ctx.make_node("EyeLike", [const_of_shape])
1761+
cls._apply_mask_and_transform(ctx, node, identity_node)
17531762
return
17541763
zero_const = ctx.make_const(utils.make_name("zero"), np.array(0, np.int64)).output[0]
17551764
one_const = ctx.make_const(utils.make_name("one"), np.array(1, np.int64)).output[0]
@@ -1771,14 +1780,14 @@ def version_11(cls, ctx, node, **kwargs):
17711780
if ctx.get_dtype(num_upper) != TensorProto.INT64:
17721781
num_upper = ctx.make_node("Cast", [num_upper], attr={'to': TensorProto.INT64}).output[0]
17731782
greater = ctx.make_node("Greater", [idx_diff, num_upper]).output[0]
1774-
less_or_equal = ctx.make_node("Not", [greater]).output[0]
1783+
less_or_equal = ctx.make_node("Not", [greater])
17751784
conditions.append(less_or_equal)
17761785
if num_lower_const is None or num_lower_const >= 0:
17771786
if ctx.get_dtype(num_lower) != TensorProto.INT64:
17781787
num_lower = ctx.make_node("Cast", [num_lower], attr={'to': TensorProto.INT64}).output[0]
17791788
num_lower_neg = ctx.make_node("Neg", [num_lower]).output[0]
17801789
greater = ctx.make_node("Greater", [num_lower_neg, idx_diff]).output[0]
1781-
less_or_equal = ctx.make_node("Not", [greater]).output[0]
1790+
less_or_equal = ctx.make_node("Not", [greater])
17821791
conditions.append(less_or_equal)
17831792
if len(conditions) == 0:
17841793
node.type = "Identity"
@@ -1787,14 +1796,8 @@ def version_11(cls, ctx, node, **kwargs):
17871796
if len(conditions) == 1:
17881797
cond = conditions[0]
17891798
if len(conditions) == 2:
1790-
cond = ctx.make_node("And", conditions).output[0]
1791-
mask = ctx.make_node("Cast", [cond], attr={'to': ctx.get_dtype(data)}).output[0]
1792-
shapes = node.output_shapes
1793-
dtypes = node.output_dtypes
1794-
ctx.remove_node(node.name)
1795-
ctx.make_node(op_type="Mul", inputs=[mask, data],
1796-
name=node.name, outputs=node.output, shapes=shapes,
1797-
dtypes=dtypes)
1799+
cond = ctx.make_node("And", inputs=[c.output[0] for c in conditions])
1800+
cls._apply_mask_and_transform(ctx, node, cond)
17981801

17991802

18001803
def _make_softmax_cross_entropy_with_logits(ctx, label, logit, tf_ori_node):

0 commit comments

Comments
 (0)