Skip to content

Commit 0816361

Browse files
committed
Solve merge conflicts and fixed tests after bump
Handle the upstream switch of tosa.mul to use an optional shift operand by allowing the constant folding path to accept and validate a constant zero shift tensor. Relax the shared binary folding helper to work with ops that gain extra operands so mul no longer asserts. Update the TOSA constant-folding test cases to spell out zero-shift operands, capture the new value numbering, and document these findings for future bump runs.
1 parent e693bc0 commit 0816361

File tree

2 files changed

+40
-19
lines changed

2 files changed

+40
-19
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaFolders.cpp

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -355,7 +355,8 @@ struct TosaFoldConstantBase : public OpRewritePattern<TosaOp> {
355355
DenseElementsAttr valuesSecond) const {
356356
if (!foldSplatOrSingleUseOnly)
357357
return true;
358-
assert(binaryOp->getNumOperands() == 2);
358+
assert(binaryOp->getNumOperands() >= 2 &&
359+
"binary folding expects at least two operands");
359360
auto firstOp = binaryOp->getOperand(0);
360361
auto secondOp = binaryOp->getOperand(1);
361362

@@ -750,10 +751,19 @@ struct TosaFoldConstantMul
750751
DenseElementsAttr computeInteger(DenseElementsAttr lhsValues,
751752
DenseElementsAttr rhsValues,
752753
PatternRewriter &rewriter, MulOp op) const {
753-
if (op.getShift() > 0) {
754-
(void)rewriter.notifyMatchFailure(
755-
op, "Non-zero shift folding is currently not implemented.");
756-
return {};
754+
if (Value shiftVal = op.getShift()) {
755+
ElementsAttr shiftAttr;
756+
if (!matchPattern(shiftVal, m_Constant(&shiftAttr))) {
757+
(void)rewriter.notifyMatchFailure(
758+
op, "shift must be a constant for folding.");
759+
return {};
760+
}
761+
if (llvm::any_of(shiftAttr.getValues<IntegerAttr>(),
762+
[](IntegerAttr attr) { return attr.getInt() != 0; })) {
763+
(void)rewriter.notifyMatchFailure(
764+
op, "Non-zero shift folding is currently not implemented.");
765+
return {};
766+
}
757767
}
758768

759769
auto resultElementWidth =

mlir/test/Dialect/Tosa/constant-mul-opt.mlir

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
// Modifications (c) Copyright 2023-2025 Advanced Micro Devices, Inc. or its
2+
// affiliates
13
// RUN: mlir-opt --split-input-file -verify-diagnostics --tosa-layerwise-constant-fold %s | FileCheck %s
24

35
// Float multiplications
@@ -15,7 +17,7 @@ func.func @mul_fold_float() -> tensor<4xf16> {
1517
dense<[-132.7, -3.0, -0.0, 5.0]> :
1618
tensor<4xf16>
1719
} : () -> tensor<4xf16>
18-
%2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<4xf16>, tensor<4xf16>) -> tensor<4xf16>
20+
%2 = "tosa.mul"(%0, %1) : (tensor<4xf16>, tensor<4xf16>) -> tensor<4xf16>
1921
return %2 : tensor<4xf16>
2022
}
2123

@@ -32,7 +34,7 @@ func.func @mul_fold_float_infinity_nan() -> tensor<7xf32> {
3234
dense<[3.0, -3.0, -3.0, 3.0, 1.0, 0xFF800000, 0.0]> :
3335
tensor<7xf32>
3436
} : () -> tensor<7xf32>
35-
%2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<7xf32>, tensor<7xf32>) -> tensor<7xf32>
37+
%2 = "tosa.mul"(%0, %1) : (tensor<7xf32>, tensor<7xf32>) -> tensor<7xf32>
3638
return %2 : tensor<7xf32>
3739
}
3840

@@ -49,7 +51,7 @@ func.func @add_fold_float_overflow() -> tensor<2xf32> {
4951
dense<[2.1e+38, 1.1e+38]> :
5052
tensor<2xf32>
5153
} : () -> tensor<2xf32>
52-
%2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
54+
%2 = "tosa.mul"(%0, %1) : (tensor<2xf32>, tensor<2xf32>) -> tensor<2xf32>
5355
return %2 : tensor<2xf32>
5456
}
5557

@@ -69,7 +71,8 @@ func.func @mul_fold_int() -> tensor<4xi32> {
6971
dense<[-132, -3, 0, 5]> :
7072
tensor<4xi32>
7173
} : () -> tensor<4xi32>
72-
%2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
74+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
75+
%2 = "tosa.mul"(%0, %1, %shift) : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
7376
return %2 : tensor<4xi32>
7477
}
7578

