@@ -300,23 +300,24 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
300300 if has_bias :
301301 biasDimVar = tilerModel .getTensorDimVar (tensorName = biasBufferName , dimIdx = 0 )
302302
303+ # ===== COMPUTE EFFECTIVE INPUT HEIGHT AND WIDTH =====
304+ # Assume worst case scenario (data padding on all sides) when tiling on a ceratin dimension.
305+ effectiveInputHeight = inputHeightVar + ((pads [0 ] + pads [2 ]) * (inputHeightVar == inputBuffer .shape [1 ])) - (
306+ (weightHeightVar - 1 ) * (inputHeightVar != inputBuffer .shape [1 ]))
307+ effectiveInputWidth = inputWidthVar + ((pads [1 ] + pads [3 ]) * (inputWidthVar == inputBuffer .shape [2 ])) - (
308+ (weightWidthVar - 1 ) * (inputWidthVar != inputBuffer .shape [2 ]))
309+
303310 # ===== ADD CONSTRAINTS =====
304311 # Add constraint for batch size match between input and output
305312 tilerModel .addConstraint (outputBatchVar == inputBatchVar )
306313
307314 # Add constraint for input width and height sizes match
308315 # (Depends on output height and width, kernel size, padding, dilations, and strides.
309316 # For more information on the connections, see ONNX and/or Torch Conv2D documentation).
310- # Assume worst case scenario (data padding on all sides) when tiling on a ceratin dimension.
311- effectiveHeight = inputHeightVar + ((pads [0 ] + pads [2 ]) * (inputHeightVar == inputBuffer .shape [1 ])) - (
312- (weightHeightVar - 1 ) * (inputHeightVar != inputBuffer .shape [1 ]))
313- effectiveWidth = inputWidthVar + ((pads [1 ] + pads [3 ]) * (inputWidthVar == inputBuffer .shape [2 ])) - (
314- (weightWidthVar - 1 ) * (inputWidthVar != inputBuffer .shape [2 ]))
315-
316317 tilerModel .addConstraint (
317- (outputHeightVar == (effectiveHeight - dilations [0 ] * (weightHeightVar - 1 ) - 1 ) // strides [0 ] + 1 ))
318+ (outputHeightVar == (effectiveInputHeight - dilations [0 ] * (weightHeightVar - 1 ) - 1 ) // strides [0 ] + 1 ))
318319 tilerModel .addConstraint (
319- (outputWidthVar == (effectiveWidth - dilations [1 ] * (weightWidthVar - 1 ) - 1 ) // strides [1 ] + 1 ))
320+ (outputWidthVar == (effectiveInputWidth - dilations [1 ] * (weightWidthVar - 1 ) - 1 ) // strides [1 ] + 1 ))
320321
321322 # Add constraint for input channel size match
322323 # (Depends on weight output channel and conv grouping)
@@ -341,6 +342,7 @@ def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkCo
341342 weightBuffer = ctxt .lookup (name = parseDict ['weight' ])
342343
343344 # Get other information
345+ pads = parseDict ["pads" ]
344346 strides = parseDict ["strides" ]
345347
346348 # ===== EXTRACT TENSOR DIMS AS VARS =====
@@ -357,17 +359,24 @@ def addPolicyConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: NetworkCo
357359 weightWidthVar = tilerModel .getTensorDimVar (tensorName = weightBuffer .name , dimIdx = 2 )
358360 weightInChannelVar = tilerModel .getTensorDimVar (tensorName = weightBuffer .name , dimIdx = 3 )
359361
362+ # ===== COMPUTE EFFECTIVE INPUT HEIGHT AND WIDTH =====
363+ # Assume worst case scenario (data padding on all sides) when tiling on a ceratin dimension.
364+ effectiveInputHeight = inputHeightVar + ((pads [0 ] + pads [2 ]) * (inputHeightVar == inputBuffer .shape [1 ])) - (
365+ (weightHeightVar - 1 ) * (inputHeightVar != inputBuffer .shape [1 ]))
366+ effectiveInputWidth = inputWidthVar + ((pads [1 ] + pads [3 ]) * (inputWidthVar == inputBuffer .shape [2 ])) - (
367+ (weightWidthVar - 1 ) * (inputWidthVar != inputBuffer .shape [2 ]))
368+
360369 # ===== ADD CONSTRAINTS =====
361370 # Keep whole input channels (required for im2col algorithm)
362371 tilerModel .addConstraint (inputChannelVar == parseDict ['ch_im_in' ])
363372
364373 # Require minimum input spatial dimensions to be at least kernel size for proper convolution application
365- tilerModel .addConstraint (inputHeightVar >= parseDict ['dim_kernel_x' ])
366- tilerModel .addConstraint (inputWidthVar >= parseDict ['dim_kernel_y' ])
374+ tilerModel .addConstraint (effectiveInputHeight >= parseDict ['dim_kernel_x' ])
375+ tilerModel .addConstraint (effectiveInputWidth >= parseDict ['dim_kernel_y' ])
367376
368377 # Ensure input tiles are compatible with stride
369- tilerModel .addConstraint ((inputHeightVar % strides [0 ]) == 0 )
370- tilerModel .addConstraint ((inputWidthVar % strides [1 ]) == 0 )
378+ tilerModel .addConstraint ((effectiveInputHeight % strides [0 ]) == 0 )
379+ tilerModel .addConstraint ((effectiveInputWidth % strides [1 ]) == 0 )
371380
372381 # Weight should not be tiled
373382 tilerModel .addConstraint (weightHeightVar == parseDict ['dim_kernel_x' ])
0 commit comments