Skip to content

Commit 6926a6b

Browse files
authored
[mlir][tosa] Allow shift operand of tosa::MulOp as non-constant (#155197)
The shift operand of tosa::MulOp could be non-constant when the dynamic extension enabled. Given that checkConstantOperandMul could check the shift operand according to the extension, we might able to relax the checking in TosaToLinalg. Relative discussion: https://discourse.llvm.org/t/tosa-ext-dynamic-clearification-needed/87478?u=r2333333.
1 parent fa883e1 commit 6926a6b

File tree

3 files changed

+46
-24
lines changed

3 files changed

+46
-24
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 35 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -126,12 +126,12 @@ static Value createLinalgBodyCalculationForElementwiseOp(
126126
if (isa<tosa::MulOp>(op)) {
127127
auto shiftVal = cast<tosa::MulOp>(op).getShift();
128128
DenseElementsAttr shiftElem;
129-
if (!matchPattern(shiftVal, m_Constant(&shiftElem))) {
130-
(void)rewriter.notifyMatchFailure(op, "shift value of mul not found");
131-
return nullptr;
132-
}
133-
134-
int32_t shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
129+
bool shiftIsConstant = true;
130+
int32_t shift = 0;
131+
if (matchPattern(shiftVal, m_Constant(&shiftElem)))
132+
shift = shiftElem.getValues<IntegerAttr>()[0].getInt();
133+
else
134+
shiftIsConstant = false;
135135

136136
if (isa<FloatType>(elementTy)) {
137137
if (shift != 0) {
@@ -147,23 +147,24 @@ static Value createLinalgBodyCalculationForElementwiseOp(
147147
Value a = args[0];
148148
Value b = args[1];
149149

150-
if (shift > 0) {
151-
auto shiftConst =
152-
arith::ConstantIntOp::create(rewriter, loc, shift, /*bitwidth=*/8);
150+
if (shift > 0 || !shiftIsConstant) {
151+
Value shiftConst;
152+
if (shiftIsConstant)
153+
shiftConst =
154+
rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
155+
153156
if (!a.getType().isInteger(32))
154157
a = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), a);
155158

156159
if (!b.getType().isInteger(32))
157160
b = arith::ExtSIOp::create(rewriter, loc, rewriter.getI32Type(), b);
158161

162+
auto shiftAmount = shiftIsConstant ? shiftConst : args[2];
159163
auto result = tosa::ApplyScaleOp::create(
160-
rewriter, loc, rewriter.getI32Type(), a, b, shiftConst,
164+
rewriter, loc, rewriter.getI32Type(), a, b, shiftAmount,
161165
rewriter.getStringAttr("SINGLE_ROUND"));
162166

163-
if (elementTy.isInteger(32))
164-
return result;
165-
166-
return arith::TruncIOp::create(rewriter, loc, elementTy, result);
167+
return result;
167168
}
168169

169170
int aWidth = a.getType().getIntOrFloatBitWidth();
@@ -918,6 +919,18 @@ broadcastDynamicDimensions(PatternRewriter &rewriter, Location loc,
918919
if (operands.size() == 1)
919920
return operands;
920921

922+
// No need to broadcast for static shape
923+
bool hasDynamic = false;
924+
for (auto op : operands) {
925+
const auto tType = dyn_cast<RankedTensorType>(op.getType());
926+
if (tType && !tType.hasStaticShape()) {
927+
hasDynamic = true;
928+
break;
929+
}
930+
}
931+
if (!hasDynamic)
932+
return operands;
933+
921934
// Broadcast dynamic dimensions operand by operand
922935
return llvm::map_to_vector(operands, [&](Value operand) {
923936
return broadcastDynamicDimensions(rewriter, loc, indexPool, operand,
@@ -990,8 +1003,14 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
9901003
static ValueRange getBroadcastableOperands(Operation *operation,
9911004
ValueRange operands) {
9921005
// Shift cannot broadcast
993-
if (isa<tosa::MulOp>(operation))
994-
return operands.take_front(2);
1006+
if (isa<tosa::MulOp>(operation)) {
1007+
DenseElementsAttr shiftElems;
1008+
// Shift cannot broadcast when it is constant
1009+
if (matchPattern(operation->getOperand(2), m_Constant(&shiftElems)))
1010+
return operands.take_front(2);
1011+
else
1012+
return operands.take_front(3);
1013+
}
9951014
// Input1_zp and output_zp cannot broadcast
9961015
if (isa<tosa::NegateOp>(operation))
9971016
return operands.take_front(1);

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 0 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -73,11 +73,3 @@ func.func @unranked_gather(%arg0: tensor<13x21x3xf32>, %arg1: tensor<13x26xi32>)
7373
%0 = tosa.gather %arg0, %arg1 : (tensor<13x21x3xf32>, tensor<13x26xi32>) -> tensor<*xf32>
7474
return %0 : tensor<*xf32>
7575
}
76-
77-
// -----
78-
79-
func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
80-
// expected-error@+1 {{failed to legalize operation 'tosa.mul'}}
81-
%0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
82-
return %0 : tensor<2x3xi32>
83-
}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2471,3 +2471,14 @@ func.func @test_0d_input(%arg0: tensor<i32>) -> () {
24712471

24722472
return
24732473
}
2474+
2475+
// -----
2476+
2477+
// CHECK-LABEL: @mul_no_const_shift
2478+
func.func @mul_no_const_shift(%arg0: tensor<2x3xi32>, %arg1: tensor<2x3xi32>, %arg2: tensor<1xi8>) -> tensor<2x3xi32> {
2479+
// CHECK: linalg.generic
2480+
// CHECK: ^bb0(%[[ARG0:.*]]: i32, %[[ARG1:.*]]: i32, %[[ARG2:.*]]: i8, %[[OUT:.*]]: i32):
2481+
// CHECK: tosa.apply_scale %[[ARG0]], %[[ARG1]], %[[ARG2]]
2482+
%0 = tosa.mul %arg0, %arg1, %arg2 : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<1xi8>) -> tensor<2x3xi32>
2483+
return %0 : tensor<2x3xi32>
2484+
}

0 commit comments

Comments
 (0)