@@ -392,24 +392,28 @@ def constructSymbolicNodeRep(tilerModel: TilerModel, parseDict: Dict,
392392
393393 @staticmethod
394394 def computeInputCube (
395- kernelShape : Tuple [int , int ],
396- pads : Tuple [int , int , int , int ],
397- strides : Tuple [int , int ],
398- inputCSize : int ,
399- outputCube : HyperRectangle ,
400- outputDims : Tuple [int , int , int ],
401- inputDims : Optional [Tuple [int , int , int ]] = None ) -> Tuple [HyperRectangle , Tuple [int , int , int , int ]]:
395+ kernelShape : Tuple [int , int ],
396+ pads : Tuple [int , int , int , int ],
397+ strides : Tuple [int , int ],
398+ inputCSize : int ,
399+ outputCube : HyperRectangle ,
400+ outputDims : Tuple [int , int , int ],
401+ inputDims : Optional [Tuple [int , int , int ]] = None ,
402+ outputAbsoluteOffsets : Optional [Tuple [int , int , int , int ]] = None ,
403+ ) -> Tuple [HyperRectangle , Tuple [int , int , int , int ]]:
402404
403405 (outputBatchOffset , outputHOffset , outputWOffset , outputCOffset ) = outputCube .offset
404406 (outputBatchSize , outputHSize , outputWSize , outputCSize ) = outputCube .dims
407+ (outputBatchAbsoluteOffset , outputHAbsoluteOffset , outputWAbsoluteOffset ,
408+ outputCAbsoluteOffset ) = outputAbsoluteOffsets
405409
406410 padTop , padLeft , padBottom , padRight = pads
407411 strideH , strideW = strides
408412
409- tilePadTop = padTop if (outputHOffset == 0 ) else 0
410- tilePadLeft = padLeft if (outputWOffset == 0 ) else 0
411- tilePadBottom = padBottom if (outputHOffset + outputHSize == outputDims [1 ]) else 0
412- tilePadRight = padRight if (outputWOffset + outputWSize == outputDims [2 ]) else 0
413+ tilePadTop = padTop if (outputHAbsoluteOffset == 0 ) else 0
414+ tilePadLeft = padLeft if (outputWAbsoluteOffset == 0 ) else 0
415+ tilePadBottom = padBottom if (outputHAbsoluteOffset + outputHSize == outputDims [1 ]) else 0
416+ tilePadRight = padRight if (outputWAbsoluteOffset + outputWSize == outputDims [2 ]) else 0
413417
414418 # LMACAN: Calculating the per-dimension relative tile offset without padding
415419 # The offset is relative to the upstream bigger tile, and represents the offset to
@@ -500,7 +504,7 @@ def serializeTilingSolution(
500504 strides = operatorRepresentation ['strides' ]
501505
502506 # Iterate throught the cubes in which the output will be split for tiling
503- for cube in outputCubes :
507+ for idx , cube in enumerate ( outputCubes ) :
504508 # Obtain current cube offsets and dimensions
505509 (BatchOffset , HOffset , WOffset , COffset ) = cube .offset
506510 (BatchSize , HSize , WSize , CSize ) = cube .dims
@@ -514,7 +518,7 @@ def serializeTilingSolution(
514518 outputCube = cube ,
515519 inputDims = ctxt .lookup (varIn ).shape ,
516520 outputDims = ctxt .lookup (varOut ).shape ,
517- )
521+ outputAbsoluteOffsets = absoluteOutputCubes [ idx ]. absoluteOffset )
518522
519523 # Extract individual padding
520524 padding_left , padding_right , padding_top , padding_bottom = padding_tuple
0 commit comments