Skip to content

Conversation

@Lewuathe
Copy link
Member

Similar to the issue reported in
#128299 (review), ExpandMath pattern for rsqrt expects the static shaped operands. Otherwise, it crashes due to the assertion violation.

See: #128299

@llvmbot
Copy link
Member

llvmbot commented Feb 27, 2025

@llvm/pr-subscribers-mlir

@llvm/pr-subscribers-mlir-math

Author: Kai Sasaki (Lewuathe)

Changes

Similar to the issue reported in
#128299 (review), ExpandMath pattern for rsqrt expects the static shaped operands. Otherwise, it crashes due to the assertion violation.

See: #128299


Full diff: https://github.com/llvm/llvm-project/pull/129006.diff

2 Files Affected:

  • (modified) mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp (+5)
  • (modified) mlir/test/Dialect/Math/expand-math.mlir (+26)
diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
index bb592c667549c..7b5350ca26b60 100644
--- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
+++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp
@@ -646,6 +646,11 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op,
 
   auto operand = op.getOperand();
   auto operandTy = operand.getType();
+  // Operand type must be shatic shaped type to create const float.
+  auto shapedOperandType = dyn_cast<ShapedType>(operandTy);
+  if (shapedOperandType && !shapedOperandType.hasStaticShape())
+    return failure();
+
   auto eTy = getElementTypeOrSelf(operandTy);
   if (!isa<FloatType>(eTy))
     return failure();
diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir
index 946a411e4cc4b..8743efec5ecb4 100644
--- a/mlir/test/Dialect/Math/expand-math.mlir
+++ b/mlir/test/Dialect/Math/expand-math.mlir
@@ -787,3 +787,29 @@ func.func @unranked_ceil_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
   %a = math.ceil %arg : tensor<*xf32>
   return %a: tensor<*xf32>
 }
+
+// -----
+
+// CHECK-LABEL:    func.func @non_static_shape_rsqrt_op
+// CHECK-SAME:     (%[[ARG:.*]]: tensor<?xf32>)
+// CHECK-SAME:     -> tensor<?xf32>
+// CHECK:          %[[CEIL:.*]] = math.rsqrt %[[ARG]] : tensor<?xf32>
+// CHECK:          return %[[CEIL]] : tensor<?xf32>
+
+func.func @non_static_shape_rsqrt_op(%arg: tensor<?xf32>) -> tensor<?xf32>{
+  %a = math.rsqrt %arg : tensor<?xf32>
+  return %a: tensor<?xf32>
+}
+
+// -----
+
+// CHECK-LABEL:    func.func @unranked_rsqrt_op
+// CHECK-SAME:     (%[[ARG:.*]]: tensor<*xf32>)
+// CHECK-SAME:     -> tensor<*xf32>
+// CHECK:          %[[CEIL:.*]] = math.rsqrt %[[ARG]] : tensor<*xf32>
+// CHECK:          return %[[CEIL]] : tensor<*xf32>
+
+func.func @unranked_rsqrt_op(%arg: tensor<*xf32>) -> tensor<*xf32>{
+  %a = math.rsqrt %arg : tensor<*xf32>
+  return %a: tensor<*xf32>
+}

Similar to the issue reported in
https://github.com/llvm/llvm-project/pull/128299/files, ExpandMath
pattern for rsqrt expects the static shaped operands. Otherwise, it
crashes due to the assertion violation.
@Lewuathe Lewuathe force-pushed the rsqrt-in-expand-math-expect-static-shape branch from 4d899de to d01daf2 Compare February 27, 2025 06:02
Copy link
Contributor

@cferry-AMD cferry-AMD left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That was fast. Thanks once more!

@Lewuathe Lewuathe merged commit 55f2547 into llvm:main Feb 28, 2025
11 checks passed
@Lewuathe Lewuathe deleted the rsqrt-in-expand-math-expect-static-shape branch February 28, 2025 04:37
cheezeburglar pushed a commit to cheezeburglar/llvm-project that referenced this pull request Feb 28, 2025
…vm#129006)

Similar to the issue reported in

llvm#128299 (review),
ExpandMath pattern for rsqrt expects the static shaped operands.
Otherwise, it crashes due to the assertion violation.

See: llvm#128299
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants