Skip to content

Commit 54acf5b

Browse files
committed
Applied ReduceMeanConstraint automatically recommended by Coder Rabbit
1 parent f65282b commit 54acf5b

File tree

1 file changed

+13
-17
lines changed

1 file changed

+13
-17
lines changed

Deeploy/Targets/PULPOpen/TileConstraints/ReduceMeanConstraint.py

Lines changed: 13 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -26,9 +26,6 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
2626
inputBufferName = parseDict['data_in']
2727
outputBufferName = parseDict['data_out']
2828

29-
# Get I/O shapes
30-
outputShape = parseDict['data_out_shape']
31-
3229
# Get other necessary information
3330
reduceAxes = parseDict['axes']
3431
keepDims = parseDict['keepdims']
@@ -39,25 +36,24 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
3936

4037
# ===== ADD CONSTRAINTS =====
4138
# Add constraints for the I/O dimensions
42-
input_ax = 0
43-
for idx in range(len(outputShape)):
44-
# Get current dimension variables
45-
outputDimensionVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = idx)
46-
47-
if idx in reduceAxes:
48-
# For reduced axes, constrain to 1 if keepdims is set,
49-
# otherwise skip this axis in the input tensor,
50-
# as it needs to be eliminated.
39+
# Iterate over input axes and maintain an output index pointer
40+
inputShape = parseDict['data_in_shape']
41+
output_idx = 0
42+
for input_ax in range(len(inputShape)):
43+
if input_ax in reduceAxes:
44+
# This axis is reduced
5145
if keepDims:
46+
# Get the output dimension variable and constrain it to 1
47+
outputDimensionVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = output_idx)
5248
tilerModel.addConstraint(outputDimensionVar == 1)
53-
input_ax += 1
49+
output_idx += 1
50+
# If keepDims is false, this axis doesn't exist in output, so don't increment output_idx
5451
else:
55-
# Otherwise, input and output dimensions need to be equal
52+
# This axis is not reduced, so input and output dimensions need to be equal
5653
inputDimensionVar = tilerModel.getTensorDimVar(tensorName = inputBufferName, dimIdx = input_ax)
57-
54+
outputDimensionVar = tilerModel.getTensorDimVar(tensorName = outputBufferName, dimIdx = output_idx)
5855
tilerModel.addConstraint(outputDimensionVar == inputDimensionVar)
59-
60-
input_ax += 1
56+
output_idx += 1
6157

6258
return tilerModel
6359

0 commit comments

Comments
 (0)