@@ -986,48 +986,23 @@ def __init__(self):
986986 super ().__init__ ()
987987
988988 def parseNode (self , node : gs .Node ) -> (bool ):
989+ if not all (['axes' in node .attrs , len (node .inputs ) == 1 , len (node .outputs ) == 1 ]):
990+ return False
989991
990- # ONNX v11: 'axes' is a node attribute
991- if 'axes' in node .attrs :
992- ret = all (['axes' in node .attrs , len (node .inputs ) == 1 , len (node .outputs ) == 1 ])
993- # ONNX v13+: 'axes' becomes an input with the data
994- # Source: https://onnx.ai/onnx/operators/onnx__Unsqueeze.html
995- else :
996- ret = all ([len (node .inputs ) == 2 , len (node .outputs ) == 1 ])
997-
998- if ret and 'axes' in node .attrs :
999- axes_attr = node .attrs ['axes' ]
1000- self .operatorRepresentation ['axes' ] = [int (axes_attr )] if isinstance (axes_attr , int ) \
1001- else [int (a ) for a in axes_attr ]
1002- # For opset 13+, axes will be extracted from the second input in parseNodeCtxt
1003-
1004- return ret
992+ self .operatorRepresentation ['axes' ] = node .attrs ['axes' ]
993+ return True
1005994
1006995 def parseNodeCtxt (self ,
1007996 ctxt : NetworkContext ,
1008997 node : gs .Node ,
1009998 channels_first : bool = True ) -> Tuple [NetworkContext , bool ]:
999+ inputs = ['data_in' ]
1000+ for idx , inputNode in enumerate (node .inputs ):
1001+ self .operatorRepresentation [inputs [idx ]] = ctxt .lookup (inputNode .name ).name
10101002
10111003 outputs = ['data_out' ]
1012- if len (node .inputs ) == 1 :
1013- inputs = ['data_in' ]
1014- for idx , inputNode in enumerate (node .inputs ):
1015- self .operatorRepresentation [inputs [idx ]] = ctxt .lookup (inputNode .name ).name
1016- for idx , outputNode in enumerate (node .outputs ):
1017- self .operatorRepresentation [outputs [idx ]] = ctxt .lookup (outputNode .name ).name
1018- else :
1019- data_in = ctxt .lookup (node .inputs [0 ].name )
1020- data_out = ctxt .lookup (node .outputs [0 ].name )
1021- self .operatorRepresentation ['data_in' ] = data_in .name
1022- self .operatorRepresentation ['data_out' ] = data_out .name
1023- # axes must be a constant; extract values
1024- axes_buf = ctxt .lookup (node .inputs [1 ].name )
1025- assert hasattr (axes_buf , 'values' ), "Unsqueeze: expected constant 'axes' input for opset 13+"
1026- axes_vals = np .array (axes_buf .values ).astype (int ).flatten ().tolist ()
1027- self .operatorRepresentation ['axes' ] = axes_vals
1028- # Do not deploy the axes tensor
1029- axes_buf ._live = False
1030- axes_buf ._deploy = False
1004+ for idx , outputNode in enumerate (node .outputs ):
1005+ self .operatorRepresentation [outputs [idx ]] = ctxt .lookup (outputNode .name ).name
10311006
10321007 return ctxt , True
10331008
0 commit comments