@@ -1148,20 +1148,67 @@ class Resize(Layer):
11481148 def initialize (self ):
11491149 inp = self .get_input_variable ()
11501150
1151- if self .get_attr ('data_format' ) == 'channels_last' :
1152- if len (inp .shape ) == 2 : # 1D -> width + chan
1153- shape = [self .get_attr ('out_width' ), self .get_attr ('n_chan' )]
1154- dims = [f'OUT_WIDTH_{ self .index } ' , f'N_CHAN_{ self .index } ' ]
1155- elif len (inp .shape ) == 3 : # 2D -> height + width + chan
1156- shape = [self .get_attr ('out_height' ), self .get_attr ('out_width' ), self .get_attr ('n_chan' )]
1157- dims = [f'OUT_HEIGHT_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' , f'N_CHAN_{ self .index } ' ]
1151+ if len (self .inputs ) > 1 :
1152+ # In order to be correctly ingested by hls4ml the QONNX resize node should have 3 inputs set with RoI left empty
1153+ if len (self .inputs ) == 2 :
1154+ raise Exception (
1155+ 'The number of inputs to Resize node is equal to 2. '
1156+ 'In this case, either one is trying to use a version 10 node '
1157+ 'or one is using the RoI parameter only to perform the resize operation, '
1158+ 'both not supported in hls4ml'
1159+ )
1160+ if len (self .inputs ) == 4 :
1161+ raise Exception ('Sizes parameter is not supported by hls4ml. Use scales instead' )
1162+ # get the scales of Resize node from QONNX frontend
1163+ # see doc here: https://onnx.ai/onnx/operators/onnx__Resize.html
1164+ scales_idx = 2 if len (self .inputs ) == 3 or len (self .inputs ) == 4 else 1
1165+ scales = self .get_input_node (self .inputs [scales_idx ]).get_attr ('value' )
1166+ if len (scales ) == 4 : # Resize 2D
1167+ self .set_attr ('out_width' , int (self .get_attr ('in_width' ) * scales [1 ]))
1168+ self .set_attr ('out_height' , int (self .get_attr ('in_height' ) * scales [2 ]))
1169+ self .set_attr ('n_chan' , int (self .get_attr ('n_chan' ) * scales [3 ]))
1170+ elif len (scales ) == 3 : # Resize 1D
1171+ self .set_attr ('out_width' , int (self .get_attr ('in_width' ) * scales [1 ]))
1172+ self .set_attr ('n_chan' , int (self .get_attr ('n_chan' ) * scales [2 ]))
1173+ else :
1174+ raise Exception ('Resize 1D and Resize 2D are the ones supported in hls4ml' )
1175+ if self .get_attr ('data_format' ) == 'channels_last' :
1176+ if len (inp .shape ) == 2 : # 1D -> width + chan
1177+ shape = [int (self .get_attr ('out_width' )), int (self .get_attr ('n_chan' ))]
1178+ dims = [f'OUT_WIDTH_{ self .index } ' , f'N_CHAN_{ self .index } ' ]
1179+ elif len (inp .shape ) == 3 : # 2D -> height + width + chan
1180+ shape = [
1181+ int (self .get_attr ('out_height' )),
1182+ int (self .get_attr ('out_width' )),
1183+ int (self .get_attr ('n_chan' )),
1184+ ]
1185+ dims = [f'OUT_HEIGHT_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' , f'N_CHAN_{ self .index } ' ]
1186+ else :
1187+ if len (inp .shape ) == 2 : # 1D -> width + chan
1188+ shape = [int (self .get_attr ('n_chan' )), int (self .get_attr ('out_width' ))]
1189+ dims = [f'N_CHAN_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' ]
1190+ elif len (inp .shape ) == 3 : # 2D -> height + width + chan
1191+ shape = [
1192+ int (self .get_attr ('n_chan' )),
1193+ int (self .get_attr ('out_height' )),
1194+ int (self .get_attr ('out_width' )),
1195+ ]
1196+ dims = [f'N_CHAN_{ self .index } ' , f'OUT_HEIGHT_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' ]
11581197 else :
1159- if len (inp .shape ) == 2 : # 1D -> width + chan
1160- shape = [self .get_attr ('n_chan' ), self .get_attr ('out_width' )]
1161- dims = [f'N_CHAN_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' ]
1162- elif len (inp .shape ) == 3 : # 2D -> height + width + chan
1163- shape = [self .get_attr ('n_chan' ), self .get_attr ('out_height' ), self .get_attr ('out_width' )]
1164- dims = [f'N_CHAN_{ self .index } ' , f'OUT_HEIGHT_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' ]
1198+ if self .get_attr ('data_format' ) == 'channels_last' :
1199+ if len (inp .shape ) == 2 : # 1D -> width + chan
1200+ shape = [self .get_attr ('out_width' ), self .get_attr ('n_chan' )]
1201+ dims = [f'OUT_WIDTH_{ self .index } ' , f'N_CHAN_{ self .index } ' ]
1202+ elif len (inp .shape ) == 3 : # 2D -> height + width + chan
1203+ shape = [self .get_attr ('out_height' ), self .get_attr ('out_width' ), self .get_attr ('n_chan' )]
1204+ dims = [f'OUT_HEIGHT_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' , f'N_CHAN_{ self .index } ' ]
1205+ else :
1206+ if len (inp .shape ) == 2 : # 1D -> width + chan
1207+ shape = [self .get_attr ('n_chan' ), self .get_attr ('out_width' )]
1208+ dims = [f'N_CHAN_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' ]
1209+ elif len (inp .shape ) == 3 : # 2D -> height + width + chan
1210+ shape = [self .get_attr ('n_chan' ), self .get_attr ('out_height' ), self .get_attr ('out_width' )]
1211+ dims = [f'N_CHAN_{ self .index } ' , f'OUT_HEIGHT_{ self .index } ' , f'OUT_WIDTH_{ self .index } ' ]
11651212
11661213 self .add_output_variable (shape , dims , precision = inp .type .precision )
11671214
0 commit comments