2424# See the License for the specific language governing permissions and
2525# limitations under the License.
2626
27+ import math
2728from typing import Dict , List , Tuple
2829
2930from Deeploy .AbstractDataTypes import PointerClass
@@ -135,25 +136,22 @@ def serializeTilingSolution(
135136
136137 # Every output is constructed by a pair of inputs. Reconstruct this pair.
137138 for cube in outputCubes :
139+ MOffset , OOffset = cube .offset [- 2 :]
140+ MSize , OSize = cube .dims [- 2 :]
138141
139- BSize = 1
140- BOffset = 0
141- BatchSize = 1
142- BatchOffset = 0
143-
144- if len (cube .offset ) == 2 :
145- (MOffset , OOffset ) = cube .offset
146- (MSize , OSize ) = cube .dims
147- elif len (cube .offset ) == 3 :
148- (BatchOffset , MOffset , OOffset ) = cube .offset
149- (BatchSize , MSize , OSize ) = cube .dims
142+ if len (cube .offset ) > 2 :
143+ BatchSize = math .prod (cube .dims [:- 2 ])
144+
145+ if len (cube .offset ) > 3 :
146+ assert all (off == 0 for off in cube .offset [:- 3 ]), (
147+ f"Unsupported tiling across leading batch dims: offsets={ cube .offset } . "
148+ "Only the last batch dim (besides M/O) may be tiled." )
150149 else :
151- (BatchOffset , BOffset , MOffset , OOffset ) = cube .offset
152- (BatchSize , BSize , MSize , OSize ) = cube .dims
150+ BatchSize = 1
153151
154152 replacements ["M" ].append (MSize )
155153 replacements ["O" ].append (OSize )
156- replacements ["batch" ].append (BSize )
154+ replacements ["batch" ].append (BatchSize )
157155
158156 if transA == 0 :
159157 AMatrixOffsets = (MOffset , NOffset )
@@ -162,49 +160,30 @@ def serializeTilingSolution(
162160 AMatrixOffsets = (NOffset , MOffset )
163161 AMatrixShape = (NSize , MSize )
164162
163+ if len (buffA .shape ) > 2 :
164+ batchDimCount = len (buffA .shape ) - 2
165+ AMatrixOffsets = tuple (cube .offset [:- 2 ][- batchDimCount :]) + AMatrixOffsets
166+ AMatrixShape = tuple (cube .dims [:- 2 ][- batchDimCount :]) + AMatrixShape
167+
168+ ACube = HyperRectangle (AMatrixOffsets , AMatrixShape )
169+ inputACubes .append (ACube )
170+
165171 if transB == 0 :
166172 BMatrixOffsets = (NOffset , OOffset )
167173 BMatrixShape = (NSize , OSize )
168174 else :
169175 BMatrixOffsets = (OOffset , NOffset )
170176 BMatrixShape = (OSize , NSize )
171177
172- if len (buffA .shape ) == 2 :
173- ACube = HyperRectangle (AMatrixOffsets , AMatrixShape )
174- elif len (buffA .shape ) == 3 :
175- ACube = HyperRectangle ((BatchOffset ,) + AMatrixOffsets , (BatchSize ,) + AMatrixShape )
176- else :
177- ACube = HyperRectangle (
178- (
179- BatchOffset ,
180- BOffset ,
181- ) + AMatrixOffsets ,
182- (
183- BatchSize ,
184- BSize ,
185- ) + AMatrixShape ,
186- )
187-
188- if len (buffB .shape ) == 2 :
189- BCube = HyperRectangle (BMatrixOffsets , BMatrixShape )
190- elif len (buffB .shape ) == 3 :
191- BCube = HyperRectangle ((BatchOffset ,) + BMatrixOffsets , (BatchSize ,) + BMatrixShape )
192- else :
193- BCube = HyperRectangle (
194- (
195- BatchOffset ,
196- BOffset ,
197- ) + BMatrixOffsets ,
198- (
199- BatchSize ,
200- BSize ,
201- ) + BMatrixShape ,
202- )
203-
204- RequantCube = HyperRectangle ((OOffset ,), (OSize ,))
178+ if len (buffB .shape ) > 2 :
179+ batchDimCount = len (buffB .shape ) - 2
180+ BMatrixOffsets = tuple (cube .offset [:- 2 ][- batchDimCount :]) + BMatrixOffsets
181+ BMatrixShape = tuple (cube .dims [:- 2 ][- batchDimCount :]) + BMatrixShape
205182
206- inputACubes . append ( ACube )
183+ BCube = HyperRectangle ( BMatrixOffsets , BMatrixShape )
207184 inputBCubes .append (BCube )
185+
186+ RequantCube = HyperRectangle ((OOffset ,), (OSize ,))
208187 inputMulCubes .append (RequantCube )
209188 inputAddCubes .append (RequantCube )
210189
@@ -231,40 +210,6 @@ def serializeTilingSolution(
231210 return VariableReplacementScheme (replacements , replacementTypes ), schedule
232211
233212
234- class MatrixVecTileConstraint (GEMMTileConstraint ):
235-
236- @staticmethod
237- def addGeometricalConstraint (tilerModel : TilerModel , parseDict : Dict , ctxt : NetworkContext ) -> TilerModel :
238-
239- tm = GEMMTileConstraint .addGeometricalConstraint (tilerModel , parseDict , ctxt )
240-
241- return tm
242-
243- @staticmethod
244- def addPolicyConstraint (tilerModel : TilerModel , parseDict : Dict , ctxt : NetworkContext ) -> TilerModel :
245-
246- tm = GEMMTileConstraint .addPolicyConstraint (tilerModel , parseDict , ctxt )
247-
248- return tm
249-
250-
251- class TallGEMMTileConstraint (GEMMTileConstraint ):
252-
253- @staticmethod
254- def addGeometricalConstraint (tilerModel : TilerModel , parseDict : Dict , ctxt : NetworkContext ) -> TilerModel :
255-
256- tm = GEMMTileConstraint .addGeometricalConstraint (tilerModel , parseDict , ctxt )
257-
258- return tm
259-
260- @staticmethod
261- def addPolicyConstraint (tilerModel : TilerModel , parseDict : Dict , ctxt : NetworkContext ) -> TilerModel :
262-
263- tm = GEMMTileConstraint .addPolicyConstraint (tilerModel , parseDict , ctxt )
264-
265- return tm
266-
267-
268213class FloatGEMMTileConstraint (TileConstraint ):
269214
270215 @staticmethod
@@ -367,25 +312,22 @@ def serializeTilingSolution(
367312
368313 # Every output is constructed by a pair of inputs. Reconstruct this pair.
369314 for cube in outputCubes :
315+ MOffset , OOffset = cube .offset [- 2 :]
316+ MSize , OSize = cube .dims [- 2 :]
370317
371- BSize = 1
372- BOffset = 0
373- BatchSize = 1
374- BatchOffset = 0
375-
376- if len (cube .offset ) == 2 :
377- (MOffset , OOffset ) = cube .offset
378- (MSize , OSize ) = cube .dims
379- elif len (cube .offset ) == 3 :
380- (BatchOffset , MOffset , OOffset ) = cube .offset
381- (BatchSize , MSize , OSize ) = cube .dims
318+ if len (cube .offset ) > 2 :
319+ BatchSize = math .prod (cube .dims [:- 2 ])
320+
321+ if len (cube .offset ) > 3 :
322+ assert all (off == 0 for off in cube .offset [:- 3 ]), (
323+ f"Unsupported tiling across leading batch dims: offsets={ cube .offset } . "
324+ "Only the last batch dim (besides M/O) may be tiled." )
382325 else :
383- (BatchOffset , BOffset , MOffset , OOffset ) = cube .offset
384- (BatchSize , BSize , MSize , OSize ) = cube .dims
326+ BatchSize = 1
385327
386328 replacements ["M" ].append (MSize )
387329 replacements ["O" ].append (OSize )
388- replacements ["batch" ].append (BSize )
330+ replacements ["batch" ].append (BatchSize )
389331
390332 if transA == 0 :
391333 AMatrixOffsets = (MOffset , NOffset )
@@ -394,57 +336,38 @@ def serializeTilingSolution(
394336 AMatrixOffsets = (NOffset , MOffset )
395337 AMatrixShape = (NSize , MSize )
396338
339+ if len (buffA .shape ) > 2 :
340+ batchDimCount = len (buffA .shape ) - 2
341+ AMatrixOffsets = tuple (cube .offset [:- 2 ][- batchDimCount :]) + AMatrixOffsets
342+ AMatrixShape = tuple (cube .dims [:- 2 ][- batchDimCount :]) + AMatrixShape
343+
344+ ACube = HyperRectangle (AMatrixOffsets , AMatrixShape )
345+ inputACubes .append (ACube )
346+
397347 if transB == 0 :
398348 BMatrixOffsets = (NOffset , OOffset )
399349 BMatrixShape = (NSize , OSize )
400350 else :
401351 BMatrixOffsets = (OOffset , NOffset )
402352 BMatrixShape = (OSize , NSize )
403353
404- if len (buffA .shape ) == 2 :
405- ACube = HyperRectangle (AMatrixOffsets , AMatrixShape )
406- elif len (buffA .shape ) == 3 :
407- ACube = HyperRectangle ((BatchOffset ,) + AMatrixOffsets , (BatchSize ,) + AMatrixShape )
408- else :
409- ACube = HyperRectangle (
410- (
411- BatchOffset ,
412- BOffset ,
413- ) + AMatrixOffsets ,
414- (
415- BatchSize ,
416- BSize ,
417- ) + AMatrixShape ,
418- )
419-
420- if len (buffB .shape ) == 2 :
421- BCube = HyperRectangle (BMatrixOffsets , BMatrixShape )
422- elif len (buffB .shape ) == 3 :
423- BCube = HyperRectangle ((BatchOffset ,) + BMatrixOffsets , (BatchSize ,) + BMatrixShape )
424- else :
425- BCube = HyperRectangle (
426- (
427- BatchOffset ,
428- BOffset ,
429- ) + BMatrixOffsets ,
430- (
431- BatchSize ,
432- BSize ,
433- ) + BMatrixShape ,
434- )
354+ if len (buffB .shape ) > 2 :
355+ batchDimCount = len (buffB .shape ) - 2
356+ BMatrixOffsets = tuple (cube .offset [:- 2 ][- batchDimCount :]) + BMatrixOffsets
357+ BMatrixShape = tuple (cube .dims [:- 2 ][- batchDimCount :]) + BMatrixShape
358+
359+ BCube = HyperRectangle (BMatrixOffsets , BMatrixShape )
360+ inputBCubes .append (BCube )
435361
436362 CMatrixOffsets = (MOffset , OOffset )
437363 CMatrixShape = (MSize , OSize )
438364
439- if len (buffC .shape ) == 2 :
440- CCube = HyperRectangle (CMatrixOffsets , CMatrixShape )
441- elif len (buffC .shape ) == 3 :
442- CCube = HyperRectangle ((BatchOffset ,) + CMatrixOffsets , (BatchSize ,) + CMatrixShape )
443- else :
444- CCube = HyperRectangle ((BatchOffset , BOffset ) + CMatrixOffsets , (BatchSize , BSize ) + CMatrixShape )
365+ if len (buffC .shape ) > 2 :
366+ batchDimCount = len (buffC .shape ) - 2
367+ CMatrixOffsets = tuple (cube .offset [:- 2 ][- batchDimCount :]) + CMatrixOffsets
368+ CMatrixShape = tuple (cube .dims [:- 2 ][- batchDimCount :]) + CMatrixShape
445369
446- inputACubes .append (ACube )
447- inputBCubes .append (BCube )
370+ CCube = HyperRectangle (CMatrixOffsets , CMatrixShape )
448371 inputAddCubes .append (CCube )
449372
450373 inputLoadSchedule = []
0 commit comments