Skip to content

Conversation

@mplatings
Copy link
Collaborator

Lower unrealized_conversion_cast of signed/unsigned/signless integer types of the same size to spirv.Bitcast.

arith.bitcast is specifically for signless types, hence it is not used for such casts and unrealized_conversion_cast is used instead.

Lower unrealized_conversion_cast of signed/unsigned/signless integer
types of the same size to spirv.Bitcast.

arith.bitcast is specifically for signless types, hence it is not used
for such casts and unrealized_conversion_cast is used instead.

Co-authored-by: Thomas Preud'homme <[email protected]>
@llvmbot
Copy link
Member

llvmbot commented Aug 26, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-spirv

Author: Michael Platings (mplatings)

Changes

Lower unrealized_conversion_cast of signed/unsigned/signless integer types of the same size to spirv.Bitcast.

arith.bitcast is specifically for signless types, hence it is not used for such casts and unrealized_conversion_cast is used instead.


Full diff: https://github.com/llvm/llvm-project/pull/155388.diff

2 Files Affected:

  • (modified) mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp (+96-31)
  • (modified) mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir (+29)
diff --git a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
index 265293b83f84c..ee694104dc918 100644
--- a/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
+++ b/mlir/lib/Conversion/ArithToSPIRV/ArithToSPIRV.cpp
@@ -856,6 +856,17 @@ convertArithRoundingModeToSPIRV(arith::RoundingMode roundingMode) {
   llvm_unreachable("Unhandled rounding mode");
 }
 
+static bool isSignednessCast(Type srcType, Type dstType) {
+  if (srcType.isInteger() && dstType.isInteger()) {
+    return srcType.getIntOrFloatBitWidth() == dstType.getIntOrFloatBitWidth();
+  }
+  if (isa<VectorType>(srcType) && isa<VectorType>(dstType)) {
+    return isSignednessCast(cast<VectorType>(srcType).getElementType(),
+                            cast<VectorType>(dstType).getElementType());
+  }
+  return false;
+}
+
 /// Converts type-casting standard operations to SPIR-V operations.
 template <typename Op, typename SPIRVOp>
 struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
