-
Notifications
You must be signed in to change notification settings - Fork 15.4k
[mlir][tosa] Handle unsigned constants in TosaConvertIntegerTypeToSignless
#156483
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
…gnless`
This commit fixes handling of unsigned constant data in the
`TosaConvertIntegerTypeToSignless` pass. Previoulsy, the type of the
"values" attribute would remain unsigned, which caused an error in
the const ops verifier:
```
error: 'tosa.const' op expected same attr/result element types
%input_zp = "tosa.const"() {values = dense<17> : tensor<1xui8>} : () -> tensor<1xui8>
^
note: see current operation: %0 = "tosa.const"() <{values = dense<17> : tensor<1xui8>}> : () -> tensor<1xi8>
```
Now the constant data in "values" is transformed to signless as well.
Change-Id: I49f492ca1643a37885b5f44fc11f7614a6e158a3
|
@llvm/pr-subscribers-mlir @llvm/pr-subscribers-mlir-tosa Author: Luke Hutton (lhutton1) ChangesThis commit fixes handling of unsigned constant data in the Now the constant data in "values" is transformed to signless as well. Full diff: https://github.com/llvm/llvm-project/pull/156483.diff 2 Files Affected:
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
index 706b5ddd22e72..4b131333b956a 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaConvertIntegerTypeToSignless.cpp
@@ -103,6 +103,32 @@ class ConvertGenericOpWithIntegerTensorType : public ConversionPattern {
}
};
+class ConvertTosaConstWithIntegerTensorType
+ : public OpConversionPattern<tosa::ConstOp> {
+ using OpConversionPattern::OpConversionPattern;
+
+ LogicalResult
+ matchAndRewrite(tosa::ConstOp op, OpAdaptor adaptor,
+ ConversionPatternRewriter &rewriter) const final {
+ const ElementsAttr oldAttr = op.getValues();
+ const auto oldTy = llvm::cast<ShapedType>(oldAttr.getType());
+ const auto newTy =
+ llvm::cast<ShapedType>(typeConverter->convertType(oldTy));
+ if (oldTy == newTy)
+ return success();
+
+ ElementsAttr newAttr = oldAttr;
+ if (auto denseAttr = llvm::dyn_cast<DenseElementsAttr>(oldAttr)) {
+ newAttr = DenseElementsAttr::get(newTy, denseAttr.getRawData());
+ } else {
+ return rewriter.notifyMatchFailure(op, "unknown elements attribute type");
+ }
+
+ rewriter.replaceOpWithNewOp<tosa::ConstOp>(op, newTy, newAttr);
+ return success();
+ }
+};
+
class TosaConvertIntegerTypeToSignless
: public impl::TosaConvertIntegerTypeToSignlessBase<
TosaConvertIntegerTypeToSignless> {
@@ -116,6 +142,10 @@ class TosaConvertIntegerTypeToSignless
return typeConverter.isSignatureLegal(op.getFunctionType()) &&
typeConverter.isLegal(&op.getBody());
});
+ target.addDynamicallyLegalOp<tosa::ConstOp>([&](tosa::ConstOp op) {
+ return typeConverter.isLegal(op.getType()) &&
+ typeConverter.isLegal(op.getValues().getType());
+ });
target.markUnknownOpDynamicallyLegal([&](Operation *op) {
return typeConverter.isLegal(op->getOperandTypes()) &&
typeConverter.isLegal(op->getResultTypes());
@@ -125,6 +155,7 @@ class TosaConvertIntegerTypeToSignless
populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(
patterns, typeConverter);
patterns.add<ConvertGenericOpWithIntegerTensorType>(typeConverter, context);
+ patterns.add<ConvertTosaConstWithIntegerTensorType>(typeConverter, context);
if (failed(
applyFullConversion(getOperation(), target, std::move(patterns))))
diff --git a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
index a64f69a8931fb..b7dbf9faf0936 100644
--- a/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
+++ b/mlir/test/Dialect/Tosa/tosa-convert-integer-type-to-signless.mlir
@@ -32,6 +32,21 @@ func.func @test_rescale_input_unsigned(%arg0: tensor<1x1xui16>) -> (tensor<1x1xi
// -----
+// CHECK-LABEL: test_rescale_unsigned_zp
+// CHECK: %[[ZP_IN:.*]] = "tosa.const"() <{values = dense<-2> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: %[[ZP_OUT:.*]] = "tosa.const"() <{values = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8>
+// CHECK: tosa.rescale %arg0, %0, %1, %[[ZP_IN]], %[[ZP_OUT]] {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = SINGLE_ROUND, scale32 = true} : (tensor<1x1xi8>, tensor<1xi32>, tensor<1xi8>, tensor<1xi8>, tensor<1xi8>)
+func.func @test_rescale_unsigned_zp(%arg0: tensor<1x1xui8>) -> tensor<1x1xi8> {
+ %0 = "tosa.const"() <{values = dense<2> : tensor<1xi32>}> : () -> tensor<1xi32>
+ %1 = "tosa.const"() <{values = dense<1> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %2 = "tosa.const"() <{values = dense<254> : tensor<1xui8>}> : () -> tensor<1xui8>
+ %3 = "tosa.const"() <{values = dense<2> : tensor<1xi8>}> : () -> tensor<1xi8>
+ %r = tosa.rescale %arg0, %0, %1, %2, %3 {input_unsigned = true, output_unsigned = false, per_channel = false, rounding_mode = SINGLE_ROUND, scale32 = true} : (tensor<1x1xui8>, tensor<1xi32>, tensor<1xi8>, tensor<1xui8>, tensor<1xi8>) -> tensor<1x1xi8>
+ return %r : tensor<1x1xi8>
+}
+
+// -----
+
// CHECK-LABEL: test_unsigned_function_signature
// CHECK: %arg0: tensor<1xi8>, %arg1: tensor<1xi8>
func.func @test_unsigned_function_signature(%arg0: tensor<1xui8>, %arg1: tensor<1xui8>) -> (tensor<1xui8>, tensor<1xui8>) {
@@ -41,6 +56,15 @@ func.func @test_unsigned_function_signature(%arg0: tensor<1xui8>, %arg1: tensor<
// -----
+// CHECK-LABEL: test_unsigned_const_data
+// CHECK: "tosa.const"() <{values = dense<[-1, -2, 0, 1, -128]> : tensor<5xi8>}> : () -> tensor<5xi8>
+func.func @test_unsigned_const_data() -> tensor<5xui8> {
+ %0 = "tosa.const"() <{values = dense<[255, 254, 0, 1, 128]> : tensor<5xui8>}> : () -> tensor<5xui8>
+ return %0 : tensor<5xui8>
+}
+
+// -----
+
// CHECK-LABEL: test_no_change
// CHECK: %arg0: tensor<13x21x3xi8>
func.func @test_no_change(%arg0: tensor<13x21x3xi8>) -> tensor<13x21x3xi8> {
|
mplatings
left a comment
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
This commit fixes handling of unsigned constant data in the
TosaConvertIntegerTypeToSignlesspass. Previously, the type of the "values" attribute would remain unsigned, which caused an error in the const ops verifier:Now the constant data in "values" is transformed to signless as well.