Skip to content

Commit c12c451

Browse files
committed
MatMul constraint fix
1 parent ec21316 commit c12c451

File tree

3 files changed

+12
-4
lines changed

3 files changed

+12
-4
lines changed

Deeploy/Targets/PULPOpen/TileConstraints/MatMulTileConstraint.py

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,18 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
5353
dimIdx = (tensorsShapeLenOutput - 1))
5454

5555
# ===== ADD CONSTRAINTS =====
56+
# Add batch constraints
57+
if (bufferA.shape[:-2] == bufferB.shape[:-2]):
58+
for idx in range(tensorsShapeLenA - 2):
59+
tilerModel.addConstraint(
60+
tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = tensorsShapeLenOutput - 3 - idx)
61+
== tilerModel.getTensorDimVar(tensorName = bufferA.name, dimIdx = tensorsShapeLenA - 3 - idx))
62+
63+
for idx in range(tensorsShapeLenB - 2):
64+
tilerModel.addConstraint(
65+
tilerModel.getTensorDimVar(tensorName = outputBuffer.name, dimIdx = tensorsShapeLenOutput - 3 - idx)
66+
== tilerModel.getTensorDimVar(tensorName = bufferB.name, dimIdx = tensorsShapeLenB - 3 - idx))
67+
5668
# Add GEMM geometrical constraints
5769
tilerModel.addConstraint(outputMatrixFirstDimVar == AMatrixFirstDimVar)
5870
tilerModel.addConstraint(outputMatrixSecondDimVar == BMatrixSecondDimVar)

Deeploy/Targets/PULPOpen/TileConstraints/ReduceMeanConstraint.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -62,13 +62,11 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
6262

6363
@staticmethod
6464
def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel:
65-
# TODO
6665
return tilerModel
6766

6867
@staticmethod
6968
def constructSymbolicNodeRep(tilerModel: TilerModel, parseDict: Dict,
7069
ctxt: NetworkContext) -> Dict[str, Union[int, IntVar]]:
71-
# TODO
7270
symbolicParseDict = parseDict.copy()
7371

7472
return symbolicParseDict

Deeploy/Targets/PULPOpen/TileConstraints/SliceConstraint.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,13 +59,11 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
5959

6060
@staticmethod
6161
def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel:
62-
# TODO
6362
return tilerModel
6463

6564
@staticmethod
6665
def constructSymbolicNodeRep(tilerModel: TilerModel, parseDict: Dict,
6766
ctxt: NetworkContext) -> Dict[str, Union[int, IntVar]]:
68-
# TODO
6967
symbolicParseDict = parseDict.copy()
7068

7169
return symbolicParseDict

0 commit comments

Comments
 (0)