Skip to content

Commit 09d7701

Browse files
committed
Improved previously found solution for convolution input constraint (clearly mark when data is needed from bordering regions)
1 parent 2f89b42 commit 09d7701

File tree

1 file changed

+6
-4
lines changed

1 file changed

+6
-4
lines changed

Deeploy/Targets/PULPOpen/TileConstraints/ConvTileConstraint.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -308,13 +308,15 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
308308
# (Depends on output height and width, kernel size, padding, dilations, and strides.
309309
# For more information on the connections, see ONNX and/or Torch Conv2D documentation).
310310
# Assume worst case scenario (data padding on all sides) when tiling on a ceratin dimension.
311-
effectiveHeight = inputHeightVar + ((pads[0] + pads[2]) * (inputHeightVar == inputBuffer.shape[1]))
312-
effectiveWidth = inputWidthVar + ((pads[1] + pads[3]) * (inputWidthVar == inputBuffer.shape[2]))
311+
effectiveHeight = inputHeightVar + ((pads[0] + pads[2]) * (inputHeightVar == inputBuffer.shape[1])) - (
312+
(weightHeightVar - 1) * (inputHeightVar != inputBuffer.shape[1]))
313+
effectiveWidth = inputWidthVar + ((pads[1] + pads[3]) * (inputWidthVar == inputBuffer.shape[2])) - (
314+
(weightWidthVar - 1) * (inputWidthVar != inputBuffer.shape[2]))
313315

314316
tilerModel.addConstraint(
315-
(outputHeightVar == (effectiveHeight - dilations[0] * (weightHeightVar - 1) - 1) // strides[0]))
317+
(outputHeightVar == (effectiveHeight - dilations[0] * (weightHeightVar - 1) - 1) // strides[0] + 1))
316318
tilerModel.addConstraint(
317-
(outputWidthVar == (effectiveWidth - dilations[1] * (weightWidthVar - 1) - 1) // strides[1]))
319+
(outputWidthVar == (effectiveWidth - dilations[1] * (weightWidthVar - 1) - 1) // strides[1] + 1))
318320

319321
# Add constraint for input channel size match
320322
# (Depends on weight output channel and conv grouping)

0 commit comments

Comments
 (0)