@@ -248,7 +248,7 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
248248 """
249249
250250 # ===== GET NECESSARY INFORMATION =====
251- # Get to-be-tiled tensor buffers
251+ # Get to-be-tiled tensor buffers
252252 inputBufferName = parseDict ['data_in' ]
253253 outputBufferName = parseDict ['data_out' ]
254254
@@ -257,7 +257,7 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
257257
258258 inputBuffer = ctxt .lookup (inputBufferName )
259259
260- # Get other information
260+ # Get other information
261261 has_bias = False if parseDict ['has_bias' ] == "false" else True
262262
263263 pads = parseDict ["pads" ]
@@ -289,8 +289,8 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
289289 outputChannelVar = tilerModel .getTensorDimVar (tensorName = outputBufferName , dimIdx = 3 )
290290
291291 # Weight
292- # C_out - H - W layout - C_in (depthwise convolution weights,
293- # with c_in used for grouping different than number of channels)
292+ # C_out - H - W layout - C_in
293+ # ( with c_in used for grouping different than number of channels)
294294 weightOutChannelVar = tilerModel .getTensorDimVar (tensorName = weightBufferName , dimIdx = 0 )
295295 weightHeightVar = tilerModel .getTensorDimVar (tensorName = weightBufferName , dimIdx = 1 )
296296 weightWidthVar = tilerModel .getTensorDimVar (tensorName = weightBufferName , dimIdx = 2 )
@@ -301,7 +301,7 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
301301 biasDimVar = tilerModel .getTensorDimVar (tensorName = biasBufferName , dimIdx = 0 )
302302
303303 # ===== ADD CONSTRAINTS =====
304- # Add constraint for batch size match between input and output
304+ # Add constraint for batch size match between input and output
305305 tilerModel .addConstraint (outputBatchVar == inputBatchVar )
306306
307307 # Add constraint for input width and height sizes match
@@ -335,33 +335,41 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
335335 @staticmethod
336336 def addPolicyConstraint (tilerModel : TilerModel , parseDict : Dict , ctxt : NetworkContext ) -> TilerModel :
337337
338- # Get to-be-tiled tensor's buffers
338+ # ===== GET NECESSARY INFORMATION =====
339+ # Get to-be-tiled tensor buffers
339340 inputBuffer = ctxt .lookup (name = parseDict ['data_in' ])
340341 weightBuffer = ctxt .lookup (name = parseDict ['weight' ])
341342
343+ # Get other information
344+ strides = parseDict ["strides" ]
345+
346+ # ===== EXTRACT TENSOR DIMS AS VARS =====
347+ # Input
348+ # NHWC layout
342349 inputHeightVar = tilerModel .getTensorDimVar (tensorName = inputBuffer .name , dimIdx = 1 )
343350 inputWidthVar = tilerModel .getTensorDimVar (tensorName = inputBuffer .name , dimIdx = 2 )
344351 inputChannelVar = tilerModel .getTensorDimVar (tensorName = inputBuffer .name , dimIdx = 3 )
345352
346- outputChannelVar = tilerModel .getTensorDimVar (tensorName = weightBuffer .name , dimIdx = 0 )
353+ # Weight
354+ # C_out - H - W layout - C_in
355+ # (with c_in used for grouping different than number of channels)
347356 weightHeightVar = tilerModel .getTensorDimVar (tensorName = weightBuffer .name , dimIdx = 1 )
348357 weightWidthVar = tilerModel .getTensorDimVar (tensorName = weightBuffer .name , dimIdx = 2 )
349358 weightInChannelVar = tilerModel .getTensorDimVar (tensorName = weightBuffer .name , dimIdx = 3 )
350359
351- strides = parseDict ["strides" ]
352-
353- # Keep input entire channels (required for im2col algorithm)
360+ # ===== ADD CONSTRAINTS =====
361+ # Keep whole input channels (required for im2col algorithm)
354362 tilerModel .addConstraint (inputChannelVar == parseDict ['ch_im_in' ])
355363
356- # Require minimum spatial dimensions to be at least kernel size
364+ # Require minimum input spatial dimensions to be at least kernel size for proper convolution application
357365 tilerModel .addConstraint (inputHeightVar >= parseDict ['dim_kernel_x' ])
358366 tilerModel .addConstraint (inputWidthVar >= parseDict ['dim_kernel_y' ])
359367
360- # Ensure input tiles are compatible with stride
368+ # Ensure input tiles are compatible with stride
361369 tilerModel .addConstraint ((inputHeightVar % strides [0 ]) == 0 )
362370 tilerModel .addConstraint ((inputWidthVar % strides [1 ]) == 0 )
363371
364- # Keep entire weight dimensions
372+ # Weight should not be tiled
365373 tilerModel .addConstraint (weightHeightVar == parseDict ['dim_kernel_x' ])
366374 tilerModel .addConstraint (weightWidthVar == parseDict ['dim_kernel_y' ])
367375 tilerModel .addConstraint (weightInChannelVar == parseDict ['ch_im_in' ])
@@ -401,14 +409,18 @@ def computeInputCube(
401409 outputAbsoluteOffsets : Optional [Tuple [int , int , int , int ]] = None ,
402410 ) -> Tuple [HyperRectangle , Tuple [int , int , int , int ]]:
403411
404- (outputBatchOffset , outputHOffset , outputWOffset , outputCOffset ) = outputCube .offset
405- (outputBatchSize , outputHSize , outputWSize , outputCSize ) = outputCube .dims
406- (outputBatchAbsoluteOffset , outputHAbsoluteOffset , outputWAbsoluteOffset ,
407- outputCAbsoluteOffset ) = outputAbsoluteOffsets if outputAbsoluteOffsets is not None else outputCube .offset
412+ # Obtain relative and absolute information about the output tile
413+ (outputBatchOffset , outputHOffset , outputWOffset , _ ) = outputCube .offset
414+ (outputBatchSize , outputHSize , outputWSize , _ ) = outputCube .dims
415+ (_ , outputHAbsoluteOffset , outputWAbsoluteOffset ,
416+ _ ) = outputAbsoluteOffsets if outputAbsoluteOffsets is not None else outputCube .offset
408417
418+ # Extract individual pads and strides
409419 padTop , padLeft , padBottom , padRight = pads
410420 strideH , strideW = strides
411421
422+ # Compute actuale tile padding, depending on tile position (keep padding only for margins situated at the edge).
423+ # Required for the Im2Col kernel that handles 0-padding internally.
412424 tilePadTop = padTop if (outputHAbsoluteOffset == 0 ) else 0
413425 tilePadLeft = padLeft if (outputWAbsoluteOffset == 0 ) else 0
414426 tilePadBottom = padBottom if (outputHAbsoluteOffset + outputHSize == outputDims [1 ]) else 0
@@ -431,6 +443,7 @@ def computeInputCube(
431443 inputHSize = min (inputHSize , inputDims [1 ] - inputHOffset )
432444 inputWSize = min (inputWSize , inputDims [2 ] - inputWOffset )
433445
446+ # Generate input tile object
434447 InCube = HyperRectangle ((outputBatchOffset , inputHOffset , inputWOffset , 0 ),
435448 (outputBatchSize , inputHSize , inputWSize , inputCSize ))
436449
@@ -505,8 +518,8 @@ def serializeTilingSolution(
505518 # Iterate throught the cubes in which the output will be split for tiling
506519 for idx , cube in enumerate (outputCubes ):
507520 # Obtain current cube offsets and dimensions
508- ( BatchOffset , HOffset , WOffset , COffset ) = cube .offset
509- (BatchSize , HSize , WSize , CSize ) = cube .dims
521+ COffset = cube .offset [ 3 ]
522+ (_ , HSize , WSize , CSize ) = cube .dims
510523
511524 # Compute input cube
512525 InCube , padding_tuple = Conv2DTileConstraint .computeInputCube (
0 commit comments