Skip to content

Commit 7180f42

Browse files
committed
Canonicalize (un)squeeze operations as pre-opset-13, i.e., put axes into node attributes to ommit creating a buffer for it
1 parent fffe8f3 commit 7180f42

File tree

2 files changed

+25
-36
lines changed

2 files changed

+25
-36
lines changed

Deeploy/OperatorDescriptor.py

Lines changed: 16 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -476,15 +476,29 @@ def canonicalize(self, node: gs.Node, opset: int) -> bool:
476476
attrDescriptors = [AttrDesc("axis", IntUnpack, default = 0)],
477477
)
478478

479+
480+
class SqueezeDescriptor(OperatorDescriptor):
481+
482+
def canonicalize(self, node: gs.Node, opset: int) -> bool:
483+
if opset >= 13:
484+
assert len(node.inputs) == 2, f"Expected 2 inputs but received {len(node.inputs)}"
485+
axes = node.inputs[1]
486+
assert isinstance(axes,
487+
gs.Constant), f"Expected axes to be a constant but received axes of type {type(axes)}"
488+
node.attrs["axes"] = axes.values
489+
axes.outputs.clear()
490+
return super().canonicalize(node, opset)
491+
492+
479493
# Opset <= 11
480-
unsqueezeDesc = OperatorDescriptor(
494+
unsqueezeDesc = SqueezeDescriptor(
481495
inputDescriptor = IoDesc("data_in"),
482496
outputDescriptor = IoDesc("data_out"),
483497
attrDescriptors = [AttrDesc("axes", IntTupleUnpack)],
484498
)
485499

486500
# Opset <= 11
487-
squeezeDesc = OperatorDescriptor(
501+
squeezeDesc = SqueezeDescriptor(
488502
inputDescriptor = IoDesc("data_in"),
489503
outputDescriptor = IoDesc("data_out"),
490504
attrDescriptors = [AttrDesc("axes", IntTupleUnpack)],

Deeploy/Targets/Generic/Parsers.py

Lines changed: 9 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -986,48 +986,23 @@ def __init__(self):
986986
super().__init__()
987987

988988
def parseNode(self, node: gs.Node) -> (bool):
989+
if not all(['axes' in node.attrs, len(node.inputs) == 1, len(node.outputs) == 1]):
990+
return False
989991

990-
# ONNX v11: 'axes' is a node attribute
991-
if 'axes' in node.attrs:
992-
ret = all(['axes' in node.attrs, len(node.inputs) == 1, len(node.outputs) == 1])
993-
# ONNX v13+: 'axes' becomes an input with the data
994-
# Source: https://onnx.ai/onnx/operators/onnx__Unsqueeze.html
995-
else:
996-
ret = all([len(node.inputs) == 2, len(node.outputs) == 1])
997-
998-
if ret and 'axes' in node.attrs:
999-
axes_attr = node.attrs['axes']
1000-
self.operatorRepresentation['axes'] = [int(axes_attr)] if isinstance(axes_attr, int) \
1001-
else [int(a) for a in axes_attr]
1002-
# For opset 13+, axes will be extracted from the second input in parseNodeCtxt
1003-
1004-
return ret
992+
self.operatorRepresentation['axes'] = node.attrs['axes']
993+
return True
1005994

1006995
def parseNodeCtxt(self,
1007996
ctxt: NetworkContext,
1008997
node: gs.Node,
1009998
channels_first: bool = True) -> Tuple[NetworkContext, bool]:
999+
inputs = ['data_in']
1000+
for idx, inputNode in enumerate(node.inputs):
1001+
self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name
10101002

10111003
outputs = ['data_out']
1012-
if len(node.inputs) == 1:
1013-
inputs = ['data_in']
1014-
for idx, inputNode in enumerate(node.inputs):
1015-
self.operatorRepresentation[inputs[idx]] = ctxt.lookup(inputNode.name).name
1016-
for idx, outputNode in enumerate(node.outputs):
1017-
self.operatorRepresentation[outputs[idx]] = ctxt.lookup(outputNode.name).name
1018-
else:
1019-
data_in = ctxt.lookup(node.inputs[0].name)
1020-
data_out = ctxt.lookup(node.outputs[0].name)
1021-
self.operatorRepresentation['data_in'] = data_in.name
1022-
self.operatorRepresentation['data_out'] = data_out.name
1023-
# axes must be a constant; extract values
1024-
axes_buf = ctxt.lookup(node.inputs[1].name)
1025-
assert hasattr(axes_buf, 'values'), "Unsqueeze: expected constant 'axes' input for opset 13+"
1026-
axes_vals = np.array(axes_buf.values).astype(int).flatten().tolist()
1027-
self.operatorRepresentation['axes'] = axes_vals
1028-
# Do not deploy the axes tensor
1029-
axes_buf._live = False
1030-
axes_buf._deploy = False
1004+
for idx, outputNode in enumerate(node.outputs):
1005+
self.operatorRepresentation[outputs[idx]] = ctxt.lookup(outputNode.name).name
10311006

10321007
return ctxt, True
10331008

0 commit comments

Comments
 (0)