@@ -943,14 +943,16 @@ def parseNode(self, node: gs.Node) -> (bool):
943943 # ONNX v11: 'axes' is a node attribute
944944 if 'axes' in node .attrs :
945945 ret = all (['axes' in node .attrs , len (node .inputs ) == 1 , len (node .outputs ) == 1 ])
946- # ONNX v13+: 'axes' becomes an input together with the data (source: https://onnx.ai/onnx/operators/onnx__Squeeze.html)
946+ # ONNX v13+: 'axes' becomes an input with the data
947+ # Source: https://onnx.ai/onnx/operators/onnx__Unsqueeze.html
947948 else :
948949 ret = all ([len (node .inputs ) == 2 , len (node .outputs ) == 1 ])
949950
950951 if ret and 'axes' in node .attrs :
951- self .operatorRepresentation ['axes' ] = node .attrs ['axes' ]
952- elif ret :
953- self .operatorRepresentation ['axes' ] = node .inputs [1 ]
952+ axes_attr = node .attrs ['axes' ]
953+ self .operatorRepresentation ['axes' ] = [int (axes_attr )] if isinstance (axes_attr , int ) \
954+ else [int (a ) for a in axes_attr ]
955+ # For opset 13+, axes will be extracted from the second input in parseNodeCtxt
954956
955957 return ret
956958
@@ -959,16 +961,26 @@ def parseNodeCtxt(self,
959961 node : gs .Node ,
960962 channels_first : bool = True ) -> Tuple [NetworkContext , bool ]:
961963
964+ outputs = ['data_out' ]
962965 if len (node .inputs ) == 1 :
963966 inputs = ['data_in' ]
967+ for idx , inputNode in enumerate (node .inputs ):
968+ self .operatorRepresentation [inputs [idx ]] = ctxt .lookup (inputNode .name ).name
969+ for idx , outputNode in enumerate (node .outputs ):
970+ self .operatorRepresentation [outputs [idx ]] = ctxt .lookup (outputNode .name ).name
964971 else :
965- inputs = ['data_in' ,'axes' ]
966- outputs = ['data_out' ]
967-
968- for idx , inputNode in enumerate (node .inputs ):
969- self .operatorRepresentation [inputs [idx ]] = ctxt .lookup (inputNode .name ).name
970- for idx , outputNode in enumerate (node .outputs ):
971- self .operatorRepresentation [outputs [idx ]] = ctxt .lookup (outputNode .name ).name
972+ data_in = ctxt .lookup (node .inputs [0 ].name )
973+ data_out = ctxt .lookup (node .outputs [0 ].name )
974+ self .operatorRepresentation ['data_in' ] = data_in .name
975+ self .operatorRepresentation ['data_out' ] = data_out .name
976+ # axes must be a constant; extract values
977+ axes_buf = ctxt .lookup (node .inputs [1 ].name )
978+ assert hasattr (axes_buf , 'values' ), "Unsqueeze: expected constant 'axes' input for opset 13+"
979+ axes_vals = np .array (axes_buf .values ).astype (int ).flatten ().tolist ()
980+ self .operatorRepresentation ['axes' ] = axes_vals
981+ # Do not deploy the axes tensor
982+ axes_buf ._live = False
983+ axes_buf ._deploy = False
972984
973985 return ctxt , True
974986
0 commit comments