@@ -1147,20 +1147,67 @@ class Resize(Layer):
11471147 def initialize (self ):
11481148 inp = self .get_input_variable ()
11491149
1150- if self .get_attr ('data_format' ) == 'channels_last' :
1151- if len (inp .shape ) == 2 : # 1D -> width + chan
1152- shape = [self .get_attr ('out_width' ), self .get_attr ('n_chan' )]
1153- dims = [f'OUT_WIDTH_{ self .index } ' , f'N_CHAN_{ self .index } ' ]
1154- elif len (inp .shape ) == 3 : # 2D -> height + width + chan
1155- shape = [self .get_attr ('out_height' ), self .get_attr ('out_width' ), self .get_attr ('n_chan' )]
1156- dims = [f'OUT_HEIGHT_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' , f'N_CHAN_{ self .index } ' ]
1150+ if len (self .inputs ) > 1 :
1151+ # In order to be correctly ingested by hls4ml the QONNX resize node should have 3 inputs set with RoI left empty
1152+ if len (self .inputs ) == 2 :
1153+ raise Exception (
1154+ 'The number of inputs to Resize node is equal to 2. '
1155+ 'In this case, either one is trying to use a version 10 node '
1156+ 'or one is using the RoI parameter only to perform the resize operation, '
1157+ 'both not supported in hls4ml'
1158+ )
1159+ if len (self .inputs ) == 4 :
1160+ raise Exception ('Sizes parameter is not supported by hls4ml. Use scales instead' )
1161+ # get the scales of Resize node from QONNX frontend
1162+ # see doc here: https://onnx.ai/onnx/operators/onnx__Resize.html
1163+ scales_idx = 2 if len (self .inputs ) == 3 or len (self .inputs ) == 4 else 1
1164+ scales = self .get_input_node (self .inputs [scales_idx ]).get_attr ('value' )
1165+ if len (scales ) == 4 : # Resize 2D
1166+ self .set_attr ('out_width' , int (self .get_attr ('in_width' ) * scales [1 ]))
1167+ self .set_attr ('out_height' , int (self .get_attr ('in_height' ) * scales [2 ]))
1168+ self .set_attr ('n_chan' , int (self .get_attr ('n_chan' ) * scales [3 ]))
1169+ elif len (scales ) == 3 : # Resize 1D
1170+ self .set_attr ('out_width' , int (self .get_attr ('in_width' ) * scales [1 ]))
1171+ self .set_attr ('n_chan' , int (self .get_attr ('n_chan' ) * scales [2 ]))
1172+ else :
1173+ raise Exception ('Resize 1D and Resize 2D are the ones supported in hls4ml' )
1174+ if self .get_attr ('data_format' ) == 'channels_last' :
1175+ if len (inp .shape ) == 2 : # 1D -> width + chan
1176+ shape = [int (self .get_attr ('out_width' )), int (self .get_attr ('n_chan' ))]
1177+ dims = [f'OUT_WIDTH_{ self .index } ' , f'N_CHAN_{ self .index } ' ]
1178+ elif len (inp .shape ) == 3 : # 2D -> height + width + chan
1179+ shape = [
1180+ int (self .get_attr ('out_height' )),
1181+ int (self .get_attr ('out_width' )),
1182+ int (self .get_attr ('n_chan' )),
1183+ ]
1184+ dims = [f'OUT_HEIGHT_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' , f'N_CHAN_{ self .index } ' ]
1185+ else :
1186+ if len (inp .shape ) == 2 : # 1D -> width + chan
1187+ shape = [int (self .get_attr ('n_chan' )), int (self .get_attr ('out_width' ))]
1188+ dims = [f'N_CHAN_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' ]
1189+ elif len (inp .shape ) == 3 : # 2D -> height + width + chan
1190+ shape = [
1191+ int (self .get_attr ('n_chan' )),
1192+ int (self .get_attr ('out_height' )),
1193+ int (self .get_attr ('out_width' )),
1194+ ]
1195+ dims = [f'N_CHAN_{ self .index } ' , f'OUT_HEIGHT_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' ]
11571196 else :
1158- if len (inp .shape ) == 2 : # 1D -> width + chan
1159- shape = [self .get_attr ('n_chan' ), self .get_attr ('out_width' )]
1160- dims = [f'N_CHAN_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' ]
1161- elif len (inp .shape ) == 3 : # 2D -> height + width + chan
1162- shape = [self .get_attr ('n_chan' ), self .get_attr ('out_height' ), self .get_attr ('out_width' )]
1163- dims = [f'N_CHAN_{ self .index } ' , f'OUT_HEIGHT_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' ]
1197+ if self .get_attr ('data_format' ) == 'channels_last' :
1198+ if len (inp .shape ) == 2 : # 1D -> width + chan
1199+ shape = [self .get_attr ('out_width' ), self .get_attr ('n_chan' )]
1200+ dims = [f'OUT_WIDTH_{ self .index } ' , f'N_CHAN_{ self .index } ' ]
1201+ elif len (inp .shape ) == 3 : # 2D -> height + width + chan
1202+ shape = [self .get_attr ('out_height' ), self .get_attr ('out_width' ), self .get_attr ('n_chan' )]
1203+ dims = [f'OUT_HEIGHT_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' , f'N_CHAN_{ self .index } ' ]
1204+ else :
1205+ if len (inp .shape ) == 2 : # 1D -> width + chan
1206+ shape = [self .get_attr ('n_chan' ), self .get_attr ('out_width' )]
1207+ dims = [f'N_CHAN_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' ]
1208+ elif len (inp .shape ) == 3 : # 2D -> height + width + chan
1209+ shape = [self .get_attr ('n_chan' ), self .get_attr ('out_height' ), self .get_attr ('out_width' )]
1210+ dims = [f'N_CHAN_{ self .index } ' , f'OUT_HEIGHT_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' ]
11641211
11651212 self .add_output_variable (shape , dims , precision = inp .type .precision )
11661213
0 commit comments