@@ -87,10 +90,12 @@ func.func @mul_fold_i8() -> tensor<4xi32> {
8790
tensor<4xi8>
8891
} : () -> tensor<4xi8>
8992
// TODO: This is wrongly rejected as illegal, see https://reviews.llvm.org/D150472#4484478
90-
// %2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<4xi8>, tensor<4xi8>) -> tensor<4xi32>
93+
// %zero_shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
94+
// %2 = "tosa.mul"(%0, %1, %zero_shift) : (tensor<4xi8>, tensor<4xi8>, tensor<1xi8>) -> tensor<4xi32>
9195
%a = "tosa.cast"(%0) : (tensor<4xi8>) -> tensor<4xi32>
9296
%b = "tosa.cast"(%1) : (tensor<4xi8>) -> tensor<4xi32>
93-
%2 = "tosa.mul"(%a, %b) {shift = 0 : i8} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
97+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
98+
%2 = "tosa.mul"(%a, %b, %shift) : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
9499

95100
return %2 : tensor<4xi32>
96101
}
@@ -110,8 +115,9 @@ func.func @mul_fold_int_overflow() -> tensor<4xi32> {
110115
dense<[1, 10, 1, 30]> :
111116
tensor<4xi32>
112117
} : () -> tensor<4xi32>
118+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
113119
// expected-warning@below {{Multiplication did overflow. The results are unspecified.}}
114-
%2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
120+
%2 = "tosa.mul"(%0, %1, %shift) : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
115121
return %2 : tensor<4xi32>
116122
}
117123

@@ -127,7 +133,8 @@ func.func @mul_fold_equal_args() -> tensor<3xi32> {
127133
dense<[-17, 4, 0]> :
128134
tensor<3xi32>
129135
} : () -> tensor<3xi32>
130-
%2 = "tosa.mul"(%0, %0) {shift = 0 : i8} : (tensor<3xi32>, tensor<3xi32>) -> tensor<3xi32>
136+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
137+
%2 = "tosa.mul"(%0, %0, %shift) : (tensor<3xi32>, tensor<3xi32>, tensor<1xi8>) -> tensor<3xi32>
131138
return %2 : tensor<3xi32>
132139
}
133140

@@ -147,7 +154,8 @@ func.func @mul_fold_int_broadcast_simple() -> tensor<3xi32> {
147154
dense<-12> :
148155
tensor<1xi32>
149156
} : () -> tensor<1xi32>
150-
%2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<3xi32>, tensor<1xi32>) -> tensor<3xi32>
157+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
158+
%2 = "tosa.mul"(%0, %1, %shift) : (tensor<3xi32>, tensor<1xi32>, tensor<1xi8>) -> tensor<3xi32>
151159
return %2 : tensor<3xi32>
152160
}
153161

@@ -167,15 +175,17 @@ func.func @mul_fold_int_broadcast_complex() -> tensor<3x3xi32> {
167175
dense<[[-12, 7, 4]]> :
168176
tensor<1x3xi32>
169177
} : () -> tensor<1x3xi32>
170-
%2 = "tosa.mul"(%0, %1) {shift = 0 : i8} : (tensor<3x1xi32>, tensor<1x3xi32>) -> tensor<3x3xi32>
178+
%shift = "tosa.const"() <{value = dense<0> : tensor<1xi8>}> : () -> tensor<1xi8>
179+
%2 = "tosa.mul"(%0, %1, %shift) : (tensor<3x1xi32>, tensor<1x3xi32>, tensor<1xi8>) -> tensor<3x3xi32>
171180
return %2 : tensor<3x3xi32>
172181
}
173182

174183
// CHECK-LABEL: @mul_fold_int_non_zero_shift
175184
func.func @mul_fold_int_non_zero_shift() -> tensor<4xi32> {
176-
// CHECK: [[FIRST:]] ={{.*}}tosa.const
177-
// CHECK-NEXT: [[SECOND:]] ={{.*}}tosa.const
178-
// CHECK-NEXT: [[MUL:]] ={{.*}}tosa.mul{{.*}}[[FIRST]], [[SECOND]]
185+
// CHECK: [[FIRST:%.*]] ={{.*}}tosa.const
186+
// CHECK-NEXT: [[SECOND:%.*]] ={{.*}}tosa.const
187+
// CHECK-NEXT: [[SHIFT:%.*]] ={{.*}}tosa.const
188+
// CHECK-NEXT: [[MUL:%.*]] ={{.*}}tosa.mul{{.*}}[[FIRST]], [[SECOND]], [[SHIFT]]
179189
// CHECK-NEXT: return [[MUL]]
180190
%0 = "tosa.const"() {value =
181191
dense<[-17, 4, 0, 0]> :
@@ -185,6 +195,7 @@ func.func @mul_fold_int_non_zero_shift() -> tensor<4xi32> {
185195
dense<[-132, -3, 0, 5]> :
186196
tensor<4xi32>
187197
} : () -> tensor<4xi32>
188-
%2 = "tosa.mul"(%0, %1) {shift = 1 : i8} : (tensor<4xi32>, tensor<4xi32>) -> tensor<4xi32>
198+
%shift = "tosa.const"() <{value = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
199+
%2 = "tosa.mul"(%0, %1, %shift) : (tensor<4xi32>, tensor<4xi32>, tensor<1xi8>) -> tensor<4xi32>
189200
return %2 : tensor<4xi32>
190201
}

0 commit comments

Comments
 (0)