Skip to content

Commit e980509

Browse files
committed
Add custom PULP Reshape Binding
1 parent 74a02d9 commit e980509

File tree

3 files changed

+84
-15
lines changed

3 files changed

+84
-15
lines changed

Deeploy/Targets/PULPOpen/Bindings.py

Lines changed: 19 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -31,7 +31,7 @@
3131
from Deeploy.AbstractDataTypes import PointerClass
3232
from Deeploy.CommonExtensions.CodeTransformationPasses.Closure import ClosureGeneration, MemoryAwareClosureGeneration
3333
from Deeploy.CommonExtensions.CodeTransformationPasses.MemoryAllocation import ArgumentStructGeneration, \
34-
MemoryManagementGeneration
34+
MemoryManagementGeneration, MemoryPassthroughGeneration
3535
from Deeploy.CommonExtensions.DataTypes import IntegerDataTypes, SignedIntegerDataTypes, float32_t, int8_t, int32_t, \
3636
uint8_t
3737
from Deeploy.DeeployTypes import CodeTransformation, NodeBinding, NodeTemplate
@@ -41,8 +41,8 @@
4141
GatherTemplate, QuantTemplate, RQSiGELUTemplate, iHardswishTemplate
4242
from Deeploy.Targets.Generic.TypeCheckers import AddChecker, ConcatChecker, ConvChecker, DequantChecker, \
4343
GatherChecker, GELUChecker, GEMMChecker, HardswishChecker, LayerNormChecker, MatMulChecker, MulChecker, \
44-
QuantChecker, ReduceMeanChecker, ReluChecker, RQAddChecker, RQHardswishChecker, SGDChecker, SliceChecker, \
45-
SoftmaxChecker, SoftmaxCrossEntropyLossChecker, TransposeChecker
44+
QuantChecker, ReduceMeanChecker, ReluChecker, ReshapeChecker, RQAddChecker, RQHardswishChecker, SGDChecker, \
45+
SliceChecker, SoftmaxChecker, SoftmaxCrossEntropyLossChecker, TransposeChecker
4646
from Deeploy.Targets.PULPOpen.CodeTransformationPasses.PULPClusterSynch import PULPSynchCoresPass
4747
from Deeploy.Targets.PULPOpen.CodeTransformationPasses.PULPClusterTiling import PULPClusterTiling
4848
from Deeploy.Targets.PULPOpen.CodeTransformationPasses.PULPL3Tiling import PULPL3Tiling
@@ -53,8 +53,8 @@
5353
from Deeploy.Targets.PULPOpen.Templates import ConvTemplate, FloatAddTemplate, FloatConvTemplate, FloatGELUTemplate, \
5454
FloatGemmTemplate, FloatLayernormTemplate, FloatMatMulTemplate, FloatMaxPoolTemplate, FloatMulTemplate, \
5555
FloatReluTemplate, FloatSoftmaxTemplate, GEMMTemplate, MatrixVectorTemplate, MaxPool2DTemplate, MulTemplate, \
56-
ReduceMeanTemplate, RequantShiftTemplate, RQAddTemplate, RQSiHardswishTemplate, SGDTemplate, SliceTemplate, \
57-
SoftmaxCrossEntropyLossTemplate, TallGEMMTemplate, TransposeTemplate, UniformRequantShiftTemplate, \
56+
ReduceMeanTemplate, RequantShiftTemplate, ReshapeTemplate, RQAddTemplate, RQSiHardswishTemplate, SGDTemplate, \
57+
SliceTemplate, SoftmaxCrossEntropyLossTemplate, TallGEMMTemplate, TransposeTemplate, UniformRequantShiftTemplate, \
5858
iRMSNormTemplate, iSoftmaxTemplate
5959
from Deeploy.Targets.PULPOpen.TypeCheckers import PULPConvChecker, PULPLinearChecker, PULPMaxPoolChecker, \
6060
PULPRequantShiftChecker
@@ -76,6 +76,12 @@
7676
pi_cl_team_fork(NUM_CORES, (void*)${closureName}, &${closureStructArgName});
7777
""")
7878

79+
SkipTransformer = CodeTransformation(
80+
[ArgumentStructGeneration(),
81+
MemoryPassthroughGeneration("L.*"),
82+
MemoryPassthroughGeneration(),
83+
FutureGeneration()])
84+
7985
FunctionCallClosure = partial(ClosureGeneration, closureSuffix = "_closure")
8086
ClusterClosure = partial(ClosureGeneration,
8187
closureSuffix = "_cluster_entry",
@@ -169,6 +175,14 @@
169175
for type in IntegerDataTypes
170176
]
171177

178+
PULPReshapeBindings = [
179+
NodeBinding(ReshapeChecker([PointerClass(type), PointerClass(int32_t)], [PointerClass(type)]),
180+
ReshapeTemplate.referenceTemplate, SkipTransformer) for type in IntegerDataTypes
181+
] + [
182+
NodeBinding(ReshapeChecker([PointerClass(float32_t), PointerClass(type)], [PointerClass(float32_t)]),
183+
ReshapeTemplate.referenceTemplate, SkipTransformer) for type in IntegerDataTypes
184+
]
185+
172186
PULPRQAddBindings = [
173187
NodeBinding(RQAddChecker([PointerClass(_type), PointerClass(_type2)], [PointerClass(_type3)]),
174188
RQAddTemplate.referenceTemplate, ForkTransformer)
Lines changed: 59 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,59 @@
1+
# ----------------------------------------------------------------------
2+
#
3+
# File: ReshapeTemplate.py
4+
#
5+
# Last edited: 16.12.2021
6+
#
7+
# Copyright (C) 2021, ETH Zurich and University of Bologna.
8+
#
9+
# Author: Moritz Scherer, ETH Zurich
10+
#
11+
# ----------------------------------------------------------------------
12+
# SPDX-License-Identifier: Apache-2.0
13+
#
14+
# Licensed under the Apache License, Version 2.0 (the License); you may
15+
# not use this file except in compliance with the License.
16+
# You may obtain a copy of the License at
17+
#
18+
# www.apache.org/licenses/LICENSE-2.0
19+
#
20+
# Unless required by applicable law or agreed to in writing, software
21+
# distributed under the License is distributed on an AS IS BASIS, WITHOUT
22+
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
23+
# See the License for the specific language governing permissions and
24+
# limitations under the License.
25+
26+
from typing import Dict, List, Tuple
27+
28+
from Deeploy.DeeployTypes import NetworkContext, NodeTemplate, OperatorRepresentation
29+
30+
31+
class _ReshapeTemplate(NodeTemplate):
32+
33+
def __init__(self, templateStr):
34+
super().__init__(templateStr)
35+
36+
def alignToContext(self, ctxt: NetworkContext,
37+
operatorRepresentation: OperatorRepresentation) -> Tuple[NetworkContext, Dict, List[str]]:
38+
39+
# SCHEREMO: Selectively mark 'indices' dead, since we don't need them
40+
if 'indices' in operatorRepresentation.keys():
41+
ctxt.globalObjects[operatorRepresentation['indices']]._deploy = False
42+
ctxt.globalObjects[operatorRepresentation['indices']]._live = False
43+
44+
# Same for "shape"
45+
if "shape" in operatorRepresentation.keys():
46+
ctxt.globalObjects[operatorRepresentation["shape"]]._deploy = False
47+
ctxt.globalObjects[operatorRepresentation["shape"]]._live = False
48+
49+
inBuffer = ctxt.lookup(operatorRepresentation['data_in'])
50+
outBuffer = ctxt.lookup(operatorRepresentation['data_out'])
51+
outBuffer._alias = inBuffer.name
52+
53+
return ctxt, operatorRepresentation, []
54+
55+
56+
referenceTemplate = _ReshapeTemplate("""
57+
// Reshape (Name: ${nodeName}, Op: ${nodeOp})
58+
${data_out} = ${data_in};
59+
""")

Deeploy/Targets/PULPOpen/Tiler.py

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@
2626

2727
import copy
2828

29-
from Deeploy.CommonExtensions.CodeTransformationPasses.MemoryAllocation import MemoryPassthroughGeneration
30-
from Deeploy.DeeployTypes import CodeTransformation
31-
from Deeploy.Targets.Generic.Bindings import BasicReshapeBindings
3229
from Deeploy.Targets.Generic.TileConstraints.AddTileConstraint import AddTileConstraint
3330
from Deeploy.Targets.Generic.TileConstraints.ConcatTileConstraint import ConcatTileConstraint
3431
from Deeploy.Targets.Generic.TileConstraints.iHardswishTileConstraint import iHardswishTileConstraint
@@ -43,10 +40,11 @@
4340
from Deeploy.Targets.PULPOpen.Bindings import PULPAddBindings, PULPConcatBindings, PULPFloatConv2DBindings, \
4441
PULPFloatGELUBinding, PULPFloatGEMMBindings, PULPGatherBindings, PULPiHardswishBindings, PULPiRMSNormBindings, \
4542
PULPiRQSGELUBindings, PULPLayernormBinding, PULPMatMulBindings, PULPMaxPool2DBindings, PULPMulBindings, \
46-
PULPReduceSumBindings, PULPReluBinding, PULPRQAddBindings, PULPRQSBindings, PULPRQSConv2DBindings, \
47-
PULPRQSDWConv2DBindings, PULPRQSGEMMBindings, PULPRQSiHardswishBindings, PULPRQSMatrixVecBindings, \
48-
PULPRQSTallGEMMBindings, PULPSGDBindings, PULPSoftmaxBindings, PULPSoftmaxCrossEntropyLossBindings, \
49-
PULPSoftmaxCrossEntropyLossGradBindings, PULPSoftmaxGradBindings, PULPTransposeBindings, PULPUniformRQSBindings
43+
PULPReduceSumBindings, PULPReluBinding, PULPReshapeBindings, PULPRQAddBindings, PULPRQSBindings, \
44+
PULPRQSConv2DBindings, PULPRQSDWConv2DBindings, PULPRQSGEMMBindings, PULPRQSiHardswishBindings, \
45+
PULPRQSMatrixVecBindings, PULPRQSTallGEMMBindings, PULPSGDBindings, PULPSoftmaxBindings, \
46+
PULPSoftmaxCrossEntropyLossBindings, PULPSoftmaxCrossEntropyLossGradBindings, PULPSoftmaxGradBindings, \
47+
PULPTransposeBindings, PULPUniformRQSBindings
5048
from Deeploy.Targets.PULPOpen.TileConstraints.ConvTileConstraint import Conv2DTileConstraint, RQConv2DTileConstraint
5149
from Deeploy.Targets.PULPOpen.TileConstraints.DWConvTileConstraint import DWConv2DTileConstraint
5250
from Deeploy.Targets.PULPOpen.TileConstraints.GatherTileConstraint import GatherTileConstraint
@@ -95,9 +93,7 @@
9593
PULPRQSiHardswishTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = PULPRQSiHardswishBindings,
9694
tileConstraint = RQSiHardswishTileConstraint())
9795

98-
_BasicFlattenBindings = copy.deepcopy(BasicReshapeBindings)
99-
for binding in _BasicFlattenBindings:
100-
binding.codeTransformer = CodeTransformation([MemoryPassthroughGeneration("L.*"), MemoryPassthroughGeneration()])
96+
_BasicFlattenBindings = copy.deepcopy(PULPReshapeBindings)
10197

10298
PULPFlattenTilingReadyBindings = TilingReadyNodeBindings(nodeBindings = _BasicFlattenBindings,
10399
tileConstraint = NOPTileConstraint())

0 commit comments

Comments
 (0)