@@ -26,9 +26,6 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
2626 inputBufferName = parseDict ['data_in' ]
2727 outputBufferName = parseDict ['data_out' ]
2828
29- # Get I/O shapes
30- outputShape = parseDict ['data_out_shape' ]
31-
3229 # Get other necessary information
3330 reduceAxes = parseDict ['axes' ]
3431 keepDims = parseDict ['keepdims' ]
@@ -39,25 +36,24 @@ def addGeometricalConstraint(tilerModel: TilerModel, parseDict: Dict, ctxt: Netw
3936
4037 # ===== ADD CONSTRAINTS =====
4138 # Add constraints for the I/O dimensions
42- input_ax = 0
43- for idx in range (len (outputShape )):
44- # Get current dimension variables
45- outputDimensionVar = tilerModel .getTensorDimVar (tensorName = outputBufferName , dimIdx = idx )
46-
47- if idx in reduceAxes :
48- # For reduced axes, constrain to 1 if keepdims is set,
49- # otherwise skip this axis in the input tensor,
50- # as it needs to be eliminated.
39+ # Iterate over input axes and maintain an output index pointer
40+ inputShape = parseDict ['data_in_shape' ]
41+ output_idx = 0
42+ for input_ax in range (len (inputShape )):
43+ if input_ax in reduceAxes :
44+ # This axis is reduced
5145 if keepDims :
46+ # Get the output dimension variable and constrain it to 1
47+ outputDimensionVar = tilerModel .getTensorDimVar (tensorName = outputBufferName , dimIdx = output_idx )
5248 tilerModel .addConstraint (outputDimensionVar == 1 )
53- input_ax += 1
49+ output_idx += 1
50+ # If keepDims is false, this axis doesn't exist in output, so don't increment output_idx
5451 else :
55- # Otherwise, input and output dimensions need to be equal
52+ # This axis is not reduced, so input and output dimensions need to be equal
5653 inputDimensionVar = tilerModel .getTensorDimVar (tensorName = inputBufferName , dimIdx = input_ax )
57-
54+ outputDimensionVar = tilerModel . getTensorDimVar ( tensorName = outputBufferName , dimIdx = output_idx )
5855 tilerModel .addConstraint (outputDimensionVar == inputDimensionVar )
59-
60- input_ax += 1
56+ output_idx += 1
6157
6258 return tilerModel
6359
0 commit comments