@@ -1134,6 +1134,57 @@ def _convert_since_9(cls, ctx, node, op_type, use_target_size=False):
1134
1134
name = node .name , outputs = node .output , shapes = shapes , dtypes = dtypes )
1135
1135
1136
1136
1137
+ @tf_op ("AdjustContrastv2" )
1138
+ class AdjustContrastv2 :
1139
+ @classmethod
1140
+ def version_1 (cls , ctx , node , ** kwargs ):
1141
+ images , contrast_factor = node .input
1142
+ dtype = ctx .get_dtype (images )
1143
+ if ctx .get_dtype (contrast_factor ) != dtype :
1144
+ contrast_factor = ctx .make_node ("Cast" , [dtype ], attr = {'to' : dtype }).output [0 ]
1145
+ rank = ctx .get_rank (images )
1146
+ utils .make_sure (rank is not None , "AdjustContrastv2 requires input of known rank" )
1147
+ # Reduce everything except channels
1148
+ axes_to_reduce = list (range (rank ))[:- 1 ]
1149
+ mean = ctx .make_node ("ReduceMean" , [images ], attr = {'axes' : axes_to_reduce , 'keepdims' : True },
1150
+ op_name_scope = node .name ).output [0 ]
1151
+ diff = ctx .make_node ("Sub" , [images , mean ], op_name_scope = node .name ).output [0 ]
1152
+ scaled = ctx .make_node ("Mul" , [diff , contrast_factor ], op_name_scope = node .name ).output [0 ]
1153
+ result = ctx .make_node ("Add" , [scaled , mean ], op_name_scope = node .name ).output [0 ]
1154
+ ctx .replace_all_inputs (node .output [0 ], result )
1155
+ ctx .remove_node (node .name )
1156
+
1157
+
1158
+ @tf_op ("AdjustSaturation" )
1159
+ class AdjustSaturation :
1160
+ @classmethod
1161
+ def version_11 (cls , ctx , node , ** kwargs ):
1162
+ images , factor = node .input
1163
+ dtype = ctx .get_dtype (images )
1164
+ np_dtype = utils .map_onnx_to_numpy_type (dtype )
1165
+ k = ctx .make_const (utils .make_name ("three" ), np .array ([3 ], np .int64 )).output [0 ]
1166
+ ordered , indices = ctx .make_node ("TopK" , [images , k ], attr = {'axis' : - 1 }, output_count = 2 ).output
1167
+ # Sorted and separated into channels
1168
+ max_c , mid_c , min_c = ctx .make_node ("Split" , [ordered ], attr = {'axis' : - 1 }, output_count = 3 ).output
1169
+ delta = ctx .make_node ("Sub" , [max_c , min_c ]).output [0 ]
1170
+ scaled_delta = ctx .make_node ("Mul" , [delta , factor ], op_name_scope = node .name ).output [0 ]
1171
+ new_delta = ctx .make_node ("Min" , [scaled_delta , max_c ]).output [0 ]
1172
+ new_min = ctx .make_node ("Sub" , [max_c , new_delta ]).output [0 ]
1173
+ delta2 = ctx .make_node ("Sub" , [mid_c , min_c ]).output [0 ]
1174
+ const_zero = ctx .make_const (utils .make_name ("zero" ), np .array (0 , np_dtype )).output [0 ]
1175
+ delta_z = ctx .make_node ("Equal" , [delta , const_zero ]).output [0 ]
1176
+ delta_z_cast = ctx .make_node ("Cast" , [delta_z ], attr = {'to' : dtype }).output [0 ]
1177
+ delta_nz = ctx .make_node ("Add" , [delta , delta_z_cast ]).output [0 ]
1178
+ delta2_scale = ctx .make_node ("Div" , [new_delta , delta_nz ]).output [0 ]
1179
+ new_delta2 = ctx .make_node ("Mul" , [delta2 , delta2_scale ], op_name_scope = node .name ).output [0 ]
1180
+ new_mid = ctx .make_node ("Add" , [new_min , new_delta2 ]).output [0 ]
1181
+ new_ordered = ctx .make_node ("Concat" , [max_c , new_mid , new_min ], attr = {'axis' : - 1 }).output [0 ]
1182
+ # Now put it back in order
1183
+ result = ctx .make_node ("GatherElements" , [new_ordered , indices ], attr = {'axis' : - 1 }).output [0 ]
1184
+ ctx .replace_all_inputs (node .output [0 ], result )
1185
+ ctx .remove_node (node .name )
1186
+
1187
+
1137
1188
@tf_op ("MatrixBandPart" )
1138
1189
class MatrixBandPart :
1139
1190
@classmethod
0 commit comments