Skip to content

Conversation

@FranklandJack
Copy link
Contributor

The TOSA-v1.0 specification makes the shift attribute of the MUL (Hammard product) operator an input. Move the shift parameter of the MUL operator in the MILR TOSA dialect from an attribute to an input and update any lit tests appropriately.

Expand the verifier of the tosa::MulOp operation to check the various constraints defined in the TOSA-v1.0 specification. Specifically, ensure that all input operands (excluding the optional shift) are of the same rank. This means that broadcasting tests which previously checked rank-0 tensors would be broadcast are no longer valid and are removed.

@llvmbot
Copy link
Member

llvmbot commented Jan 7, 2025

@llvm/pr-subscribers-mlir-linalg

@llvm/pr-subscribers-mlir

Author: Jack Frankland (FranklandJack)

Changes

The TOSA-v1.0 specification makes the shift attribute of the MUL (Hammard product) operator an input. Move the shift parameter of the MUL operator in the MILR TOSA dialect from an attribute to an input and update any lit tests appropriately.

Expand the verifier of the tosa::MulOp operation to check the various constraints defined in the TOSA-v1.0 specification. Specifically, ensure that all input operands (excluding the optional shift) are of the same rank. This means that broadcasting tests which previously checked rank-0 tensors would be broadcast are no longer valid and are removed.


Patch is 30.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/121953.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+55-34)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+13-2)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+78-3)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp (+1-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+6-4)
  • (modified) mlir/test/Dialect/Tosa/broadcast.mlir (-9)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+19-6)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+13-10)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+13-2)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+8-6)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e3c725801d1629..dceb36116797c5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -800,7 +800,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
   let arguments = (ins
     Tosa_Tensor:$input1,
     Tosa_Tensor:$input2,
-    I8Attr:$shift
+    Optional<TosaTensorRankOf<[Tosa_Int8], [0]>>:$shift
   );
 
   let results = (outs
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 88e544c4e4b5f1..4ec19db2116546 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -90,43 +90,58 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   }
 
   // tosa::MulOp
