Skip to content

Commit 0435b85

Browse files
committed
Update UnsqueezeParser with axes handling and input extraction based on ONNX Opset version and number of inputs
1 parent 0f04c32 commit 0435b85

File tree

1 file changed

+23
-11
lines changed

1 file changed

+23
-11
lines changed

Deeploy/Targets/Generic/Parsers.py

Lines changed: 23 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -943,14 +943,16 @@ def parseNode(self, node: gs.Node) -> (bool):
943943
# ONNX v11: 'axes' is a node attribute
944944
if 'axes' in node.attrs:
945945
ret = all(['axes' in node.attrs, len(node.inputs) == 1, len(node.outputs) == 1])
946-
# ONNX v13+: 'axes' becomes an input together with the data (source: https://onnx.ai/onnx/operators/onnx__Squeeze.html)
946+
# ONNX v13+: 'axes' becomes an input with the data
947+
# Source: https://onnx.ai/onnx/operators/onnx__Unsqueeze.html
947948
else:
948949
ret = all([len(node.inputs) == 2, len(node.outputs) == 1])
949950

950951
if ret and 'axes' in node.attrs:
951-
self.operatorRepresentation['axes'] = node.attrs['axes']
952-
elif ret:
953-
self.operatorRepresentation['axes'] = node.inputs[1]
952+
axes_attr = node.attrs['axes']
953+
self.operatorRepresentation['axes'] = [int(axes_attr)] if isinstance(axes_attr, int) \
954+
else [int(a) for a in axes_attr]
955+
# For opset 13+, axes will be extracted from the second input in parseNodeCtxt
954956

955957
return ret
956958

@@ -959,16 +961,26 @@ def parseNodeCtxt(self,
959961
node: gs.Node,
960962
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
961963

964+
outputs = ['data_out']
962965
if len(node.inputs) == 1:
963966
inputs = ['data_in']
967+
for idx, inputNode in enumerate(node.inputs):
968+
self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name
969+
for idx, outputNode in enumerate(node.outputs):
970+
self.operatorRepresentation[outputs[idx]] = ctxt.lookup(outputNode.name).name
964971
else:
965-
inputs = ['data_in','axes']
966-
outputs = ['data_out']
967-
968-
for idx, inputNode in enumerate(node.inputs):
969-
self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name
970-
for idx, outputNode in enumerate(node.outputs):
971-
self.operatorRepresentation[outputs[idx]] = ctxt.lookup(outputNode.name).name
972+
data_in = ctxt.lookup(node.inputs[0].name)
973+
data_out = ctxt.lookup(node.outputs[0].name)
974+
self.operatorRepresentation['data_in'] = data_in.name
975+
self.operatorRepresentation['data_out'] = data_out.name
976+
# axes must be a constant; extract values
977+
axes_buf = ctxt.lookup(node.inputs[1].name)
978+
assert hasattr(axes_buf, 'values'), "Unsqueeze: expected constant 'axes' input for opset 13+"
979+
axes_vals = np.array(axes_buf.values).astype(int).flatten().tolist()
980+
self.operatorRepresentation['axes'] = axes_vals
981+
# Do not deploy the axes tensor
982+
axes_buf._live = False
983+
axes_buf._deploy = False
972984

973985
return ctxt, True
974986

0 commit comments

Comments
 (0)