Skip to content

Commit 3c4ab4f

Browse files
authored
[mlir][tosa] Handle unsigned constants in TosaConvertIntegerTypeToSignless (#156483)
This commit fixes handling of unsigned constant data in the `TosaConvertIntegerTypeToSignless` pass. Previously, the type of the "values" attribute would remain unsigned, which caused an error in the const ops verifier: ``` error: 'tosa.const' op expected same attr/result element types %input_zp = "tosa.const"() {values = dense<17> : tensor<1xui8>} : () -> tensor<1xui8> ^ note: see current operation: %0 = "tosa.const"() <{values = dense<17> : tensor<1xui8>}> : () -> tensor<1xi8> ``` Now the constant data in "values" is transformed to signless as well.
1 parent 9e9edb5 commit 3c4ab4f

File tree

2 files changed

+55
-0
lines changed

2 files changed

+55
-0
lines changed

mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -103,6 +103,32 @@ class ConvertGenericOpWithIntegerTensorType : public ConversionPattern {
103103
}
104104
};
105105

106+
class ConvertTosaConstWithIntegerTensorType
107+
: public OpConversionPattern<tosa::ConstOp> {
108+
using OpConversionPattern::OpConversionPattern;
109+
110+
LogicalResult
111+
matchAndRewrite(tosa::ConstOp op, OpAdaptor adaptor,
112+
ConversionPatternRewriter &rewriter) const final {
113+
const ElementsAttr oldAttr = op.getValues();
114+
const auto oldTy = llvm::cast<ShapedType>(oldAttr.getType());
115+
const auto newTy =
116+
llvm::cast<ShapedType>(typeConverter->convertType(oldTy));
117+
if (oldTy == newTy)
118+
return success();
119+
120+
ElementsAttr newAttr = oldAttr;
121+
if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(oldAttr)) {
122+
newAttr = DenseElementsAttr::get(newTy, denseAttr.getRawData());
123+
} else {
124+
return rewriter.notifyMatchFailure(op, "unknown elements attribute type");
125+
}
126+
127+
rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, newTy, newAttr);
128+
return success();
129+
}
130+
};
131+
106132
class TosaConvertIntegerTypeToSignless
107133
: public impl::TosaConvertIntegerTypeToSignlessBase<
108134
TosaConvertIntegerTypeToSignless> {
@@ -116,6 +142,10 @@ class TosaConvertIntegerTypeToSignless
116142
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
117143
typeConverter.isLegal(&op.getBody());
118144
});
145+
target.addDynamicallyLegalOp<tosa::ConstOp>([&](tosa::ConstOp op) {
146+
return typeConverter.isLegal(op.getType()) &&
147+
typeConverter.isLegal(op.getValues().getType());
148+
});
119149
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
120150
return typeConverter.isLegal(op->getOperandTypes()) &&
121151
typeConverter.isLegal(op->getResultTypes());
@@ -125,6 +155,7 @@ class TosaConvertIntegerTypeToSignless
125155
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
126156
patterns, typeConverter);
127157
patterns.add<ConvertGenericOpWithIntegerTensorType>(typeConverter, context);
158+
patterns.add<ConvertTosaConstWithIntegerTensorType>(typeConverter, context);
128159

129160
if (failed(
130161
applyFullConversion(getOperation(), target, std::move(patterns))))

mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -32,6 +32,21 @@ func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui16>) -> (tensor<1x1xi
3232

3333
// -----
3434

35+
// CHECK-LABEL: test_rescale_unsigned_zp
36+
// CHECK: %[[ZP_IN:.*]] = "tosa.const"() <{values = dense<-2> : tensor<1xi8>}> : () -> tensor<1xi8>
37+
// CHECK: %[[ZP_OUT:.*]] = "tosa.const"() <{values = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8>
38+
// CHECK: tosa.rescale %arg0, %0, %1, %[[ZP_IN]], %[[ZP_OUT]] {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = SINGLE_ROUND, scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>)
39+
func.func @test_rescale_unsigned_zp(%arg0: tensor<1x1xui8>) -> tensor<1x1xi8> {
40+
%0 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
41+
%1 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
42+
%2 = "tosa.const"() <{values = dense<254> : tensor<1xui8>}> : () -> tensor<1xui8>
43+
%3 = "tosa.const"() <{values = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8>
44+
%r = tosa.rescale %arg0, %0, %1, %2, %3 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = SINGLE_ROUND, scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xui8>, tensor<1xi8>) -> tensor<1x1xi8>
45+
return %r : tensor<1x1xi8>
46+
}
47+
48+
// -----
49+
3550
// CHECK-LABEL: test_unsigned_function_signature
3651
// CHECK: %arg0: tensor<1xi8>, %arg1: tensor<1xi8>
3752
func.func @test_unsigned_function_signature(%arg0: tensor<1xui8>, %arg1: tensor<1xui8>) -> (tensor<1xui8>, tensor<1xui8>) {
@@ -41,6 +56,15 @@ func.func @test_unsigned_function_signature(%arg0: tensor<1xui8>, %arg1: tensor<
4156

4257
// -----
4358

59+
// CHECK-LABEL: test_unsigned_const_data
60+
// CHECK: "tosa.const"() <{values = dense<[-1, -2, 0, 1, -128]> : tensor<5xi8>}> : () -> tensor<5xi8>
61+
func.func @test_unsigned_const_data() -> tensor<5xui8> {
62+
%0 = "tosa.const"() <{values = dense<[255, 254, 0, 1, 128]> : tensor<5xui8>}> : () -> tensor<5xui8>
63+
return %0 : tensor<5xui8>
64+
}
65+
66+
// -----
67+
4468
// CHECK-LABEL: test_no_change
4569
// CHECK: %arg0: tensor<13x21x3xi8>
4670
func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {

0 commit comments

Comments
 (0)