-  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
-    return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
-
-  if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
-    Value a = args[0];
-    Value b = args[1];
-    auto shift =
-        cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
-    if (shift > 0) {
-      auto shiftConst =
-          rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
-      if (!a.getType().isInteger(32))
-        a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
-
-      if (!b.getType().isInteger(32))
-        b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
-
-      auto result = rewriter.create<tosa::ApplyScaleOp>(
-          loc, rewriter.getI32Type(), a, b, shiftConst,
-          rewriter.getBoolAttr(false));
-
-      if (elementTy.isInteger(32))
-        return result;
-
-      return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
+  if (isa<tosa::MulOp>(op)) {
+    auto shift_val = cast<tosa::MulOp>(op).getShift();
+    if (!elementTy.isInteger(32) && shift_val.getImpl()) {
+        (void)rewriter.notifyMatchFailure(op,
+                                          "Cannot have shift value for non i32 output");
+        return nullptr;
+    };
+
+    if (isa<FloatType>(elementTy)) {
+      return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
     }
 
-    int aWidth = a.getType().getIntOrFloatBitWidth();
-    int bWidth = b.getType().getIntOrFloatBitWidth();
-    int cWidth = resultTypes[0].getIntOrFloatBitWidth();
+    if (isa<IntegerType>(elementTy)) {
+      int32_t shift = 0;
+      ElementsAttr shift_elem;
+      if (shift_val.getImpl() && matchPattern(shift_val, m_Constant(&shift_elem))) {
+        // Explicit shift is set.
+        shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+      }
+
+      Value a = args[0];
+      Value b = args[1];
+      if (shift > 0) {
+        auto shiftConst =
+            rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+        if (!a.getType().isInteger(32))
+          a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
+
+        if (!b.getType().isInteger(32))
+          b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
+
+        auto result = rewriter.create<tosa::ApplyScaleOp>(
+            loc, rewriter.getI32Type(), a, b, shiftConst,
+            rewriter.getBoolAttr(false));
 
-    if (aWidth < cWidth)
-      a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
-    if (bWidth < cWidth)
-      b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
+        if (elementTy.isInteger(32))
+          return result;
 
-    return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
+        return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
+      }
+
+      int aWidth = a.getType().getIntOrFloatBitWidth();
+      int bWidth = b.getType().getIntOrFloatBitWidth();
+      int cWidth = resultTypes[0].getIntOrFloatBitWidth();
+
+      if (aWidth < cWidth)
+        a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
+      if (bWidth < cWidth)
+        b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
+
+      return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
+    }
   }
 
   // tosa::NegateOp
@@ -931,7 +946,13 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
   auto loc = operation->getLoc();
   auto rank =
       cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
-  auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
+  // For the mul op we need to avoid expanding the rank of the optional shift
+  // input.
+  auto operandsToExpand =
+      isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
+
+  auto expandedOperands =
+      expandInputRanks(rewriter, loc, operandsToExpand, rank);
   auto [targetShape, masterOperands] =
       computeTargetShape(rewriter, loc, indexPool, expandedOperands);
   auto broadcastOperands = broadcastDynamicDimensions(
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 39d0ee122b1630..892421155733b3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -614,7 +614,18 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
   auto rhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
-  const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
+  // Result right shift on i32_t data type only. For simplification, synthesize a zero
+  // shift for other date type.
+  int32_t shift = 0;
+  if (resultETy.isInteger(32)) {
+    ElementsAttr shift_elem;
+    if (getShift().getImpl()) {
+      if (!matchPattern(getShift(), m_Constant(&shift_elem)))
+        // cannot be folded when the shift value is unknown.
+        return {};
+      shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+    }
+  }
 
   if (rhsTy == resultTy) {
     if (isSplatZero(resultETy, lhsAttr))
@@ -629,7 +640,7 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
       return lhs;
   }
 
-  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
+  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
 }
 
 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 631d3c48f2df02..07f5d5e49f6d44 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -866,9 +866,84 @@ LogicalResult tosa::SliceOp::verify() {
 }
 
 LogicalResult tosa::MulOp::verify() {
-  Type elementTy = getInput1().getType().getElementType();
-  if (isa<FloatType>(elementTy) && getShift() != 0)
-    return emitOpError() << "require shift to be 0 for float type";
+  auto resElemType = getElementTypeOrSelf(getOutput());
+
+  // Verify if the element type amoung operands and result match tosa
+  // specification.
+  if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
+    IntegerType lhsIntType =
+      cast<IntegerType>(getElementTypeOrSelf(getInput1()));
+    IntegerType rhsIntType =
+      cast<IntegerType>(getElementTypeOrSelf(getInput2()));
+    if (lhsIntType != rhsIntType)
+      return emitOpError(
+          "requires the same element type for all operands");
+
+    // Though the spec requires the element type of result to be i32, a more
+    // relaxed way is provided at dialect level for easier cooperating with
+    // other dialects.
+    if (lhsIntType.getWidth() > resIntType.getWidth())
+      return emitOpError("invalid data type size for operands or result");
+
+  } else {
+    // For other supported type, the spec requires requires the same element
+    // type for all operands (excludes `shift` operand) and results.
+    for (int i = 0; i < 2; ++i) {
+      if (getElementTypeOrSelf(getOperand(i)) != resElemType)
+        return emitOpError(
+            "requires the same element type for all operands and results");
+    }
+  }
+
+  // Check if the shift value apply to non-i32 output type as that is not
+  // allowed in the spec.
+  if (!(llvm::isa<IntegerType>(resElemType) && resElemType.isInteger(32)))
+    if (getShift().getImpl())
+        return emitOpError(
+            "right shift output only on i32 data type");
+
+  // Verify the op has same ranks for all main operands (excludes extra operands
+  // such as shift of mul op, so this is the only difference with the built-in
+  // `SameOperandsAndResultRank` trait) and results types, if known.
+
+  // delegate function that returns true if type is a shaped type with known
+  // rank
+  auto hasRank = [](const Type type) {
+    if (auto shaped_type = dyn_cast<ShapedType>(type))
+      return shaped_type.hasRank();
+
+    return false;
+  };
+
+  auto rankedOperandTypes =
+    llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
+
+  auto rankedResultTypes =
+    llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
+
+  // If all operands and results are unranked, then no further verification.
+  if (rankedOperandTypes.empty() && rankedResultTypes.empty())
+    return success();
+
+  // delegate function that returns rank of shaped type with known rank
+  auto getRank = [](const Type type) {
+    return cast<ShapedType>(type).getRank();
+  };
+
+  auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
+    : getRank(*rankedResultTypes.begin());
+
+  for (size_t i = 0; i < 2; ++i) {
+    if (rank != getRank(rankedOperandTypes[i])) {
+      return emitOpError("operands don't have matching ranks");
+    }
+  }
+
+  for (const auto type : rankedResultTypes) {
+    if (rank != getRank(type)) {
+      return emitOpError("result type has different rank than operands");
+    }
+  }
 
   return success();
 }
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index e6fba211dc37ab..537c0cce04491e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -137,7 +137,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
 
     Value mulValue = rewriter
                          .create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
-                                              weight, /*shift=*/0)
+                                              weight, Value{} /* zero_shift */)
                          .getResult();
 
     // Reshape output to [N, H, W, C * M].
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 2a990eed3f681e..79afc75fd6c8ee 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -113,7 +113,7 @@ struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
 
     Value input1 = tosaBinaryOp.getInput1();
     Value input2 = tosaBinaryOp.getInput2();
-    int32_t shift = tosaBinaryOp.getShift();
+    Value shift = tosaBinaryOp.getShift();
     Value output = tosaBinaryOp.getResult();
     auto outputType = dyn_cast<RankedTensorType>(output.getType());
     if (!outputType)
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 265a75986c6c8d..7137d24afde3c7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -451,7 +451,7 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
 
   // CHECK: linalg.generic
   // CHECK: arith.mulf
