Skip to content

Commit bc01b44

Browse files
committed
Remove the _parseNode override since we don't use the memory-aware node bindings anymore
1 parent c00fda5 commit bc01b44

File tree

2 files changed

+2
-51
lines changed

2 files changed

+2
-51
lines changed

Deeploy/CommonExtensions/NetworkDeployers/NetworkDeployerWrapper.py

Lines changed: 1 addition & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22
#
33
# SPDX-License-Identifier: Apache-2.0
44

5-
from typing import Any, Tuple, Union
5+
from typing import Any, Union
66

77
import onnx_graphsurgeon as gs
88

@@ -63,11 +63,6 @@ def lower(self, graph: gs.Graph) -> gs.Graph:
6363
def codeTransform(self, verbose: CodeGenVerbosity = _NoVerbosity):
6464
return self._innerObject.codeTransform(verbose)
6565

66-
# MemoryAwareDeployer augment
67-
def _parseNode(self, node: ONNXLayer, ctxt: NetworkContext,
68-
default_channels_first: bool) -> Tuple[NetworkContext, bool]:
69-
return self._innerObject._parseNode(node, ctxt, default_channels_first)
70-
7166
# PULPDeployer augment
7267
def generateBufferAllocationCode(self) -> str:
7368
return self._innerObject.generateBufferAllocationCode()

Deeploy/MemoryLevelExtension/NetworkDeployers/MemoryLevelDeployer.py

Lines changed: 1 addition & 45 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
from Deeploy.CommonExtensions.NetworkDeployers.NetworkDeployerWrapper import NetworkDeployerWrapper
1212
from Deeploy.CommonExtensions.NetworkDeployers.SignPropDeployer import SignPropDeployer
1313
from Deeploy.DeeployTypes import CodeGenVerbosity, ConstantBuffer, DeploymentEngine, DeploymentPlatform, \
14-
NetworkContext, NetworkDeployer, NetworkOptimizationPass, NetworkOptimizer, ONNXLayer, Schedule, StructBuffer, \
14+
NetworkContext, NetworkDeployer, NetworkOptimizationPass, NetworkOptimizer, Schedule, StructBuffer, \
1515
TopologyOptimizer, TransientBuffer, VariableBuffer, _NoVerbosity
1616
from Deeploy.Logging import DEFAULT_LOGGER as log
1717
from Deeploy.MemoryLevelExtension.MemoryLevels import MemoryHierarchy, MemoryLevel
@@ -128,18 +128,6 @@ def getTargetMemoryLevelMapping(self) -> TargetMemoryLevelMapping:
128128
f"Platform should be a MemoryPlatform or MemoryPlatformWrapper! Got {type(self.Platform).__name__}"
129129
return TargetMemoryLevelMapping(self.graph, self.Platform, self.ctxt)
130130

131-
def _parseNode(self, node: ONNXLayer, ctxt: NetworkContext,
132-
default_channels_first: bool) -> Tuple[NetworkContext, bool]:
133-
134-
newCtxt, parsePass = super()._parseNode(node, ctxt, default_channels_first)
135-
136-
if not parsePass:
137-
return ctxt, False
138-
139-
newCtxt, self.graph = self.memoryLevelAnnotationOptimizer.optimize(newCtxt, self.graph)
140-
141-
return newCtxt, parsePass
142-
143131
def bind(self):
144132

145133
ret = super().bind()
@@ -181,22 +169,6 @@ def getTargetMemoryLevelMapping(self) -> TargetMemoryLevelMapping:
181169
f"Platform should be a MemoryPlatform or MemoryPlatformWrapper! Got {type(self.Platform).__name__}"
182170
return TargetMemoryLevelMapping(self.graph, self.Platform, self.ctxt)
183171

184-
def _parseNode(self, node: ONNXLayer, ctxt: NetworkContext,
185-
default_channels_first: bool) -> Tuple[NetworkContext, bool]:
186-
187-
newCtxt, parsePass = node.parse(ctxt.copy(), default_channels_first)
188-
189-
if not parsePass:
190-
return ctxt, False
191-
192-
newCtxt, self.graph = self.memoryLevelAnnotationOptimizer.optimize(newCtxt, self.graph)
193-
newCtxt, LayerBindSuccess = node.typeCheck(newCtxt)
194-
195-
if not LayerBindSuccess:
196-
return ctxt, False
197-
198-
return newCtxt, True
199-
200172
def bind(self):
201173

202174
ret = super().bind()
@@ -229,22 +201,6 @@ def getTargetMemoryLevelMapping(self) -> TargetMemoryLevelMapping:
229201
f"Platform should be a MemoryPlatform or MemoryPlatformWrapper! Got {type(self.Platform).__name__}"
230202
return TargetMemoryLevelMapping(self.graph, self.Platform, self.ctxt)
231203

232-
def _parseNode(self, node: ONNXLayer, ctxt: NetworkContext,
233-
default_channels_first: bool) -> Tuple[NetworkContext, bool]:
234-
235-
newCtxt, parsePass = node.parse(ctxt.copy(), default_channels_first)
236-
237-
if not parsePass:
238-
return ctxt, False
239-
240-
newCtxt, self.graph = self.memoryLevelAnnotationOptimizer.optimize(newCtxt, self.graph)
241-
newCtxt, LayerBindSuccess = node.typeCheck(newCtxt)
242-
243-
if not LayerBindSuccess:
244-
return ctxt, False
245-
246-
return newCtxt, True
247-
248204
def bind(self):
249205

250206
ret = super().bind()

0 commit comments

Comments
 (0)