Skip to content

Commit ad4c792

Browse files
Implement conversions for AdjustSaturation and AdjustContrastv2 (#1346)
Signed-off-by: Tom Wildenhain <[email protected]> Co-authored-by: Guenther Schmuelling <[email protected]>
1 parent 5497030 commit ad4c792

File tree

2 files changed

+71
-0
lines changed

2 files changed

+71
-0
lines changed

tests/test_backend.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2528,6 +2528,26 @@ def func(x, x_new_size_):
25282528
return tf.identity(x_, name=_TFOUTPUT)
25292529
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: x_new_size})
25302530

2531+
def test_adjust_contrast(self):
2532+
x_shape = [4, 3, 2]
2533+
x_val = np.arange(1, 1 + np.prod(x_shape), dtype=np.float32).reshape(x_shape)
2534+
y_val = np.array(2.1, np.float32)
2535+
def func(x, y):
2536+
x_ = tf.image.adjust_contrast(x, y)
2537+
return tf.identity(x_, name=_TFOUTPUT)
2538+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
2539+
2540+
@check_opset_min_version(11, "GatherElements")
2541+
def test_adjust_saturation(self):
2542+
x_val = np.array([[1, 2, 3], [4, 4, 4], [3, 2, 3], [3, 2, 2]], dtype=np.float32).reshape([2, 2, 3])
2543+
y_val = np.array(2.1, np.float32)
2544+
def func(x, y):
2545+
x_ = tf.image.adjust_saturation(x, y)
2546+
return tf.identity(x_, name=_TFOUTPUT)
2547+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
2548+
y_val = np.array(0.5, np.float32)
2549+
self._run_test_case(func, [_OUTPUT], {_INPUT: x_val, _INPUT1: y_val})
2550+
25312551
@check_tf_min_version("2.0", "Results are slightly different in tf1")
25322552
@check_opset_min_version(11, "resize bicubic")
25332553
def test_resize_bicubic(self):

tf2onnx/onnx_opset/nn.py

Lines changed: 51 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1134,6 +1134,57 @@ def _convert_since_9(cls, ctx, node, op_type, use_target_size=False):
11341134
name=node.name, outputs=node.output, shapes=shapes, dtypes=dtypes)
11351135

11361136

1137+
@tf_op("AdjustContrastv2")
1138+
class AdjustContrastv2:
1139+
@classmethod
1140+
def version_1(cls, ctx, node, **kwargs):
1141+
images, contrast_factor = node.input
1142+
dtype = ctx.get_dtype(images)
1143+
if ctx.get_dtype(contrast_factor) != dtype:
1144+
contrast_factor = ctx.make_node("Cast", [dtype], attr={'to': dtype}).output[0]
1145+
rank = ctx.get_rank(images)
1146+
utils.make_sure(rank is not None, "AdjustContrastv2 requires input of known rank")
1147+
# Reduce everything except channels
1148+
axes_to_reduce = list(range(rank))[:-1]
1149+
mean = ctx.make_node("ReduceMean", [images], attr={'axes': axes_to_reduce, 'keepdims': True},
1150+
op_name_scope=node.name).output[0]
1151+
diff = ctx.make_node("Sub", [images, mean], op_name_scope=node.name).output[0]
1152+
scaled = ctx.make_node("Mul", [diff, contrast_factor], op_name_scope=node.name).output[0]
1153+
result = ctx.make_node("Add", [scaled, mean], op_name_scope=node.name).output[0]
1154+
ctx.replace_all_inputs(node.output[0], result)
1155+
ctx.remove_node(node.name)
1156+
1157+
1158+
@tf_op("AdjustSaturation")
1159+
class AdjustSaturation:
1160+
@classmethod
1161+
def version_11(cls, ctx, node, **kwargs):
1162+
images, factor = node.input
1163+
dtype = ctx.get_dtype(images)
1164+
np_dtype = utils.map_onnx_to_numpy_type(dtype)
1165+
k = ctx.make_const(utils.make_name("three"), np.array([3], np.int64)).output[0]
1166+
ordered, indices = ctx.make_node("TopK", [images, k], attr={'axis': -1}, output_count=2).output
1167+
# Sorted and separated into channels
1168+
max_c, mid_c, min_c = ctx.make_node("Split", [ordered], attr={'axis': -1}, output_count=3).output
1169+
delta = ctx.make_node("Sub", [max_c, min_c]).output[0]
1170+
scaled_delta = ctx.make_node("Mul", [delta, factor], op_name_scope=node.name).output[0]
1171+
new_delta = ctx.make_node("Min", [scaled_delta, max_c]).output[0]
1172+
new_min = ctx.make_node("Sub", [max_c, new_delta]).output[0]
1173+
delta2 = ctx.make_node("Sub", [mid_c, min_c]).output[0]
1174+
const_zero = ctx.make_const(utils.make_name("zero"), np.array(0, np_dtype)).output[0]
1175+
delta_z = ctx.make_node("Equal", [delta, const_zero]).output[0]
1176+
delta_z_cast = ctx.make_node("Cast", [delta_z], attr={'to': dtype}).output[0]
1177+
delta_nz = ctx.make_node("Add", [delta, delta_z_cast]).output[0]
1178+
delta2_scale = ctx.make_node("Div", [new_delta, delta_nz]).output[0]
1179+
new_delta2 = ctx.make_node("Mul", [delta2, delta2_scale], op_name_scope=node.name).output[0]
1180+
new_mid = ctx.make_node("Add", [new_min, new_delta2]).output[0]
1181+
new_ordered = ctx.make_node("Concat", [max_c, new_mid, new_min], attr={'axis': -1}).output[0]
1182+
# Now put it back in order
1183+
result = ctx.make_node("GatherElements", [new_ordered, indices], attr={'axis': -1}).output[0]
1184+
ctx.replace_all_inputs(node.output[0], result)
1185+
ctx.remove_node(node.name)
1186+
1187+
11371188
@tf_op("MatrixBandPart")
11381189
class MatrixBandPart:
11391190
@classmethod

0 commit comments

Comments
 (0)