-  %4 = tosa.mul %0, %1 {shift = 0 : i8} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %4 = tosa.mul %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: arith.negf
@@ -597,7 +597,7 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
   // CHECK: arith.extsi
   // CHECK: arith.extsi
   // CHECK: arith.muli
-  %0 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
+  %0 = tosa.mul %arg0, %arg0 : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
 
   return
 }
@@ -625,12 +625,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
 
   // CHECK: linalg.generic
   // CHECK: arith.muli
-  %2 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %shift1 = "tosa.const"() <{value = dense<0> : tensor<i8>}> : () -> tensor<i8>
+  %2 = tosa.mul %arg0, %arg0, %shift1 : (tensor<1xi32>, tensor<1xi32>, tensor<i8>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: arith.constant 2
   // CHECK: apply_scale
-  %3 = tosa.mul %arg0, %arg0 {shift = 2 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %shift2 = "tosa.const"() <{value = dense<2> : tensor<i8>}> : () -> tensor<i8>
+  %3 = tosa.mul %arg0, %arg0, %shift2: (tensor<1xi32>, tensor<1xi32>, tensor<i8>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: arith.divsi
diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
index 7613aa3b8dd03d..f33a1465b6009d 100644
--- a/mlir/test/Dialect/Tosa/broadcast.mlir
+++ b/mlir/test/Dialect/Tosa/broadcast.mlir
@@ -169,15 +169,6 @@ func.func @test_broadcast20(%arg0: tensor<3x3x4x1xf32>, %arg1: tensor<4x5xf32>)
   return %0 : tensor<3x3x4x5xf32>
 }
 
-// -----
-// CHECK-LABEL: broadcast_mul
-func.func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
-  // CHECK: %[[VAR1:.*]] = tosa.mul %[[VAR0]], %arg1
-  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
-  return %0 : tensor<17x16x15x14xi32>
-}
-
 // -----
 // CHECK-LABEL: broadcast_arithmetic_right_shift
 func.func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 67cd01f62f0bdf..0e783a9a23ae48 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -280,7 +280,7 @@ func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
   %ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
-  %1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %1 = tosa.mul %arg0, %ones : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
   return %1 : tensor<2x3xf32>
 }
 
@@ -291,7 +291,7 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
   %ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
-  %1 = tosa.mul %ones, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %1 = tosa.mul %ones, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
   return %1 : tensor<2x3xf32>
 }
 
@@ -302,7 +302,20 @@ func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
   %ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
-  %1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  %1 = tosa.mul %arg0, %ones : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  return %1 : tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_one_int_and_shift
+func.func @mul_one_int_and_shift(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
+  // CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x3xi32>}>
+  // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<31> : tensor<i8>}>
+  // CHECK: %[[VAL_3:.*]] = tosa.mul %arg0, %[[VAL_1]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<i8>)
+  %ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+  %shift = "tosa.const"() <{value = dense<31> : tensor<i8>}> : () -> tensor<i8>
+  %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<i8>) -> tensor<2x3xi32>
   return %1 : tensor<2x3xi32>
 }
 
@@ -313,11 +326,11 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
   // CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}
   // CHECK-NOT: tosa.mul
   %zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
-  %1 = tosa.mul %arg0, %zeros {shift = 0 : i8} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
+  %1 = tosa.mul %arg0, %zeros : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
 
   // CHECK-NOT: tosa.mul
   // CHECK: return %[[ZERO]], %[[ZERO]]
-  %2 = tosa.mul %zeros, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %2 = tosa.mul %zeros, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
   return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
 }
 
@@ -872,7 +885,7 @@ func.func @mul_quant_nofold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899
    // CHECK: tosa.mul
    %0 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    %1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