@@ -864,42 +875,86 @@ struct TypeCastingOpPattern final : public OpConversionPattern<Op> {
   LogicalResult
   matchAndRewrite(Op op, typename Op::Adaptor adaptor,
                   ConversionPatternRewriter &rewriter) const override {
-    Type srcType = llvm::getSingleElement(adaptor.getOperands()).getType();
-    Type dstType = this->getTypeConverter()->convertType(op.getType());
-    if (!dstType)
-      return getTypeConversionFailure(rewriter, op);
+    TypeRange dstTypes;
+    SmallVector<Type> newDstTypes;
+    SmallVector<Value> unrealizedConvCastSrcs;
+    SmallVector<Type> unrealizedConvCastDstTypes;
+    constexpr bool isUnrealizedConvCast =
+        std::is_same_v<Op, UnrealizedConversionCastOp>;
+    if constexpr (isUnrealizedConvCast)
+      dstTypes = op.getOutputs().getTypes();
+    else
+      dstTypes = op.getType();
+    LogicalResult matched = failure();
+    for (auto [src, dstType] : llvm::zip(adaptor.getOperands(), dstTypes)) {
+      Type srcType = src.getType();
+      // Use UnrealizedConversionCast as the bridge so that we don't need to
+      // pull in patterns for other dialects.
+      if (isUnrealizedConvCast && !isSignednessCast(srcType, dstType)) {
+        newDstTypes.push_back(dstType);
+        unrealizedConvCastSrcs.push_back(src);
+        unrealizedConvCastDstTypes.push_back(dstType);
+        continue;
+      }
+      dstType = this->getTypeConverter()->convertType(dstType);
+      if (!dstType)
+        return getTypeConversionFailure(rewriter, op);
+
+      if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
+        return failure();
+      matched = success();
+      newDstTypes.push_back(dstType);
+    }
 
-    if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
+    if (failed(matched))
       return failure();
 
-    if (dstType == srcType) {
-      // Due to type conversion, we are seeing the same source and target type.
-      // Then we can just erase this operation by forwarding its operand.
-      rewriter.replaceOp(op, adaptor.getOperands().front());
-    } else {
-      // Compute new rounding mode (if any).
-      std::optional<spirv::FPRoundingMode> rm = std::nullopt;
-      if (auto roundingModeOp =
-              dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
-        if (arith::RoundingModeAttr roundingMode =
-                roundingModeOp.getRoundingModeAttr()) {
-          if (!(rm =
-                    convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
-            return rewriter.notifyMatchFailure(
-                op->getLoc(),
-                llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
-          }
+    // Compute new rounding mode (if any).
+    Location loc = op->getLoc();
+    std::optional<spirv::FPRoundingMode> rm = std::nullopt;
+    if (auto roundingModeOp =
+            dyn_cast<arith::ArithRoundingModeInterface>(*op)) {
+      if (arith::RoundingModeAttr roundingMode =
+              roundingModeOp.getRoundingModeAttr()) {
+        if (!(rm = convertArithRoundingModeToSPIRV(roundingMode.getValue()))) {
+          return rewriter.notifyMatchFailure(
+              loc,
+              llvm::formatv("unsupported rounding mode '{0}'", roundingMode));
         }
       }
-      // Create replacement op and attach rounding mode attribute (if any).
-      auto newOp = rewriter.template replaceOpWithNewOp<SPIRVOp>(
-          op, dstType, adaptor.getOperands());
-      if (rm) {
-        newOp->setAttr(
-            getDecorationString(spirv::Decoration::FPRoundingMode),
-            spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
+    }
+
+    llvm::DenseMap<Value, Value> unrealizedConvCastSrcDstMap;
+    if (!unrealizedConvCastSrcs.empty()) {
+      auto newOp = rewriter.create<UnrealizedConversionCastOp>(
+          loc, unrealizedConvCastDstTypes, unrealizedConvCastSrcs);
+      for (auto [src, dst] :
+           llvm::zip(unrealizedConvCastSrcs, newOp.getResults()))
+        unrealizedConvCastSrcDstMap[src] = dst;
+    }
+
+    SmallVector<Value> newValues;
+    for (auto [src, dstType] : llvm::zip(adaptor.getOperands(), newDstTypes)) {
+      Type srcType = src.getType();
+      if (dstType == srcType) {
+        // Due to type conversion, we are seeing the same source and target
+        // type. Then we can just erase this operation by forwarding its
+        // operand.
+        newValues.push_back(src);
+      } else if (isUnrealizedConvCast && !isSignednessCast(srcType, dstType)) {
+        newValues.push_back(unrealizedConvCastSrcDstMap[src]);
+      } else {
+        // Create replacement op and attach rounding mode attribute (if any).
+        auto newOp = rewriter.template create<SPIRVOp>(loc, dstType, src);
+        if (rm) {
+          newOp->setAttr(
+              getDecorationString(spirv::Decoration::FPRoundingMode),
+              spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
+        }
+        newValues.push_back(newOp.getResult());
       }
     }
+    rewriter.replaceOp(op, newValues);
     return success();
   }
 };
@@ -1331,6 +1386,7 @@ void mlir::arith::populateArithToSPIRVPatterns(
     TypeCastingOpPattern<arith::IndexCastOp, spirv::SConvertOp>,
     TypeCastingOpPattern<arith::IndexCastUIOp, spirv::UConvertOp>,
     TypeCastingOpPattern<arith::BitcastOp, spirv::BitcastOp>,
+    TypeCastingOpPattern<UnrealizedConversionCastOp, spirv::BitcastOp>,
     CmpIOpBooleanPattern, CmpIOpPattern,
     CmpFOpNanNonePattern, CmpFOpPattern,
     AddUIExtendedOpPattern,
@@ -1385,8 +1441,17 @@ struct ConvertArithToSPIRVPass
     SPIRVTypeConverter typeConverter(targetAttr, options);
 
     // Use UnrealizedConversionCast as the bridge so that we don't need to pull
-    // in patterns for other dialects.
-    target->addLegalOp<UnrealizedConversionCastOp>();
+    // in patterns for other dialects. If the UnrealizedConversionCast is
+    // between integers of the same bitwidth, it is either a nop or a
+    // signedness cast which the corresponding pattern convert to Bitcast.
+    target->addDynamicallyLegalOp<UnrealizedConversionCastOp>(
+        [&](UnrealizedConversionCastOp op) {
+          for (auto [srcType, dstType] :
+               llvm::zip(op.getOperandTypes(), op.getResultTypes()))
+            if (isSignednessCast(srcType, dstType))
+              return false;
+          return true;
+        });
 
     // Fail hard when there are any remaining 'arith' ops.
     target->addIllegalDialect<arith::ArithDialect>();
diff --git a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
index 6e2352e706acc..b9a4232758a17 100644
--- a/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
+++ b/mlir/test/Conversion/ArithToSPIRV/arith-to-spirv.mlir
@@ -743,6 +743,35 @@ func.func @bit_cast(%arg0: vector<2xf32>, %arg1: i64) {
   return
 }
 
+// CHECK-LABEL: @unrealized_conversion_cast
+func.func @unrealized_conversion_cast(%arg0: vector<3xi64>, %arg1: i16, %arg2: f32) {
+  // CHECK-NEXT: spirv.Bitcast %{{.+}} : vector<3xi64> to vector<3xui64>
+  %0 = builtin.unrealized_conversion_cast %arg0 : vector<3xi64> to vector<3xui64>
+  // CHECK-NEXT: spirv.Bitcast %{{.+}} : i16 to ui16
+  %1 = builtin.unrealized_conversion_cast %arg1 : i16 to ui16
+
+  // CHECK-NEXT: spirv.Bitcast %{{.+}} : vector<3xi64> to vector<3xui64>
+  // CHECK-NEXT: spirv.Bitcast %{{.+}} : i16 to ui16
+  %2:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : vector<3xi64>, i16 to vector<3xui64>, ui16
+
+  // CHECK-NEXT: spirv.Bitcast %{{.+}} : i16 to ui16
+  %3:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : vector<3xi64>, i16 to vector<3xi64>, ui16
+  // CHECK-NEXT: spirv.Bitcast %{{.+}} : vector<3xi64> to vector<3xui64>
+  %4:2 = builtin.unrealized_conversion_cast %arg0, %arg1 : vector<3xi64>, i16 to vector<3xui64>, i16
+
+  // bitcast from float to int should be represented using arith.bitcast
+  // CHECK-NEXT: builtin.unrealized_conversion_cast %{{.+}} : f32 to i32
+  %5 = builtin.unrealized_conversion_cast %arg2 : f32 to i32
+
+  // test mixed signedness and non-signedness cast
+  // CHECK-NEXT: builtin.unrealized_conversion_cast %{{.+}} : f32 to f16
+  // CHECK-NEXT: spirv.Bitcast %{{.+}} : i32 to ui32
+  %6:2 = builtin.unrealized_conversion_cast %5, %arg2 : i32, f32 to ui32, f16
+
+  // CHECK-NEXT: return
+  return
+}
+
 // CHECK-LABEL: @fpext1
 func.func @fpext1(%arg0: f16) -> f64 {
   // CHECK: spirv.FConvert %{{.*}} : f16 to f64

unrealizedConvCastDstTypes.push_back(dstType);
continue;
}
dstType = this->getTypeConverter()->convertType(dstType);
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could it make sense to create a new variable newDstType to avoid confusion?

}

if (isBoolScalarOrVector(srcType) || isBoolScalarOrVector(dstType))
if (failed(matched))
Copy link
Contributor

@fabrizio-indirli fabrizio-indirli Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could we get rid of the matched variable, and instead create a const bool variable after the loop to check if !newDstTypes.empty() && (newDstTypes.size() > unrealizedConvCastDstTypes.size()) here, or am I missing something?

spirv::FPRoundingModeAttr::get(rewriter.getContext(), *rm));
}

llvm::DenseMap<Value, Value> unrealizedConvCastSrcDstMap;
Copy link
Contributor

@fabrizio-indirli fabrizio-indirli Aug 26, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: could it make sense to add a comment explaining what's happening here? E.g. "Recreate unrealized_conversion_cast ops for unhandled casts" (if I understood the logic correctly)?

@fabrizio-indirli
Copy link
Contributor

Thanks for the PR! Looks mostly good to me, I left just a few comments :)

Copy link
Member

@kuhar kuhar left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What produces these unrealized_conversion_casts? Could you show the input IR and the result of conversion to SPIR-V that runs into the signess issue?

It's not clear to be we should have dedicated handling for unrealized_conversion_casts.

@mplatings
Copy link
Collaborator Author

What produces these unrealized_conversion_casts? Could you show the input SPIR-V that runs into the signess issue?

One instance is #141096.
https://github.com/arm/ai-ml-sdk-model-converter can also generate content that ends up as unsigned inputs.

The problem I'm encountering is around TOSA rescale because there's tooling that wants to use tensors of unsigned type if input_unsigned or output_unsigned are true.

@kuhar
Copy link
Member

kuhar commented Aug 26, 2025

One instance is #141096.

This fix looks fishy to me, I wouldn't not expect any unrealized_casts as an input for conversion to spirv.

A motivating example at the level of arith/vector/sfc or the test dialect. would help me understand if this needs handling.

@mplatings
Copy link
Collaborator Author

One instance is #141096.

This fix looks fishy to me, I wouldn't not expect any unrealized_casts as an input for conversion to spirv.

A motivating example at the level of arith/vector/sfc or the test dialect. would help me understand if this needs handling.

I can create an arith example by running mlir-opt -pass-pipeline="builtin.module(func.func(tosa-to-linalg))" on this test from #141096:

func.func @rescale_i8_unsigned_output_explicit(%arg0 : tensor<2xui8>) -> () {
  %multiplier = "tosa.const"() {values = dense<19689> : tensor<1xi16> } : () -> tensor<1xi16>
  %shift = "tosa.const"() {values = dense<15> : tensor<1xi8> } : () -> tensor<1xi8>
  %input_zp = "tosa.const"() {values = dense<17> : tensor<1xui8>} : () -> tensor<1xui8>
  %output_zp = "tosa.const"() {values = dense<-22> : tensor<1xi8>} : () -> tensor<1xi8>
  %1 = tosa.rescale %arg0, %multiplier, %shift, %input_zp, %output_zp {scale32 = false, rounding_mode = "SINGLE_ROUND", per_channel = false, input_unsigned = true, output_unsigned = true} : (tensor<2xui8>, tensor<1xi16>, tensor<1xi8>, tensor<1xui8>, tensor<1xi8>) -> tensor<2xui8>
  return
}

which gives you:

#map = affine_map<(d0) -> (d0)>
module {
  func.func @rescale_i8_unsigned_output_explicit(%arg0: tensor<2xui8>) {
    %0 = "tosa.const"() <{values = dense<19689> : tensor<1xi16>}> : () -> tensor<1xi16>
    %1 = "tosa.const"() <{values = dense<15> : tensor<1xi8>}> : () -> tensor<1xi8>
    %2 = "tosa.const"() <{values = dense<17> : tensor<1xui8>}> : () -> tensor<1xui8>
    %3 = "tosa.const"() <{values = dense<-22> : tensor<1xi8>}> : () -> tensor<1xi8>
    %c19689_i32 = arith.constant 19689 : i32
    %c15_i8 = arith.constant 15 : i8
    %4 = tensor.empty() : tensor<2xui8>
    %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xui8>) outs(%4 : tensor<2xui8>) {
    ^bb0(%in: ui8, %out: ui8):
      %c17_i32 = arith.constant 17 : i32
      %c234_i32 = arith.constant 234 : i32
      %6 = builtin.unrealized_conversion_cast %in : ui8 to i8
      %7 = arith.extui %6 : i8 to i32
      %8 = arith.subi %7, %c17_i32 : i32
      %9 = tosa.apply_scale %8, %c19689_i32, %c15_i8 {rounding_mode = "SINGLE_ROUND"} : (i32, i32, i8) -> i32
      %10 = arith.addi %9, %c234_i32 : i32
      %c0_i32 = arith.constant 0 : i32
      %c255_i32 = arith.constant 255 : i32
      %11 = arith.maxsi %c0_i32, %10 : i32
      %12 = arith.minsi %c255_i32, %11 : i32
      %13 = arith.trunci %12 : i32 to i8
      %14 = builtin.unrealized_conversion_cast %13 : i8 to ui8
      linalg.yield %14 : ui8
    } -> tensor<2xui8>
    return
  }
}

Converting that with mlir-opt -convert-arith-to-spirv gives you:

module attributes {spirv.target_env = #spirv.target_env<#spirv.vce<v1.0, [Int8, Int16, Int64, Float16, Float64, Shader], []>, #spirv.resource_limits<>>} {
  func.func @rescale_i8_unsigned_output_explicit(%arg0: tensor<2xui8>) {
    %0 = "tosa.const"() <{values = dense<19689> : tensor<1xi16>}> : () -> tensor<1xi16>
    %1 = "tosa.const"() <{values = dense<15> : tensor<1xi8>}> : () -> tensor<1xi8>
    %2 = "tosa.const"() <{values = dense<17> : tensor<1xui8>}> : () -> tensor<1xui8>
    %3 = "tosa.const"() <{values = dense<-22> : tensor<1xi8>}> : () -> tensor<1xi8>
    %cst19689_i32 = spirv.Constant 19689 : i32
    %cst15_i8 = spirv.Constant 15 : i8
    %4 = tensor.empty() : tensor<2xui8>
    %5 = linalg.generic {indexing_maps = [#map, #map], iterator_types = ["parallel"]} ins(%arg0 : tensor<2xui8>) outs(%4 : tensor<2xui8>) {
    ^bb0(%in: ui8, %out: ui8):
      %cst17_i32 = spirv.Constant 17 : i32
      %cst234_i32 = spirv.Constant 234 : i32
      %6 = builtin.unrealized_conversion_cast %in : ui8 to i8
      %7 = spirv.UConvert %6 : i8 to i32
      %8 = spirv.ISub %7, %cst17_i32 : i32
      %9 = tosa.apply_scale %8, %cst19689_i32, %cst15_i8 {rounding_mode = "SINGLE_ROUND"} : (i32, i32, i8) -> i32
      %10 = spirv.IAdd %9, %cst234_i32 : i32
      %cst0_i32 = spirv.Constant 0 : i32
      %cst255_i32 = spirv.Constant 255 : i32
      %11 = spirv.GL.SMax %10, %cst0_i32 : i32
      %12 = spirv.GL.SMin %11, %cst255_i32 : i32
      %13 = spirv.SConvert %12 : i32 to i8
      %14 = builtin.unrealized_conversion_cast %13 : i8 to ui8
      linalg.yield %14 : ui8
    } -> tensor<2xui8>
    return
  }
}

So there's a problem at %6 = builtin.unrealized_conversion_cast %in : ui8 to i8 because that should have been lowered to spirv.Bitcast

(Apologies if I've misunderstood you, I'm relatively new to MLIR)

@kuhar
Copy link
Member

kuhar commented Aug 26, 2025

So there's a problem at %6 = builtin.unrealized_conversion_cast %in : ui8 to i8 because that should have been lowered to spirv.Bitcast

I think the real issue is that step 2 (after tosa to linalg) has an unresolved cast. In general, these are supposed to be inserted automatically by the dialect conversion driver and eventually cancel out. Once you start emitting those manually, this assumption may no longer hold.

@mplatings
Copy link
Collaborator Author

I think the real issue is that step 2 (after tosa to linalg) has an unresolved cast

I kind of agree, but it seems that unrealized_conversion_cast is the only way to represent signedness casts, given that arith.bitcast is explicitly only for signless casts.

Any advice for how to proceed?

@kuhar
Copy link
Member

kuhar commented Aug 27, 2025

Can you can add a new op to represent signedness casts? This seems useful in general to get signed/unsigned values in and out of arith.

@mplatings
Copy link
Collaborator Author

Potentially. Do you have any idea why arith.bitcast is specifically signless? The easy thing to do seems to be to relax that constraint

@kuhar
Copy link
Member

kuhar commented Aug 27, 2025

Potentially. Do you have any idea why arith.bitcast is specifically signless? The easy thing to do seems to be to relax that constraint

I think it would be worth discussing on discourse. This has probably been considered before, but I'm not aware of any specific discussions.

@mplatings mplatings marked this pull request as draft August 28, 2025 08:25
@mplatings
Copy link
Collaborator Author

I found this old discussion: https://discourse.llvm.org/t/rfc-signednesscastop/3253. From that I conclude that signs should be removed from types before lowering to arith et al. And I agree that #141096 doesn't look right.
@lhutton1 has added a conversion pass --tosa-convert-integer-type-to-signless so I will try to use that to work around the problem.

@mplatings mplatings closed this Sep 17, 2025
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants