Skip to content

Commit 59ac79e

Browse files
authored
Fix PULP GEMM batch serialization (#109)
PULPOpen: Fix serialization of the batch variable in MatMul and GEMM tile constraints
1 parent a23b15f commit 59ac79e

File tree

11 files changed

+91
-186
lines changed

11 files changed

+91
-186
lines changed

.github/workflows/ci-platform-siracusa-tiled.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,9 @@ jobs:
5555
{"name":"testFloatSoftmax","L1":[4000]},
5656
{"name":"testFloatTranspose","L1":[2000]},
5757
{"name":"testFloatMul","L1":[2000]},
58-
{"name":"largeFloatAdd","L1":[220000]}
58+
{"name":"largeFloatAdd","L1":[220000]},
59+
{"name":"testRQGEMMwBatch","L1":[20000]},
60+
{"name":"testMatMulBatch","L1":[20000]}
5961
]
6062
num-cores: 8
6163

CHANGELOG.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,7 @@ This file contains the changelog for the Deeploy project. The changelog is divid
44
## Unreleased (Planned Release Target: v0.2.1)
55

66
### List of Pull Requests
7+
- Fix PULP GEMM `batch` serialization [#109](https://github.com/pulp-platform/Deeploy/pull/109)
78
- Split CI Workflows by Platform and Task, Improve Formatting and Linting Reliability [#108](https://github.com/pulp-platform/Deeploy/pull/108)
89
- Refactor tiling code generation [#105](https://github.com/pulp-platform/Deeploy/pull/105)
910
- Change order of typeMatching entries [#68](https://github.com/pulp-platform/Deeploy/pull/68)
@@ -61,6 +62,7 @@ This file contains the changelog for the Deeploy project. The changelog is divid
6162
- Prevent node duplication for graphs generated via GraphSurgeon
6263
- Resolved issue with missing `id` in the `Build Cache for Docker` step, used in the `Inject build-cache` step.
6364
- Fix license CI check and prevent potential issues with `jq` installation
65+
- PULP Gemm `batch` variable serialization
6466

6567
### Removed
6668
- Delete outdated and unused `.gitlab-ci.yml` file

Deeploy/Targets/PULPOpen/TileConstraints/GEMMTileConstraint.py

Lines changed: 58 additions & 135 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# See the License for the specific language governing permissions and
2525
# limitations under the License.
2626

27+
import math
2728
from typing import Dict, List, Tuple
2829

2930
from Deeploy.AbstractDataTypes import PointerClass
@@ -135,25 +136,22 @@ def serializeTilingSolution(
135136

136137
# Every output is constructed by a pair of inputs. Reconstruct this pair.
137138
for cube in outputCubes:
139+
MOffset, OOffset = cube.offset[-2:]
140+
MSize, OSize = cube.dims[-2:]
138141

139-
BSize = 1
140-
BOffset = 0
141-
BatchSize = 1
142-
BatchOffset = 0
143-
144-
if len(cube.offset) == 2:
145-
(MOffset, OOffset) = cube.offset
146-
(MSize, OSize) = cube.dims
147-
elif len(cube.offset) == 3:
148-
(BatchOffset, MOffset, OOffset) = cube.offset
149-
(BatchSize, MSize, OSize) = cube.dims
142+
if len(cube.offset) > 2:
143+
BatchSize = math.prod(cube.dims[:-2])
144+
145+
if len(cube.offset) > 3:
146+
assert all(off == 0 for off in cube.offset[:-3]), (
147+
f"Unsupported tiling across leading batch dims: offsets={cube.offset}. "
148+
"Only the last batch dim (besides M/O) may be tiled.")
150149
else:
151-
(BatchOffset, BOffset, MOffset, OOffset) = cube.offset
152-
(BatchSize, BSize, MSize, OSize) = cube.dims
150+
BatchSize = 1
153151

154152
replacements["M"].append(MSize)
155153
replacements["O"].append(OSize)
156-
replacements["batch"].append(BSize)
154+
replacements["batch"].append(BatchSize)
157155

158156
if transA == 0:
159157
AMatrixOffsets = (MOffset, NOffset)
@@ -162,49 +160,30 @@ def serializeTilingSolution(
162160
AMatrixOffsets = (NOffset, MOffset)
163161
AMatrixShape = (NSize, MSize)
164162

163+
if len(buffA.shape) > 2:
164+
batchDimCount = len(buffA.shape) - 2
165+
AMatrixOffsets = tuple(cube.offset[:-2][-batchDimCount:]) + AMatrixOffsets
166+
AMatrixShape = tuple(cube.dims[:-2][-batchDimCount:]) + AMatrixShape
167+
168+
ACube = HyperRectangle(AMatrixOffsets, AMatrixShape)
169+
inputACubes.append(ACube)
170+
165171
if transB == 0:
166172
BMatrixOffsets = (NOffset, OOffset)
167173
BMatrixShape = (NSize, OSize)
168174
else:
169175
BMatrixOffsets = (OOffset, NOffset)
170176
BMatrixShape = (OSize, NSize)
171177

172-
if len(buffA.shape) == 2:
173-
ACube = HyperRectangle(AMatrixOffsets, AMatrixShape)
174-
elif len(buffA.shape) == 3:
175-
ACube = HyperRectangle((BatchOffset,) + AMatrixOffsets, (BatchSize,) + AMatrixShape)
176-
else:
177-
ACube = HyperRectangle(
178-
(
179-
BatchOffset,
180-
BOffset,
181-
) + AMatrixOffsets,
182-
(
183-
BatchSize,
184-
BSize,
185-
) + AMatrixShape,
186-
)
187-
188-
if len(buffB.shape) == 2:
189-
BCube = HyperRectangle(BMatrixOffsets, BMatrixShape)
190-
elif len(buffB.shape) == 3:
191-
BCube = HyperRectangle((BatchOffset,) + BMatrixOffsets, (BatchSize,) + BMatrixShape)
192-
else:
193-
BCube = HyperRectangle(
194-
(
195-
BatchOffset,
196-
BOffset,
197-
) + BMatrixOffsets,
198-
(
199-
BatchSize,
200-
BSize,
201-
) + BMatrixShape,
202-
)
203-
204-
RequantCube = HyperRectangle((OOffset,), (OSize,))
178+
if len(buffB.shape) > 2:
179+
batchDimCount = len(buffB.shape) - 2
180+
BMatrixOffsets = tuple(cube.offset[:-2][-batchDimCount:]) + BMatrixOffsets
181+
BMatrixShape = tuple(cube.dims[:-2][-batchDimCount:]) + BMatrixShape
205182

206-
inputACubes.append(ACube)
183+
BCube = HyperRectangle(BMatrixOffsets, BMatrixShape)
207184
inputBCubes.append(BCube)
185+
186+
RequantCube = HyperRectangle((OOffset,), (OSize,))
208187
inputMulCubes.append(RequantCube)
209188
inputAddCubes.append(RequantCube)
210189

@@ -231,40 +210,6 @@ def serializeTilingSolution(
231210
return VariableReplacementScheme(replacements, replacementTypes), schedule
232211

233212

234-
class MatrixVecTileConstraint(GEMMTileConstraint):
235-
236-
@staticmethod
237-
def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel:
238-
239-
tm = GEMMTileConstraint.addGeometricalConstraint(tilerModel, parseDict, ctxt)
240-
241-
return tm
242-
243-
@staticmethod
244-
def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel:
245-
246-
tm = GEMMTileConstraint.addPolicyConstraint(tilerModel, parseDict, ctxt)
247-
248-
return tm
249-
250-
251-
class TallGEMMTileConstraint(GEMMTileConstraint):
252-
253-
@staticmethod
254-
def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel:
255-
256-
tm = GEMMTileConstraint.addGeometricalConstraint(tilerModel, parseDict, ctxt)
257-
258-
return tm
259-
260-
@staticmethod
261-
def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkContext) -> TilerModel:
262-
263-
tm = GEMMTileConstraint.addPolicyConstraint(tilerModel, parseDict, ctxt)
264-
265-
return tm
266-
267-
268213
class FloatGEMMTileConstraint(TileConstraint):
269214

270215
@staticmethod
@@ -367,25 +312,22 @@ def serializeTilingSolution(
367312

368313
# Every output is constructed by a pair of inputs. Reconstruct this pair.
369314
for cube in outputCubes:
315+
MOffset, OOffset = cube.offset[-2:]
316+
MSize, OSize = cube.dims[-2:]
370317

371-
BSize = 1
372-
BOffset = 0
373-
BatchSize = 1
374-
BatchOffset = 0
375-
376-
if len(cube.offset) == 2:
377-
(MOffset, OOffset) = cube.offset
378-
(MSize, OSize) = cube.dims
379-
elif len(cube.offset) == 3:
380-
(BatchOffset, MOffset, OOffset) = cube.offset
381-
(BatchSize, MSize, OSize) = cube.dims
318+
if len(cube.offset) > 2:
319+
BatchSize = math.prod(cube.dims[:-2])
320+
321+
if len(cube.offset) > 3:
322+
assert all(off == 0 for off in cube.offset[:-3]), (
323+
f"Unsupported tiling across leading batch dims: offsets={cube.offset}. "
324+
"Only the last batch dim (besides M/O) may be tiled.")
382325
else:
383-
(BatchOffset, BOffset, MOffset, OOffset) = cube.offset
384-
(BatchSize, BSize, MSize, OSize) = cube.dims
326+
BatchSize = 1
385327

386328
replacements["M"].append(MSize)
387329
replacements["O"].append(OSize)
388-
replacements["batch"].append(BSize)
330+
replacements["batch"].append(BatchSize)
389331

390332
if transA == 0:
391333
AMatrixOffsets = (MOffset, NOffset)
@@ -394,57 +336,38 @@ def serializeTilingSolution(
394336
AMatrixOffsets = (NOffset, MOffset)
395337
AMatrixShape = (NSize, MSize)
396338

339+
if len(buffA.shape) > 2:
340+
batchDimCount = len(buffA.shape) - 2
341+
AMatrixOffsets = tuple(cube.offset[:-2][-batchDimCount:]) + AMatrixOffsets
342+
AMatrixShape = tuple(cube.dims[:-2][-batchDimCount:]) + AMatrixShape
343+
344+
ACube = HyperRectangle(AMatrixOffsets, AMatrixShape)
345+
inputACubes.append(ACube)
346+
397347
if transB == 0:
398348
BMatrixOffsets = (NOffset, OOffset)
399349
BMatrixShape = (NSize, OSize)
400350
else:
401351
BMatrixOffsets = (OOffset, NOffset)
402352
BMatrixShape = (OSize, NSize)
403353

404-
if len(buffA.shape) == 2:
405-
ACube = HyperRectangle(AMatrixOffsets, AMatrixShape)
406-
elif len(buffA.shape) == 3:
407-
ACube = HyperRectangle((BatchOffset,) + AMatrixOffsets, (BatchSize,) + AMatrixShape)
408-
else:
409-
ACube = HyperRectangle(
410-
(
411-
BatchOffset,
412-
BOffset,
413-
) + AMatrixOffsets,
414-
(
415-
BatchSize,
416-
BSize,
417-
) + AMatrixShape,
418-
)
419-
420-
if len(buffB.shape) == 2:
421-
BCube = HyperRectangle(BMatrixOffsets, BMatrixShape)
422-
elif len(buffB.shape) == 3:
423-
BCube = HyperRectangle((BatchOffset,) + BMatrixOffsets, (BatchSize,) + BMatrixShape)
424-
else:
425-
BCube = HyperRectangle(
426-
(
427-
BatchOffset,
428-
BOffset,
429-
) + BMatrixOffsets,
430-
(
431-
BatchSize,
432-
BSize,
433-
) + BMatrixShape,
434-
)
354+
if len(buffB.shape) > 2:
355+
batchDimCount = len(buffB.shape) - 2
356+
BMatrixOffsets = tuple(cube.offset[:-2][-batchDimCount:]) + BMatrixOffsets
357+
BMatrixShape = tuple(cube.dims[:-2][-batchDimCount:]) + BMatrixShape
358+
359+
BCube = HyperRectangle(BMatrixOffsets, BMatrixShape)
360+
inputBCubes.append(BCube)
435361

436362
CMatrixOffsets = (MOffset, OOffset)
437363
CMatrixShape = (MSize, OSize)
438364

439-
if len(buffC.shape) == 2:
440-
CCube = HyperRectangle(CMatrixOffsets, CMatrixShape)
441-
elif len(buffC.shape) == 3:
442-
CCube = HyperRectangle((BatchOffset,) + CMatrixOffsets, (BatchSize,) + CMatrixShape)
443-
else:
444-
CCube = HyperRectangle((BatchOffset, BOffset) + CMatrixOffsets, (BatchSize, BSize) + CMatrixShape)
365+
if len(buffC.shape) > 2:
366+
batchDimCount = len(buffC.shape) - 2
367+
CMatrixOffsets = tuple(cube.offset[:-2][-batchDimCount:]) + CMatrixOffsets
368+
CMatrixShape = tuple(cube.dims[:-2][-batchDimCount:]) + CMatrixShape
445369

446-
inputACubes.append(ACube)
447-
inputBCubes.append(BCube)
370+
CCube = HyperRectangle(CMatrixOffsets, CMatrixShape)
448371
inputAddCubes.append(CCube)
449372

450373
inputLoadSchedule = []

Deeploy/Targets/PULPOpen/TileConstraints/MatMulTileConstraint.py

Lines changed: 25 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
# See the License for the specific language governing permissions and
2525
# limitations under the License.
2626

27+
import math
2728
from typing import Dict, List, Tuple
2829

2930
from Deeploy.AbstractDataTypes import PointerClass
@@ -125,65 +126,43 @@ def serializeTilingSolution(
125126

126127
# Every output is constructed by a pair of inputs. Reconstruct this pair.
127128
for cube in outputCubes:
129+
MOffset, OOffset = cube.offset[-2:]
130+
MSize, OSize = cube.dims[-2:]
128131

129-
BSize = 1
130-
BOffset = 0
131-
BatchSize = 1
132-
BatchOffset = 0
133-
134-
if len(cube.offset) == 2:
135-
(MOffset, OOffset) = cube.offset
136-
(MSize, OSize) = cube.dims
137-
elif len(cube.offset) == 3:
138-
(BatchOffset, MOffset, OOffset) = cube.offset
139-
(BatchSize, MSize, OSize) = cube.dims
132+
if len(cube.offset) > 2:
133+
BatchSize = math.prod(cube.dims[:-2])
134+
135+
if len(cube.offset) > 3:
136+
assert all(off == 0 for off in cube.offset[:-3]), (
137+
f"Unsupported tiling across leading batch dims: offsets={cube.offset}. "
138+
"Only the last batch dim (besides M/O) may be tiled.")
140139
else:
141-
(BatchOffset, BOffset, MOffset, OOffset) = cube.offset
142-
(BatchSize, BSize, MSize, OSize) = cube.dims
140+
BatchSize = 1
143141

144142
replacements["M"].append(MSize)
145143
replacements["O"].append(OSize)
146-
replacements["batch"].append(BSize)
144+
replacements["batch"].append(BatchSize)
147145

148146
AMatrixOffsets = (MOffset, NOffset)
149147
AMatrixShape = (MSize, NSize)
150148

149+
if len(buffA.shape) > 2:
150+
batchDimCount = len(buffA.shape) - 2
151+
AMatrixOffsets = tuple(cube.offset[:-2][-batchDimCount:]) + AMatrixOffsets
152+
AMatrixShape = tuple(cube.dims[:-2][-batchDimCount:]) + AMatrixShape
153+
154+
ACube = HyperRectangle(AMatrixOffsets, AMatrixShape)
155+
inputACubes.append(ACube)
156+
151157
BMatrixOffsets = (NOffset, OOffset)
152158
BMatrixShape = (NSize, OSize)
153159

154-
if len(buffA.shape) == 2:
155-
ACube = HyperRectangle(AMatrixOffsets, AMatrixShape)
156-
elif len(buffA.shape) == 3:
157-
ACube = HyperRectangle((BatchOffset,) + AMatrixOffsets, (BatchSize,) + AMatrixShape)
158-
else:
159-
ACube = HyperRectangle(
160-
(
161-
BatchOffset,
162-
BOffset,
163-
) + AMatrixOffsets,
164-
(
165-
BatchSize,
166-
BSize,
167-
) + AMatrixShape,
168-
)
169-
170-
if len(buffB.shape) == 2:
171-
BCube = HyperRectangle(BMatrixOffsets, BMatrixShape)
172-
elif len(buffB.shape) == 3:
173-
BCube = HyperRectangle((BatchOffset,) + BMatrixOffsets, (BatchSize,) + BMatrixShape)
174-
else:
175-
BCube = HyperRectangle(
176-
(
177-
BatchOffset,
178-
BOffset,
179-
) + BMatrixOffsets,
180-
(
181-
BatchSize,
182-
BSize,
183-
) + BMatrixShape,
184-
)
160+
if len(buffB.shape) > 2:
161+
batchDimCount = len(buffB.shape) - 2
162+
BMatrixOffsets = tuple(cube.offset[:-2][-batchDimCount:]) + BMatrixOffsets
163+
BMatrixShape = tuple(cube.dims[:-2][-batchDimCount:]) + BMatrixShape
185164

186-
inputACubes.append(ACube)
165+
BCube = HyperRectangle(BMatrixOffsets, BMatrixShape)
187166
inputBCubes.append(BCube)
188167

189168
inputLoadSchedule = []

0 commit comments

Comments
 (0)