From d01daf2396a5affe8e7e2dc6be421958f4e714cf Mon Sep 17 00:00:00 2001 From: Kai Sasaki Date: Thu, 27 Feb 2025 14:54:25 +0900 Subject: [PATCH] [mlir][math] Rsqrt math expand pass expects static shaped operand Similar to the issue reported in https://github.com/llvm/llvm-project/pull/128299/files, ExpandMath pattern for rsqrt expects the static shaped operands. Otherwise, it crashes due to the assertion violation. --- .../Math/Transforms/ExpandPatterns.cpp | 5 ++++ mlir/test/Dialect/Math/expand-math.mlir | 26 +++++++++++++++++++ 2 files changed, 31 insertions(+) diff --git a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp index bb592c667549c..7b5350ca26b60 100644 --- a/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp +++ b/mlir/lib/Dialect/Math/Transforms/ExpandPatterns.cpp @@ -646,6 +646,11 @@ static LogicalResult convertRsqrtOp(math::RsqrtOp op, auto operand = op.getOperand(); auto operandTy = operand.getType(); + // Operand type must be shatic shaped type to create const float. + auto shapedOperandType = dyn_cast(operandTy); + if (shapedOperandType && !shapedOperandType.hasStaticShape()) + return failure(); + auto eTy = getElementTypeOrSelf(operandTy); if (!isa(eTy)) return failure(); diff --git a/mlir/test/Dialect/Math/expand-math.mlir b/mlir/test/Dialect/Math/expand-math.mlir index 946a411e4cc4b..1420acaa40d35 100644 --- a/mlir/test/Dialect/Math/expand-math.mlir +++ b/mlir/test/Dialect/Math/expand-math.mlir @@ -787,3 +787,29 @@ func.func @unranked_ceil_op(%arg: tensor<*xf32>) -> tensor<*xf32>{ %a = math.ceil %arg : tensor<*xf32> return %a: tensor<*xf32> } + +// ----- + +// CHECK-LABEL: func.func @non_static_shape_rsqrt_op +// CHECK-SAME: (%[[ARG:.*]]: tensor) +// CHECK-SAME: -> tensor +// CHECK: %[[RSQRT:.*]] = math.rsqrt %[[ARG]] : tensor +// CHECK: return %[[RSQRT]] : tensor + +func.func @non_static_shape_rsqrt_op(%arg: tensor) -> tensor{ + %a = math.rsqrt %arg : tensor + return %a: tensor +} + +// ----- + +// CHECK-LABEL: func.func @unranked_rsqrt_op +// CHECK-SAME: (%[[ARG:.*]]: tensor<*xf32>) +// CHECK-SAME: -> tensor<*xf32> +// CHECK: %[[RSQRT:.*]] = math.rsqrt %[[ARG]] : tensor<*xf32> +// CHECK: return %[[RSQRT]] : tensor<*xf32> + +func.func @unranked_rsqrt_op(%arg: tensor<*xf32>) -> tensor<*xf32>{ + %a = math.rsqrt %arg : tensor<*xf32> + return %a: tensor<*xf32> +}