Skip to content

Commit bd67b8f

Browse files
authored
[mlir][tosa] support NegateOp with dynamic extension in TosaToLinalg (#158782)
1 parent dfbd76b commit bd67b8f

File tree

2 files changed

+80
-35
lines changed

2 files changed

+80
-35
lines changed

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 47 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -186,56 +186,63 @@ static Value createLinalgBodyCalculationForElementwiseOp(
186186
if (isa<tosa::NegateOp>(op)) {
187187
auto negate = cast<tosa::NegateOp>(op);
188188

189+
int64_t inZp = 0, outZp = 0;
189190
FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
190-
if (failed(maybeInZp)) {
191-
(void)rewriter.notifyMatchFailure(
192-
op, "input1 zero point cannot be statically determined");
193-
return nullptr;
194-
}
195-
196191
FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
197-
if (failed(maybeOutZp)) {
198-
(void)rewriter.notifyMatchFailure(
199-
op, "output zero point cannot be statically determined");
200-
return nullptr;
201-
}
202-
203-
int64_t inZp = *maybeInZp;
204-
int64_t outZp = *maybeOutZp;
192+
bool hasInZp = !failed(maybeInZp);
193+
bool hasOutZp = !failed(maybeOutZp);
194+
if (hasInZp)
195+
inZp = *maybeInZp;
196+
if (hasOutZp)
197+
outZp = *maybeOutZp;
205198

206199
if (isa<FloatType>(elementTy))
207200
return arith::NegFOp::create(rewriter, loc, resultTypes, args[0]);
208201

209202
if (isa<IntegerType>(elementTy)) {
210-
if (!inZp && !outZp) {
203+
if (hasInZp && hasOutZp && !inZp && !outZp) {
211204
auto constant = arith::ConstantOp::create(
212205
rewriter, loc, IntegerAttr::get(elementTy, 0));
213206
return arith::SubIOp::create(rewriter, loc, resultTypes, constant,
214207
args[0]);
215208
}
216209

210+
Value zpAddValue;
211+
Type intermediateType;
217212
// Compute the maximum value that can occur in the intermediate buffer.
218213
const int32_t inputBitWidth = elementTy.getIntOrFloatBitWidth();
219-
const int64_t zpAdd = inZp + outZp;
220-
const int64_t maxValue =
221-
APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
222-
std::abs(zpAdd) + 1;
223-
224-
// Convert that maximum value into the maximum bitwidth needed to
225-
// represent it. We assume 48-bit numbers may be supported further in
226-
// the pipeline.
227214
int intermediateBitWidth = 64;
228-
if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
229-
intermediateBitWidth = 16;
230-
} else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
231-
intermediateBitWidth = 32;
232-
} else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
233-
intermediateBitWidth = 48;
234-
}
235215

236-
Type intermediateType = rewriter.getIntegerType(intermediateBitWidth);
237-
Value zpAddValue = arith::ConstantOp::create(
238-
rewriter, loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
216+
if (hasInZp && hasOutZp) {
217+
// Compute the maximum value that can occur in the intermediate buffer.
218+
const int64_t zpAdd = inZp + outZp;
219+
const int64_t maxValue =
220+
APInt::getSignedMaxValue(inputBitWidth).getSExtValue() +
221+
std::abs(zpAdd) + 1;
222+
223+
// Convert that maximum value into the maximum bitwidth needed to
224+
// represent it. We assume 48-bit numbers may be supported further in
225+
// the pipeline.
226+
if (maxValue <= APInt::getSignedMaxValue(16).getSExtValue()) {
227+
intermediateBitWidth = 16;
228+
} else if (maxValue <= APInt::getSignedMaxValue(32).getSExtValue()) {
229+
intermediateBitWidth = 32;
230+
} else if (maxValue <= APInt::getSignedMaxValue(48).getSExtValue()) {
231+
intermediateBitWidth = 48;
232+
}
233+
234+
intermediateType = rewriter.getIntegerType(intermediateBitWidth);
235+
zpAddValue = rewriter.create<arith::ConstantOp>(
236+
loc, rewriter.getIntegerAttr(intermediateType, zpAdd));
237+
} else {
238+
intermediateType = rewriter.getIntegerType(intermediateBitWidth);
239+
auto arg1 =
240+
rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[1]);
241+
auto arg2 =
242+
rewriter.create<arith::ExtSIOp>(loc, intermediateType, args[2]);
243+
zpAddValue =
244+
rewriter.create<arith::AddIOp>(loc, intermediateType, arg1, arg2);
245+
}
239246

240247
// The negation can be applied by doing:
241248
// outputValue = inZp + outZp - inputValue
@@ -1013,9 +1020,14 @@ static ValueRange getBroadcastableOperands(Operation *operation,
10131020
else
10141021
return operands.take_front(3);
10151022
}
1016-
// Input1_zp and output_zp cannot broadcast
1017-
if (isa<tosa::NegateOp>(operation))
1023+
if (auto negate = dyn_cast<tosa::NegateOp>(operation)) {
1024+
FailureOr<int64_t> maybeInZp = negate.getInput1ZeroPoint();
1025+
FailureOr<int64_t> maybeOutZp = negate.getOutputZeroPoint();
1026+
if (failed(maybeOutZp) && failed(maybeInZp))
1027+
return operands;
1028+
// Input1_zp and output_zp cannot broadcast when they are constants.
10181029
return operands.take_front(1);
1030+
}
10191031
return operands;
10201032
}
10211033

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -899,6 +899,39 @@ func.func @test_negate_quantized(%arg0: tensor<1xi8>) -> () {
899899

900900
// -----
901901

902+
// CHECK-LABEL: @test_negate_no_const_1
903+
func.func @test_negate_no_const_1(%arg0: tensor<50x42xf16> ,%arg1: tensor<1xf16> , %arg2: tensor<1xf16> ) -> tensor<*xf16> {
904+
// CHECK: %[[GENERIC:.+]] = linalg.generic
905+
// CHECK: ^bb0([[ARG0:%.*]]: f16, [[ARG1:%.*]]: f16, [[ARG2:%.*]]: f16, [[OUT:%.*]]: f16)
906+
// CHECK: [[ELEMENT:%.*]] = arith.negf [[ARG0]] : f16
907+
%0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xf16>, tensor<1xf16>, tensor<1xf16>) -> tensor<50x42xf16>
908+
%cast = tensor.cast %0 : tensor<50x42xf16> to tensor<*xf16>
909+
return %cast : tensor<*xf16>
910+
}
911+
912+
// -----
913+
914+
// CHECK-LABEL: @test_negate_no_const_2
915+
func.func @test_negate_no_const_2(%arg0: tensor<50x42xi16> ,%arg1: tensor<1xi16> , %arg2: tensor<1xi16> ) -> tensor<*xi16> {
916+
// CHECK: %[[GENERIC:.+]] = linalg.generic
917+
// CHECK: ^bb0([[ARG0:%.*]]: i16, [[ARG1:%.*]]: i16, [[ARG2:%.*]]: i16, [[OUT:%.*]]: i16)
918+
// CHECK: [[EXTSI1:%.*]] = arith.extsi [[ARG1]] : i16 to i64
919+
// CHECK: [[EXTSI2:%.*]] = arith.extsi [[ARG2]] : i16 to i64
920+
// CHECK: [[SUM:%.*]] = arith.addi [[EXTSI1]], [[EXTSI2]] : i64
921+
// CHECK: [[EXTSI0:%.*]] = arith.extsi [[ARG0]] : i16 to i64
922+
// CHECK: [[SUB:%.*]] = arith.subi [[SUM]], [[EXTSI0]] : i64
923+
// CHECK: [[C_32768:%.*]] = arith.constant -32768 : i64
924+
// CHECK: [[C32767:%.*]] = arith.constant 32767 : i64
925+
// CHECK: [[MAX:%.*]] = arith.maxsi [[C_32768]], [[SUB]] : i64
926+
// CHECK: [[MIN:%.*]] = arith.minsi [[C32767]], [[MAX]] : i64
927+
// CHECK: [[TRUNC:%.*]] = arith.trunci [[MIN]] : i64 to i16
928+
%0 = tosa.negate %arg0, %arg1, %arg2 : (tensor<50x42xi16>, tensor<1xi16>, tensor<1xi16>) -> tensor<50x42xi16>
929+
%cast = tensor.cast %0 : tensor<50x42xi16> to tensor<*xi16>
930+
return %cast : tensor<*xi16>
931+
}
932+
933+
// -----
934+
902935
// CHECK-LABEL: @test_identity
903936
// CHECK-SAME: %[[ARG0:[0-9a-zA-Z_]*]]: tensor<1xf32>,
904937
// CHECK-SAME: %[[ARG1:[0-9a-zA-Z_]*]]: tensor<1xi32>

0 commit comments

Comments
 (0)