-   %2 = tosa.mul %0, %1 { shift = 0 : i8} : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %2 = tosa.mul %0, %1 : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    return %2 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
 }
 
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 2902c4a62009e9..b6947e05b26bf1 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -247,7 +247,7 @@ func.func @fold_div_splat_i32() -> tensor<i32> {
 func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
-  %mul = tosa.mul %arg0, %zero {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %mul = tosa.mul %arg0, %zero : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<f32>
 }
@@ -258,7 +258,7 @@ func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
-  %mul = tosa.mul %zero, %arg0 {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %mul = tosa.mul %zero, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<f32>
 }
@@ -269,7 +269,7 @@ func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0>
-  %mul = tosa.mul %arg0, %zero {shift = 0 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %mul = tosa.mul %arg0, %zero : (tensor<i32>, tensor<i32>) -> tensor<i32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<i32>
 }
@@ -280,7 +280,7 @@ func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0>
-  %mul = tosa.mul %zero, %arg0 {shift = 0 : i8} : (te...
[truncated]

@llvmbot
Copy link
Member

llvmbot commented Jan 7, 2025

@llvm/pr-subscribers-mlir-tosa

Author: Jack Frankland (FranklandJack)

Changes

The TOSA-v1.0 specification makes the shift attribute of the MUL (Hammard product) operator an input. Move the shift parameter of the MUL operator in the MILR TOSA dialect from an attribute to an input and update any lit tests appropriately.

Expand the verifier of the tosa::MulOp operation to check the various constraints defined in the TOSA-v1.0 specification. Specifically, ensure that all input operands (excluding the optional shift) are of the same rank. This means that broadcasting tests which previously checked rank-0 tensors would be broadcast are no longer valid and are removed.


Patch is 30.80 KiB, truncated to 20.00 KiB below, full version: https://github.com/llvm/llvm-project/pull/121953.diff

14 Files Affected:

  • (modified) mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td (+1-1)
  • (modified) mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp (+55-34)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp (+13-2)
  • (modified) mlir/lib/Dialect/Tosa/IR/TosaOps.cpp (+78-3)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp (+1-1)
  • (modified) mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp (+1-1)
  • (modified) mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir (+6-4)
  • (modified) mlir/test/Dialect/Tosa/broadcast.mlir (-9)
  • (modified) mlir/test/Dialect/Tosa/canonicalize.mlir (+19-6)
  • (modified) mlir/test/Dialect/Tosa/constant-op-fold.mlir (+13-10)
  • (modified) mlir/test/Dialect/Tosa/invalid.mlir (+13-2)
  • (modified) mlir/test/Dialect/Tosa/ops.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-decompose-depthwise.mlir (+2-2)
  • (modified) mlir/test/Dialect/Tosa/tosa-infer-shapes.mlir (+8-6)
diff --git a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
index e3c725801d1629..dceb36116797c5 100644
--- a/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
+++ b/mlir/include/mlir/Dialect/Tosa/IR/TosaOps.td
@@ -800,7 +800,7 @@ def Tosa_MulOp : Tosa_ElementwiseOp<"mul", [
   let arguments = (ins
     Tosa_Tensor:$input1,
     Tosa_Tensor:$input2,
-    I8Attr:$shift
+    Optional<TosaTensorRankOf<[Tosa_Int8], [0]>>:$shift
   );
 
   let results = (outs
diff --git a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
index 88e544c4e4b5f1..4ec19db2116546 100644
--- a/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
+++ b/mlir/lib/Conversion/TosaToLinalg/TosaToLinalg.cpp
@@ -90,43 +90,58 @@ static Value createLinalgBodyCalculationForElementwiseOp(
   }
 
   // tosa::MulOp
-  if (isa<tosa::MulOp>(op) && isa<FloatType>(elementTy))
-    return rewriter.create<arith::MulFOp>(loc, resultTypes, args);
-
-  if (isa<tosa::MulOp>(op) && isa<IntegerType>(elementTy)) {
-    Value a = args[0];
-    Value b = args[1];
-    auto shift =
-        cast<IntegerAttr>(op->getAttr("shift")).getValue().getSExtValue();
-    if (shift > 0) {
-      auto shiftConst =
-          rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
-      if (!a.getType().isInteger(32))
-        a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
-
-      if (!b.getType().isInteger(32))
-        b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
-
-      auto result = rewriter.create<tosa::ApplyScaleOp>(
-          loc, rewriter.getI32Type(), a, b, shiftConst,
-          rewriter.getBoolAttr(false));
-
-      if (elementTy.isInteger(32))
-        return result;
-
-      return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
+  if (isa<tosa::MulOp>(op)) {
+    auto shift_val = cast<tosa::MulOp>(op).getShift();
+    if (!elementTy.isInteger(32) && shift_val.getImpl()) {
+        (void)rewriter.notifyMatchFailure(op,
+                                          "Cannot have shift value for non i32 output");
+        return nullptr;
+    };
+
+    if (isa<FloatType>(elementTy)) {
+      return rewriter.create<arith::MulFOp>(loc, resultTypes, args[0], args[1]);
     }
 
-    int aWidth = a.getType().getIntOrFloatBitWidth();
-    int bWidth = b.getType().getIntOrFloatBitWidth();
-    int cWidth = resultTypes[0].getIntOrFloatBitWidth();
+    if (isa<IntegerType>(elementTy)) {
+      int32_t shift = 0;
+      ElementsAttr shift_elem;
+      if (shift_val.getImpl() && matchPattern(shift_val, m_Constant(&shift_elem))) {
+        // Explicit shift is set.
+        shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+      }
+
+      Value a = args[0];
+      Value b = args[1];
+      if (shift > 0) {
+        auto shiftConst =
+            rewriter.create<arith::ConstantIntOp>(loc, shift, /*bitwidth=*/8);
+        if (!a.getType().isInteger(32))
+          a = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), a);
+
+        if (!b.getType().isInteger(32))
+          b = rewriter.create<arith::ExtSIOp>(loc, rewriter.getI32Type(), b);
+
+        auto result = rewriter.create<tosa::ApplyScaleOp>(
+            loc, rewriter.getI32Type(), a, b, shiftConst,
+            rewriter.getBoolAttr(false));
 
-    if (aWidth < cWidth)
-      a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
-    if (bWidth < cWidth)
-      b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
+        if (elementTy.isInteger(32))
+          return result;
 
-    return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
+        return rewriter.create<arith::TruncIOp>(loc, elementTy, result);
+      }
+
+      int aWidth = a.getType().getIntOrFloatBitWidth();
+      int bWidth = b.getType().getIntOrFloatBitWidth();
+      int cWidth = resultTypes[0].getIntOrFloatBitWidth();
+
+      if (aWidth < cWidth)
+        a = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], a);
+      if (bWidth < cWidth)
+        b = rewriter.create<arith::ExtSIOp>(loc, resultTypes[0], b);
+
+      return rewriter.create<arith::MulIOp>(loc, resultTypes, a, b);
+    }
   }
 
   // tosa::NegateOp
@@ -931,7 +946,13 @@ elementwiseMatchAndRewriteHelper(Operation *operation, ValueRange operands,
   auto loc = operation->getLoc();
   auto rank =
       cast<RankedTensorType>(operation->getResultTypes().front()).getRank();
-  auto expandedOperands = expandInputRanks(rewriter, loc, operands, rank);
+  // For the mul op we need to avoid expanding the rank of the optional shift
+  // input.
+  auto operandsToExpand =
+      isa<tosa::MulOp>(operation) ? operands.take_front(2) : operands;
+
+  auto expandedOperands =
+      expandInputRanks(rewriter, loc, operandsToExpand, rank);
   auto [targetShape, masterOperands] =
       computeTargetShape(rewriter, loc, indexPool, expandedOperands);
   auto broadcastOperands = broadcastDynamicDimensions(
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
index 39d0ee122b1630..892421155733b3 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaCanonicalizations.cpp
@@ -614,7 +614,18 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
   auto rhsAttr =
       llvm::dyn_cast_if_present<DenseElementsAttr>(adaptor.getInput2());
 
-  const int64_t shift = llvm::isa<IntegerType>(resultETy) ? getShift() : 0;
+  // Result right shift on i32_t data type only. For simplification, synthesize a zero
+  // shift for other date type.
+  int32_t shift = 0;
+  if (resultETy.isInteger(32)) {
+    ElementsAttr shift_elem;
+    if (getShift().getImpl()) {
+      if (!matchPattern(getShift(), m_Constant(&shift_elem)))
+        // cannot be folded when the shift value is unknown.
+        return {};
+      shift = shift_elem.getValues<IntegerAttr>()[0].getInt();
+    }
+  }
 
   if (rhsTy == resultTy) {
     if (isSplatZero(resultETy, lhsAttr))
@@ -629,7 +640,7 @@ OpFoldResult MulOp::fold(FoldAdaptor adaptor) {
       return lhs;
   }
 
-  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, getShift());
+  return mulBinaryFolder(lhsAttr, rhsAttr, resultTy, shift);
 }
 
 OpFoldResult SubOp::fold(FoldAdaptor adaptor) {
diff --git a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
index 631d3c48f2df02..07f5d5e49f6d44 100644
--- a/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
+++ b/mlir/lib/Dialect/Tosa/IR/TosaOps.cpp
@@ -866,9 +866,84 @@ LogicalResult tosa::SliceOp::verify() {
 }
 
 LogicalResult tosa::MulOp::verify() {
-  Type elementTy = getInput1().getType().getElementType();
-  if (isa<FloatType>(elementTy) && getShift() != 0)
-    return emitOpError() << "require shift to be 0 for float type";
+  auto resElemType = getElementTypeOrSelf(getOutput());
+
+  // Verify if the element type amoung operands and result match tosa
+  // specification.
+  if (auto resIntType = dyn_cast<IntegerType>(resElemType)) {
+    IntegerType lhsIntType =
+      cast<IntegerType>(getElementTypeOrSelf(getInput1()));
+    IntegerType rhsIntType =
+      cast<IntegerType>(getElementTypeOrSelf(getInput2()));
+    if (lhsIntType != rhsIntType)
+      return emitOpError(
+          "requires the same element type for all operands");
+
+    // Though the spec requires the element type of result to be i32, a more
+    // relaxed way is provided at dialect level for easier cooperating with
+    // other dialects.
+    if (lhsIntType.getWidth() > resIntType.getWidth())
+      return emitOpError("invalid data type size for operands or result");
+
+  } else {
+    // For other supported type, the spec requires requires the same element
+    // type for all operands (excludes `shift` operand) and results.
+    for (int i = 0; i < 2; ++i) {
+      if (getElementTypeOrSelf(getOperand(i)) != resElemType)
+        return emitOpError(
+            "requires the same element type for all operands and results");
+    }
+  }
+
+  // Check if the shift value apply to non-i32 output type as that is not
+  // allowed in the spec.
+  if (!(llvm::isa<IntegerType>(resElemType) && resElemType.isInteger(32)))
+    if (getShift().getImpl())
+        return emitOpError(
+            "right shift output only on i32 data type");
+
+  // Verify the op has same ranks for all main operands (excludes extra operands
+  // such as shift of mul op, so this is the only difference with the built-in
+  // `SameOperandsAndResultRank` trait) and results types, if known.
+
+  // delegate function that returns true if type is a shaped type with known
+  // rank
+  auto hasRank = [](const Type type) {
+    if (auto shaped_type = dyn_cast<ShapedType>(type))
+      return shaped_type.hasRank();
+
+    return false;
+  };
+
+  auto rankedOperandTypes =
+    llvm::to_vector(llvm::make_filter_range(getOperandTypes(), hasRank));
+
+  auto rankedResultTypes =
+    llvm::make_filter_range(getOperation()->getResultTypes(), hasRank);
+
+  // If all operands and results are unranked, then no further verification.
+  if (rankedOperandTypes.empty() && rankedResultTypes.empty())
+    return success();
+
+  // delegate function that returns rank of shaped type with known rank
+  auto getRank = [](const Type type) {
+    return cast<ShapedType>(type).getRank();
+  };
+
+  auto rank = !rankedOperandTypes.empty() ? getRank(*rankedOperandTypes.begin())
+    : getRank(*rankedResultTypes.begin());
+
+  for (size_t i = 0; i < 2; ++i) {
+    if (rank != getRank(rankedOperandTypes[i])) {
+      return emitOpError("operands don't have matching ranks");
+    }
+  }
+
+  for (const auto type : rankedResultTypes) {
+    if (rank != getRank(type)) {
+      return emitOpError("result type has different rank than operands");
+    }
+  }
 
   return success();
 }
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
index e6fba211dc37ab..537c0cce04491e 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaDecomposeDepthwise.cpp
@@ -137,7 +137,7 @@ struct DepthwiseConv2DIsMul : public OpRewritePattern<tosa::DepthwiseConv2DOp> {
 
     Value mulValue = rewriter
                          .create<tosa::MulOp>(op.getLoc(), mulShapeType, input,
-                                              weight, /*shift=*/0)
+                                              weight, Value{} /* zero_shift */)
                          .getResult();
 
     // Reshape output to [N, H, W, C * M].
diff --git a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
index 2a990eed3f681e..79afc75fd6c8ee 100644
--- a/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
+++ b/mlir/lib/Dialect/Tosa/Transforms/TosaMakeBroadcastable.cpp
@@ -113,7 +113,7 @@ struct ConvertTosaOp<tosa::MulOp> : public OpRewritePattern<tosa::MulOp> {
 
     Value input1 = tosaBinaryOp.getInput1();
     Value input2 = tosaBinaryOp.getInput2();
-    int32_t shift = tosaBinaryOp.getShift();
+    Value shift = tosaBinaryOp.getShift();
     Value output = tosaBinaryOp.getResult();
     auto outputType = dyn_cast<RankedTensorType>(output.getType());
     if (!outputType)
diff --git a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
index 265a75986c6c8d..7137d24afde3c7 100644
--- a/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
+++ b/mlir/test/Conversion/TosaToLinalg/tosa-to-linalg.mlir
@@ -451,7 +451,7 @@ func.func @test_simple_f32(%arg0: tensor<1xf32>) -> () {
 
   // CHECK: linalg.generic
   // CHECK: arith.mulf
-  %4 = tosa.mul %0, %1 {shift = 0 : i8} : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
+  %4 = tosa.mul %0, %1 : (tensor<1xf32>, tensor<1xf32>) -> tensor<1xf32>
 
   // CHECK: linalg.generic
   // CHECK: arith.negf
@@ -597,7 +597,7 @@ func.func @test_simple_i16(%arg0: tensor<1xi16>) -> () {
   // CHECK: arith.extsi
   // CHECK: arith.extsi
   // CHECK: arith.muli
-  %0 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
+  %0 = tosa.mul %arg0, %arg0 : (tensor<1xi16>, tensor<1xi16>) -> tensor<1xi32>
 
   return
 }
@@ -625,12 +625,14 @@ func.func @test_simple_i32(%arg0: tensor<1xi32>, %unsigned: tensor<1xui32>, %uns
 
   // CHECK: linalg.generic
   // CHECK: arith.muli
-  %2 = tosa.mul %arg0, %arg0 {shift = 0 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %shift1 = "tosa.const"() <{value = dense<0> : tensor<i8>}> : () -> tensor<i8>
+  %2 = tosa.mul %arg0, %arg0, %shift1 : (tensor<1xi32>, tensor<1xi32>, tensor<i8>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: arith.constant 2
   // CHECK: apply_scale
-  %3 = tosa.mul %arg0, %arg0 {shift = 2 : i8} : (tensor<1xi32>, tensor<1xi32>) -> tensor<1xi32>
+  %shift2 = "tosa.const"() <{value = dense<2> : tensor<i8>}> : () -> tensor<i8>
+  %3 = tosa.mul %arg0, %arg0, %shift2: (tensor<1xi32>, tensor<1xi32>, tensor<i8>) -> tensor<1xi32>
 
   // CHECK: linalg.generic
   // CHECK: arith.divsi
diff --git a/mlir/test/Dialect/Tosa/broadcast.mlir b/mlir/test/Dialect/Tosa/broadcast.mlir
index 7613aa3b8dd03d..f33a1465b6009d 100644
--- a/mlir/test/Dialect/Tosa/broadcast.mlir
+++ b/mlir/test/Dialect/Tosa/broadcast.mlir
@@ -169,15 +169,6 @@ func.func @test_broadcast20(%arg0: tensor<3x3x4x1xf32>, %arg1: tensor<4x5xf32>)
   return %0 : tensor<3x3x4x5xf32>
 }
 
-// -----
-// CHECK-LABEL: broadcast_mul
-func.func @test_broadcast_mul(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
-  // CHECK-DAG: %[[VAR0:.*]] = tosa.reshape %arg0 {new_shape = array<i64: 1, 1, 15, 14>}
-  // CHECK: %[[VAR1:.*]] = tosa.mul %[[VAR0]], %arg1
-  %0 = tosa.mul %arg0, %arg1 {shift = 1 : i8 } : (tensor<15x14xi32>, tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32>
-  return %0 : tensor<17x16x15x14xi32>
-}
-
 // -----
 // CHECK-LABEL: broadcast_arithmetic_right_shift
 func.func @test_broadcast_arithmetic_right_shift(%arg0: tensor<15x14xi32>, %arg1: tensor<17x16x15x14xi32>) -> tensor<17x16x15x14xi32> {
diff --git a/mlir/test/Dialect/Tosa/canonicalize.mlir b/mlir/test/Dialect/Tosa/canonicalize.mlir
index 67cd01f62f0bdf..0e783a9a23ae48 100644
--- a/mlir/test/Dialect/Tosa/canonicalize.mlir
+++ b/mlir/test/Dialect/Tosa/canonicalize.mlir
@@ -280,7 +280,7 @@ func.func @mul_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
   %ones = "tosa.const"() {value = dense<1.0> : tensor<2x3xf32>} : () -> tensor<2x3xf32>
-  %1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %1 = tosa.mul %arg0, %ones : (tensor<2x3xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
   return %1 : tensor<2x3xf32>
 }
 
@@ -291,7 +291,7 @@ func.func @mul_bcast_one_float(%arg0: tensor<2x3xf32>) -> tensor<2x3xf32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
   %ones = "tosa.const"() {value = dense<1.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
-  %1 = tosa.mul %ones, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %1 = tosa.mul %ones, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
   return %1 : tensor<2x3xf32>
 }
 
@@ -302,7 +302,20 @@ func.func @mul_one_int(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
   // CHECK: return %arg0
   // CHECK-NOT: tosa.mul
   %ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
-  %1 = tosa.mul %arg0, %ones {shift = 0 : i8} : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  %1 = tosa.mul %arg0, %ones : (tensor<2x3xi32>, tensor<2x3xi32>) -> tensor<2x3xi32>
+  return %1 : tensor<2x3xi32>
+}
+
+// -----
+
+// CHECK-LABEL: @mul_one_int_and_shift
+func.func @mul_one_int_and_shift(%arg0: tensor<2x3xi32>) -> tensor<2x3xi32> {
+  // CHECK-DAG: %[[VAL_1:.*]] = "tosa.const"() <{value = dense<1> : tensor<2x3xi32>}>
+  // CHECK-DAG: %[[VAL_2:.*]] = "tosa.const"() <{value = dense<31> : tensor<i8>}>
+  // CHECK: %[[VAL_3:.*]] = tosa.mul %arg0, %[[VAL_1]], %[[VAL_2]] : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<i8>)
+  %ones = "tosa.const"() {value = dense<1> : tensor<2x3xi32>} : () -> tensor<2x3xi32>
+  %shift = "tosa.const"() <{value = dense<31> : tensor<i8>}> : () -> tensor<i8>
+  %1 = tosa.mul %arg0, %ones, %shift : (tensor<2x3xi32>, tensor<2x3xi32>, tensor<i8>) -> tensor<2x3xi32>
   return %1 : tensor<2x3xi32>
 }
 
@@ -313,11 +326,11 @@ func.func @mul_zero_broadcast(%arg0: tensor<2x3xf32>) -> (tensor<2x3xf32>, tenso
   // CHECK: %[[ZERO:.*]] = "tosa.const"() <{value = dense<0.000000e+00> : tensor<2x3xf32>}
   // CHECK-NOT: tosa.mul
   %zeros = "tosa.const"() {value = dense<0.0> : tensor<1x1xf32>} : () -> tensor<1x1xf32>
-  %1 = tosa.mul %arg0, %zeros {shift = 0 : i8} : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
+  %1 = tosa.mul %arg0, %zeros : (tensor<2x3xf32>, tensor<1x1xf32>) -> tensor<2x3xf32>
 
   // CHECK-NOT: tosa.mul
   // CHECK: return %[[ZERO]], %[[ZERO]]
-  %2 = tosa.mul %zeros, %arg0 {shift = 0 : i8} : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
+  %2 = tosa.mul %zeros, %arg0 : (tensor<1x1xf32>, tensor<2x3xf32>) -> tensor<2x3xf32>
   return %1, %2 : tensor<2x3xf32>, tensor<2x3xf32>
 }
 
@@ -872,7 +885,7 @@ func.func @mul_quant_nofold() -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899
    // CHECK: tosa.mul
    %0 = "tosa.const"() {value = dense<0> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    %1 = "tosa.const"() {value = dense<1> : tensor<1xi8>} : () -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
-   %2 = tosa.mul %0, %1 { shift = 0 : i8} : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
+   %2 = tosa.mul %0, %1 : (tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>, tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>) -> tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
    return %2 : tensor<1x!quant.uniform<i8:f32, 3.0757404601899907E-5:-128>>
 }
 
diff --git a/mlir/test/Dialect/Tosa/constant-op-fold.mlir b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
index 2902c4a62009e9..b6947e05b26bf1 100644
--- a/mlir/test/Dialect/Tosa/constant-op-fold.mlir
+++ b/mlir/test/Dialect/Tosa/constant-op-fold.mlir
@@ -247,7 +247,7 @@ func.func @fold_div_splat_i32() -> tensor<i32> {
 func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
-  %mul = tosa.mul %arg0, %zero {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %mul = tosa.mul %arg0, %zero : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<f32>
 }
@@ -258,7 +258,7 @@ func.func @fold_mul_zero_rhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
   %zero = "tosa.const"() {value = dense<0.0> : tensor<f32>} : () -> tensor<f32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0.000000e+00>
-  %mul = tosa.mul %zero, %arg0 {shift = 0 : i8} : (tensor<f32>, tensor<f32>) -> tensor<f32>
+  %mul = tosa.mul %zero, %arg0 : (tensor<f32>, tensor<f32>) -> tensor<f32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<f32>
 }
@@ -269,7 +269,7 @@ func.func @fold_mul_zero_lhs_f32(%arg0: tensor<f32>) -> tensor<f32> {
 func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0>
-  %mul = tosa.mul %arg0, %zero {shift = 0 : i8} : (tensor<i32>, tensor<i32>) -> tensor<i32>
+  %mul = tosa.mul %arg0, %zero : (tensor<i32>, tensor<i32>) -> tensor<i32>
   // CHECK: return %[[ZERO]]
   return %mul : tensor<i32>
 }
@@ -280,7 +280,7 @@ func.func @fold_mul_zero_rhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
 func.func @fold_mul_zero_lhs_i32(%arg0: tensor<i32>) -> tensor<i32> {
   %zero = "tosa.const"() {value = dense<0> : tensor<i32>} : () -> tensor<i32>
   // CHECK: %[[ZERO:.+]] = "tosa.const"() <{value = dense<0>
-  %mul = tosa.mul %zero, %arg0 {shift = 0 : i8} : (te...
[truncated]

@github-actions
Copy link

github-actions bot commented Jan 7, 2025

✅ With the latest revision this PR passed the C/C++ code formatter.

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the PR @FranklandJack, just adding a couple of small comments on top of @GeorgeARM's

@FranklandJack FranklandJack force-pushed the mul_shift branch 2 times, most recently from a83d5a9 to ad05fa2 Compare January 27, 2025 14:48
The TOSA-v1.0 specification makes the shift attribute of the MUL
(Hammard product) operator an input. Move the `shift` parameter of the
MUL operator in the MILR TOSA dialect from an attribute to an input and
update any lit tests appropriately.

Expand the verifier of the `tosa::MulOp` operation to check the various
constraints defined in the TOSA-v1.0 specification. Specifically, ensure
that all input operands (excluding the optional shift) are of the same
rank. This means that broadcasting tests which previously checked rank-0
tensors would be broadcast are no longer valid and are removed.

Signed-off-by: Jack Frankland <[email protected]>
Copy link
Contributor

@GeorgeARM GeorgeARM left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Leaving to @lhutton1 for a final review as well. Cheers Jack

Copy link
Contributor

@lhutton1 lhutton1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@FranklandJack FranklandJack merged commit a58e774 into llvm:main Jan 28, 2025
8 checks passed
GleasonK pushed a commit to openxla/stablehlo that referenced this pull request Feb 1, 2025
…upstream and TOSA-v1.0 (#2702)

This PR updates the StableHLO to TOSA legalization pass to align with
the recent changes in the MLIR upstream. Specifically,
[the `tosa::MulOp` operation has being modified to comply with the
TOSA-v1.0
specification](llvm/llvm-project#121953). The
shift parameter of the MUL operator, which was previously an attribute,
has been moved to an SSA operand. 

The upstream changes caused a compilation failure in the StableHLO to
TOSA conversion pass because the `tosa.mul` operation now expects 3
operand groups instead of 2, as reflected in the error message:

```jsx
error: invalid number of operand groups for `tosa.mul`; expected 3, but got 2
     with op<tosa.mul>(input0, input1) {shift = attr<"0 : i8">};
          ^
```

This PR added a zero constant as the additional argument for the shift
parameter to maintain compatibility with the updated `tosa.mul`
operation.
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants