Skip to content

Commit ee6e01f

Browse files
committed
[AutoBump] Merge with fixes of 8d23719 (Jun 27)
2 parents 6d8ed84 + 8d23719 commit ee6e01f

File tree

6 files changed

+76
-56
lines changed

6 files changed

+76
-56
lines changed

mlir/include/mlir/Dialect/Tosa/Utils/ConversionUtils.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ Value clampFloatHelper(Location loc, Value arg, Value min, Value max,
3737
// Takes the parameters for a clamp and turns it into a series of ops for
3838
// integer inputs.
3939
Value clampIntHelper(Location loc, Value arg, Value min, Value max,
40-
OpBuilder &rewriter, bool isUnsigned = false);
40+
OpBuilder &rewriter, bool isUnsigned);
4141

4242
// Determines whether the integer value falls witin the range of integer type.
4343
bool validIntegerRange(IntegerType ty, int64_t value);

mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp

Lines changed: 40 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -191,7 +191,8 @@ static Value createLinalgBodyCalculationForElementwiseOp(
191191
Value max = rewriter.create<arith::ConstantIntOp>(
192192
loc, APInt::getSignedMaxValue(inputBitWidth).getSExtValue(),
193193
intermediateType);
194-
auto clamp = clampIntHelper(loc, sub, min, max, rewriter);
194+
auto clamp =
195+
clampIntHelper(loc, sub, min, max, rewriter, /*isUnsigned=*/false);
195196

196197
// Truncate to the final value.
197198
return rewriter.create<arith::TruncIOp>(loc, elementTy, clamp);
@@ -402,24 +403,26 @@ static Value createLinalgBodyCalculationForElementwiseOp(
402403
int64_t max =
403404
cast<IntegerAttr>(op->getAttr("max_int")).getValue().getSExtValue();
404405

406+
int64_t minRepresentable = std::numeric_limits<int64_t>::min();
407+
int64_t maxRepresentable = std::numeric_limits<int64_t>::max();
405408
if (intTy.isUnsignedInteger()) {
406-
if (intTy.getIntOrFloatBitWidth() > 63) {
407-
(void)rewriter.notifyMatchFailure(
408-
op, "support for 64-bit or larger integers is not implemented");
409-
return {};
409+
minRepresentable = 0;
410+
if (intTy.getIntOrFloatBitWidth() <= 63) {
411+
maxRepresentable = (int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
412+
.getZExtValue();
410413
}
411-
min = std::max(min, (int64_t)0);
412-
max = std::min(max,
413-
(int64_t)APInt::getMaxValue(intTy.getIntOrFloatBitWidth())
414-
.getZExtValue());
415-
} else {
416-
min =
417-
std::max(min, APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
418-
.getSExtValue());
419-
max =
420-
std::min(max, APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
421-
.getSExtValue());
414+
} else if(intTy.getIntOrFloatBitWidth() <= 64) {
415+
// Ensure that min & max fit into signed n-bit constants.
416+
minRepresentable = APInt::getSignedMinValue(intTy.getIntOrFloatBitWidth())
417+
.getSExtValue();
418+
maxRepresentable = APInt::getSignedMaxValue(intTy.getIntOrFloatBitWidth())
419+
.getSExtValue();
422420
}
421+
// Ensure that the bounds are representable as n-bit signed/unsigned integers.
422+
min = std::max(min, minRepresentable);
423+
max = std::max(max, minRepresentable);
424+
min = std::min(min, maxRepresentable);
425+
max = std::min(max, maxRepresentable);
423426

424427
auto minVal = rewriter.create<arith::ConstantIntOp>(
425428
loc, min, intTy.getIntOrFloatBitWidth());
@@ -666,10 +669,8 @@ static Value expandRank(PatternRewriter &rewriter, Location loc, Value tensor,
666669
}
667670

668671
static SmallVector<Value> expandInputRanks(PatternRewriter &rewriter,
669-
Location loc, Operation *operation,
670-
ValueRange operands) {
671-
auto rank =
672-
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
672+
Location loc, ValueRange operands,
673+
int64_t rank) {
673674
return llvm::map_to_vector(operands, [&](Value operand) {
674675
return expandRank(rewriter, loc, operand, rank);
675676
});
@@ -898,10 +899,10 @@ static LogicalResult
898899
emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
899900
Operation *operation, ValueRange operands,
900901
ArrayRef<OpFoldResult> targetShape,
901-
const TypeConverter *converter) {
902+
const TypeConverter &converter) {
902903
// Generate output tensor
903-
auto resultType = cast_or_null<RankedTensorType>(converter->convertType(
904-
cast<RankedTensorType>(operation->getResultTypes().front())));
904+
auto resultType = cast_or_null<RankedTensorType>(
905+
converter.convertType(operation->getResultTypes().front()));
905906
if (!resultType) {
906907
return rewriter.notifyMatchFailure(operation, "failed to convert type");
907908
}
@@ -953,7 +954,7 @@ emitElementwiseComputation(ConversionPatternRewriter &rewriter, Location loc,
953954
static LogicalResult
954955
elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
955956
ConversionPatternRewriter &rewriter,
956-
const TypeConverter *converter) {
957+
const TypeConverter &converter) {
957958

958959
// Collect op properties
959960
assert(operation->getNumResults() == 1 && "elementwise op expects 1 result");
@@ -966,7 +967,9 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
966967
// Lower operation
967968
IndexPool indexPool;
968969
auto loc = operation->getLoc();
969-
auto expandedOperands = expandInputRanks(rewriter, loc, operation, operands);
970+
auto rank =
971+
cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
972+
auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
970973
auto [targetShape, masterOperands] =
971974
computeTargetShape(rewriter, loc, indexPool, expandedOperands);
972975
auto broadcastOperands = broadcastDynamicDimensions(
@@ -1173,8 +1176,8 @@ class PointwiseConverter : public OpConversionPattern<SrcOp> {
11731176
LogicalResult
11741177
matchAndRewrite(SrcOp op, OpAdaptor operands,
11751178
ConversionPatternRewriter &rewriter) const final {
1176-
return elementwiseMatchAndRewriteHelper(op, operands.getOperands(),
1177-
rewriter, this->getTypeConverter());
1179+
return elementwiseMatchAndRewriteHelper(
1180+
op, operands.getOperands(), rewriter, *this->getTypeConverter());
11781181
}
11791182
};
11801183

@@ -1398,7 +1401,7 @@ class RescaleConverter : public OpConversionPattern<tosa::RescaleOp> {
13981401
loc, nestedBuilder.getI32IntegerAttr(intMax));
13991402

14001403
value = clampIntHelper(nestedLoc, value, intMinVal, intMaxVal,
1401-
nestedBuilder);
1404+
nestedBuilder, /*isUnsigned=*/false);
14021405

14031406
if (outIntType.getWidth() < 32) {
14041407
value = nestedBuilder.create<arith::TruncIOp>(
@@ -1772,7 +1775,7 @@ class GenericResizeConverter : public OpConversionPattern<tosa::ResizeOp> {
17721775

17731776
auto offset = b.create<arith::SelectOp>(pred, one, zeroI32);
17741777
val = b.create<arith::AddIOp>(val, offset);
1775-
val = clampIntHelper(loc, val, zeroI32, max, b);
1778+
val = clampIntHelper(loc, val, zeroI32, max, b, /*isUnsigned=*/false);
17761779
return b.create<arith::IndexCastOp>(b.getIndexType(), val);
17771780
};
17781781

@@ -1793,8 +1796,10 @@ class GenericResizeConverter : public OpConversionPattern<tosa::ResizeOp> {
17931796
Value max, ImplicitLocOpBuilder &b) {
17941797
val0 = in;
17951798
val1 = b.create<arith::AddIOp>(val0, oneVal);
1796-
val0 = clampIntHelper(loc, val0, zeroI32, max, b);
1797-
val1 = clampIntHelper(loc, val1, zeroI32, max, b);
1799+
val0 =
1800+
clampIntHelper(loc, val0, zeroI32, max, b, /*isUnsigned=*/false);
1801+
val1 =
1802+
clampIntHelper(loc, val1, zeroI32, max, b, /*isUnsigned=*/false);
17981803
val0 = b.create<arith::IndexCastOp>(b.getIndexType(), val0);
17991804
val1 = b.create<arith::IndexCastOp>(b.getIndexType(), val1);
18001805
};
@@ -2760,7 +2765,10 @@ void mlir::tosa::populateTosaToLinalgConversionPatterns(
27602765
PointwiseConverter<tosa::CeilOp>,
27612766
PointwiseConverter<tosa::FloorOp>,
27622767
PointwiseConverter<tosa::ClampOp>,
2763-
PointwiseConverter<tosa::SigmoidOp>,
2768+
PointwiseConverter<tosa::SigmoidOp>
2769+
>(converter, patterns->getContext());
2770+
2771+
patterns->add<
27642772
IdentityNConverter<tosa::IdentityOp>,
27652773
ReduceConverter<tosa::ReduceAllOp>,
27662774
ReduceConverter<tosa::ReduceAnyOp>,

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgNamed.cpp

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1110,7 +1110,8 @@ class AvgPool2dConverter : public OpRewritePattern<tosa::AvgPool2dOp> {
11101110
auto max = rewriter.create<arith::ConstantIntOp>(
11111111
loc, APInt::getSignedMaxValue(outBitwidth).getSExtValue(),
11121112
accETy);
1113-
auto clamp = clampIntHelper(loc, scaled, min, max, rewriter);
1113+
auto clamp = clampIntHelper(loc, scaled, min, max, rewriter,
1114+
/*isUnsigned=*/false);
11141115

11151116
poolVal = clamp;
11161117
// Convert type.

mlir/lib/Conversion/TosaToLinalg/TosaToLinalgPass.cpp

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,6 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
4646
}
4747

4848
void runOnOperation() override {
49-
TypeConverter converter;
50-
mlir::tosa::populateTosaToLinalgTypeConversion(converter);
51-
5249
RewritePatternSet patterns(&getContext());
5350
ConversionTarget target(getContext());
5451
target.addLegalDialect<linalg::LinalgDialect, tensor::TensorDialect,
@@ -64,13 +61,16 @@ struct TosaToLinalg : public impl::TosaToLinalgBase<TosaToLinalg> {
6461
target.addLegalOp<tosa::SliceOp>();
6562
target.addLegalOp<tosa::ReshapeOp>();
6663
target.addLegalOp<tosa::PadOp>();
64+
TypeConverter converter;
6765
target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) {
6866
return converter.isSignatureLegal(op.getFunctionType());
6967
});
7068
target.addDynamicallyLegalDialect<func::FuncDialect>(
7169
[&](Operation *op) { return converter.isLegal(op); });
7270
target.markUnknownOpDynamicallyLegal([](Operation *) { return true; });
7371

72+
tosa::populateTosaTypeConversion(converter);
73+
7474
FunctionOpInterface func = getOperation();
7575
mlir::tosa::populateTosaToLinalgConversionPatterns(converter, &patterns);
7676
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns,

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg-invalid.mlir

Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,15 +30,6 @@ func.func @unranked_add(%arg0 : tensor<10x10xf32> , %arg1 : tensor<10x10xf32>, %
3030

3131
// -----
3232

33-
// CHECK-LABEL: @clamp_on_large_int
34-
func.func @clamp_on_large_int(%arg0: tensor<1xui64>) -> tensor<1xui64> {
35-
// expected-error@+1 {{failed to legalize operation 'tosa.clamp'}}
36-
%0 = tosa.clamp %arg0 {min_int = -1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64>
37-
return %0 : tensor<1xui64>
38-
}
39-
40-
// -----
41-
4233
// CHECK-LABEL: @rfft2d_with_non_float_type
4334
func.func @rfft2d_with_non_float_type(%arg0 : tensor<1x1x1xi32>) -> (tensor<1x1x1xi32>, tensor<1x1x1xi32>) {
4435
// expected-error@+1 {{failed to legalize operation 'tosa.rfft2d'}}

mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir

Lines changed: 30 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -651,7 +651,7 @@ func.func @test_simple_ui8(%arg0: tensor<1xui8>) -> () {
651651
// -----
652652

653653
// CHECK-LABEL: @test_simple_i32
654-
func.func @test_simple_i32(%arg0: tensor<1xi32>, %arg1: tensor<1xui32>) -> () {
654+
func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %unsigned64: tensor<1xui64>) -> () {
655655
// CHECK: linalg.generic
656656
// CHECK: arith.addi
657657
%0 = tosa.add %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
@@ -674,7 +674,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %arg1: tensor<1xui32>) -> () {
674674
%40 = tosa.int_div %arg0, %arg0 : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
675675

676676
// CHECK: arith.divui
677-
%u4 = tosa.int_div %arg1, %arg1 : (tensor<1xui32>, tensor<1xui32>) -> tensor<1xui32>
677+
%u4 = tosa.int_div %unsigned, %unsigned : (tensor<1xui32>, tensor<1xui32>) -> tensor<1xui32>
678678

679679
// CHECK: linalg.generic
680680
// CHECK: ^bb0(%[[ARG1:.*]]: i32, %[[ARG2:.*]]: i32):
@@ -708,7 +708,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %arg1: tensor<1xui32>) -> () {
708708

709709
// CHECK: linalg.generic
710710
// CHECK: arith.shrui
711-
%u11 = tosa.arithmetic_right_shift %arg1, %arg1 {round = 0 : i1} : (tensor<1xui32>, tensor<1xui32>) -> tensor<1xui32>
711+
%u11 = tosa.arithmetic_right_shift %unsigned, %unsigned {round = 0 : i1} : (tensor<1xui32>, tensor<1xui32>) -> tensor<1xui32>
712712

713713
// CHECK: linalg.generic
714714
// CHECK: arith.constant 1
@@ -736,7 +736,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %arg1: tensor<1xui32>) -> () {
736736
// CHECK: and
737737
// CHECK: arith.extui
738738
// CHECK: arith.addi
739-
%u12 = tosa.arithmetic_right_shift %arg1, %arg1 {round = 1 : i1} : (tensor<1xui32>, tensor<1xui32>) -> tensor<1xui32>
739+
%u12 = tosa.arithmetic_right_shift %unsigned, %unsigned {round = 1 : i1} : (tensor<1xui32>, tensor<1xui32>) -> tensor<1xui32>
740740

741741
// CHECK: math.ctlz
742742
%13 = tosa.clz %arg0 : (tensor<1xi32>) -> tensor<1xi32>
@@ -767,12 +767,32 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %arg1: tensor<1xui32>) -> () {
767767
%19 = tosa.clamp %0 {min_int = 1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xi32>) -> tensor<1xi32>
768768

769769
// CHECK: linalg.generic
770-
// CHECK: bb0(%[[IN:.*]]: i32,
770+
// CHECK-DAG: %[[LB:.*]] = arith.constant 4 : i32
771+
// CHECK-DAG: %[[UB:.*]] = arith.constant 32 : i32
772+
// CHECK-DAG: arith.maxui %[[LB]],
773+
// CHECK-DAG: arith.minui %[[UB]],
774+
%u0 = tosa.clamp %unsigned {min_int = 4 : i64, max_int = 32 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
775+
776+
// CHECK: linalg.generic
777+
// CHECK-DAG: %[[LB:.*]] = arith.constant -1 : i32
778+
// CHECK-DAG: %[[UB:.*]] = arith.constant -1 : i32
779+
// CHECK-DAG: arith.maxui %[[LB]],
780+
// CHECK-DAG: arith.minui %[[UB]],
781+
%u1 = tosa.clamp %unsigned {min_int = 9223372036854775807 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
782+
783+
// CHECK: linalg.generic
771784
// CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i32
772-
// CHECK-DAG: %[[UB:.*]] = arith.constant 5 : i32
773-
// CHECK-DAG: %[[MAX:.*]] = arith.maxui %[[LB]], %[[IN]]
774-
// CHECK-DAG: arith.minui %[[UB]], %[[MAX]]
775-
%u19 = tosa.clamp %arg1 {min_int = -1 : i64, max_int = 5 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
785+
// CHECK-DAG: %[[UB:.*]] = arith.constant 0 : i32
786+
// CHECK-DAG: arith.maxui %[[LB]],
787+
// CHECK-DAG: arith.minui %[[UB]],
788+
%u2 = tosa.clamp %unsigned {min_int = -3 : i64, max_int = -2 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui32>) -> tensor<1xui32>
789+
790+
// CHECK: linalg.generic
791+
// CHECK-DAG: %[[LB:.*]] = arith.constant 0 : i64
792+
// CHECK-DAG: %[[UB:.*]] = arith.constant 9223372036854775807 : i64
793+
// CHECK-DAG: arith.maxui %[[LB]],
794+
// CHECK-DAG: arith.minui %[[UB]],
795+
%u3 = tosa.clamp %unsigned64 {min_int = -3 : i64, max_int = 9223372036854775807 : i64, min_fp = 1.0 : f32, max_fp = 5.0 : f32} : (tensor<1xui64>) -> tensor<1xui64>
776796

777797
// CHECK: linalg.generic
778798
// CHECK: arith.trunci
@@ -793,7 +813,7 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %arg1: tensor<1xui32>) -> () {
793813

794814
// CHECK: linalg.generic
795815
// CHECK: arith.uitofp
796-
%u23 = tosa.cast %arg1 : (tensor<1xui32>) -> tensor<1xf32>
816+
%u23 = tosa.cast %unsigned : (tensor<1xui32>) -> tensor<1xf32>
797817

798818
// CHECK: linalg.generic
799819
// CHECK: arith.constant 0

0 commit comments

Comments
 (0)