Skip to content

Conversation

@CoTinker
Copy link
Contributor

This PR fixes a crash in getSemiAffineExprFromFlatForm when localExpr is not AffineBinaryOpExpr. Fixes #144091.

This PR fixes a crash in `getSemiAffineExprFromFlatForm`
when localExpr is not `AffineBinaryOpExpr`.
@llvmbot llvmbot added mlir:core MLIR Core Infrastructure mlir:affine mlir labels Jun 21, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 21, 2025

@llvm/pr-subscribers-mlir

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes a crash in getSemiAffineExprFromFlatForm when localExpr is not AffineBinaryOpExpr. Fixes #144091.


Full diff: https://github.com/llvm/llvm-project/pull/145162.diff

2 Files Affected:

  • (modified) mlir/lib/IR/AffineExpr.cpp (+7-3)
  • (modified) mlir/test/Dialect/Affine/simplify-structures.mlir (+16)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index c8d9761511bec..cc81f9d19aca7 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1174,11 +1174,15 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
   // the indices in `coefficients` map, and affine expression corresponding to
   // in indices in `indexToExprMap` map.
   for (const auto &it : llvm::enumerate(localExprs)) {
-    AffineExpr expr = it.value();
     if (flatExprs[numDims + numSymbols + it.index()] == 0)
       continue;
-    AffineExpr lhs = cast<AffineBinaryOpExpr>(expr).getLHS();
-    AffineExpr rhs = cast<AffineBinaryOpExpr>(expr).getRHS();
+    AffineExpr expr = it.value();
+    auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+    if (!binaryExpr)
+      continue;
+
+    AffineExpr lhs = binaryExpr.getLHS();
+    AffineExpr rhs = binaryExpr.getRHS();
     if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
           (isa<AffineDimExpr>(rhs) || isa<AffineSymbolExpr>(rhs) ||
            isa<AffineConstantExpr>(rhs)))) {
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index e4a8512b002ee..6f2737a982752 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -592,3 +592,19 @@ func.func @semiaffine_modulo_dim(%arg0: index, %arg1: index, %arg2: index) -> in
   //CHECK: affine.apply #[[$MAP]]()[%{{.*}}, %{{.*}}, %{{.*}}]
   return %a : index
 }
+
+// -----
+
+// CHECK-LABEL: func @semiaffine_simplification_floordiv_and_ceildiv_const
+func.func @semiaffine_simplification_floordiv_and_ceildiv_const(%arg0: tensor<?xf32>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c13 = arith.constant 13 : index
+  %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+  %a = affine.apply affine_map<()[s0, s1, s2] -> (s0 floordiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
+  %b = affine.apply affine_map<()[s0, s1, s2] -> (s0 ceildiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
+  // CHECK:      %[[C6:.*]] = arith.constant 6 : index
+  // CHECK-NEXT: %[[C7:.*]] = arith.constant 7 : index
+  // CHECK-NEXT: return %[[C6]], %[[C7]]
+  return %a, %b : index, index
+}

@llvmbot
Copy link
Member

llvmbot commented Jun 21, 2025

@llvm/pr-subscribers-mlir-core

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes a crash in getSemiAffineExprFromFlatForm when localExpr is not AffineBinaryOpExpr. Fixes #144091.


Full diff: https://github.com/llvm/llvm-project/pull/145162.diff

2 Files Affected:

  • (modified) mlir/lib/IR/AffineExpr.cpp (+7-3)
  • (modified) mlir/test/Dialect/Affine/simplify-structures.mlir (+16)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index c8d9761511bec..cc81f9d19aca7 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1174,11 +1174,15 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
   // the indices in `coefficients` map, and affine expression corresponding to
   // in indices in `indexToExprMap` map.
   for (const auto &it : llvm::enumerate(localExprs)) {
-    AffineExpr expr = it.value();
     if (flatExprs[numDims + numSymbols + it.index()] == 0)
       continue;
-    AffineExpr lhs = cast<AffineBinaryOpExpr>(expr).getLHS();
-    AffineExpr rhs = cast<AffineBinaryOpExpr>(expr).getRHS();
+    AffineExpr expr = it.value();
+    auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+    if (!binaryExpr)
+      continue;
+
+    AffineExpr lhs = binaryExpr.getLHS();
+    AffineExpr rhs = binaryExpr.getRHS();
     if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
           (isa<AffineDimExpr>(rhs) || isa<AffineSymbolExpr>(rhs) ||
            isa<AffineConstantExpr>(rhs)))) {
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index e4a8512b002ee..6f2737a982752 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -592,3 +592,19 @@ func.func @semiaffine_modulo_dim(%arg0: index, %arg1: index, %arg2: index) -> in
   //CHECK: affine.apply #[[$MAP]]()[%{{.*}}, %{{.*}}, %{{.*}}]
   return %a : index
 }
+
+// -----
+
+// CHECK-LABEL: func @semiaffine_simplification_floordiv_and_ceildiv_const
+func.func @semiaffine_simplification_floordiv_and_ceildiv_const(%arg0: tensor<?xf32>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c13 = arith.constant 13 : index
+  %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+  %a = affine.apply affine_map<()[s0, s1, s2] -> (s0 floordiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
+  %b = affine.apply affine_map<()[s0, s1, s2] -> (s0 ceildiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
+  // CHECK:      %[[C6:.*]] = arith.constant 6 : index
+  // CHECK-NEXT: %[[C7:.*]] = arith.constant 7 : index
+  // CHECK-NEXT: return %[[C6]], %[[C7]]
+  return %a, %b : index, index
+}

@llvmbot
Copy link
Member

llvmbot commented Jun 21, 2025

@llvm/pr-subscribers-mlir-affine

Author: Longsheng Mou (CoTinker)

Changes

This PR fixes a crash in getSemiAffineExprFromFlatForm when localExpr is not AffineBinaryOpExpr. Fixes #144091.


Full diff: https://github.com/llvm/llvm-project/pull/145162.diff

2 Files Affected:

  • (modified) mlir/lib/IR/AffineExpr.cpp (+7-3)
  • (modified) mlir/test/Dialect/Affine/simplify-structures.mlir (+16)
diff --git a/mlir/lib/IR/AffineExpr.cpp b/mlir/lib/IR/AffineExpr.cpp
index c8d9761511bec..cc81f9d19aca7 100644
--- a/mlir/lib/IR/AffineExpr.cpp
+++ b/mlir/lib/IR/AffineExpr.cpp
@@ -1174,11 +1174,15 @@ static AffineExpr getSemiAffineExprFromFlatForm(ArrayRef<int64_t> flatExprs,
   // the indices in `coefficients` map, and affine expression corresponding to
   // in indices in `indexToExprMap` map.
   for (const auto &it : llvm::enumerate(localExprs)) {
-    AffineExpr expr = it.value();
     if (flatExprs[numDims + numSymbols + it.index()] == 0)
       continue;
-    AffineExpr lhs = cast<AffineBinaryOpExpr>(expr).getLHS();
-    AffineExpr rhs = cast<AffineBinaryOpExpr>(expr).getRHS();
+    AffineExpr expr = it.value();
+    auto binaryExpr = dyn_cast<AffineBinaryOpExpr>(expr);
+    if (!binaryExpr)
+      continue;
+
+    AffineExpr lhs = binaryExpr.getLHS();
+    AffineExpr rhs = binaryExpr.getRHS();
     if (!((isa<AffineDimExpr>(lhs) || isa<AffineSymbolExpr>(lhs)) &&
           (isa<AffineDimExpr>(rhs) || isa<AffineSymbolExpr>(rhs) ||
            isa<AffineConstantExpr>(rhs)))) {
diff --git a/mlir/test/Dialect/Affine/simplify-structures.mlir b/mlir/test/Dialect/Affine/simplify-structures.mlir
index e4a8512b002ee..6f2737a982752 100644
--- a/mlir/test/Dialect/Affine/simplify-structures.mlir
+++ b/mlir/test/Dialect/Affine/simplify-structures.mlir
@@ -592,3 +592,19 @@ func.func @semiaffine_modulo_dim(%arg0: index, %arg1: index, %arg2: index) -> in
   //CHECK: affine.apply #[[$MAP]]()[%{{.*}}, %{{.*}}, %{{.*}}]
   return %a : index
 }
+
+// -----
+
+// CHECK-LABEL: func @semiaffine_simplification_floordiv_and_ceildiv_const
+func.func @semiaffine_simplification_floordiv_and_ceildiv_const(%arg0: tensor<?xf32>) -> (index, index) {
+  %c0 = arith.constant 0 : index
+  %c1 = arith.constant 1 : index
+  %c13 = arith.constant 13 : index
+  %dim = tensor.dim %arg0, %c0 : tensor<?xf32>
+  %a = affine.apply affine_map<()[s0, s1, s2] -> (s0 floordiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
+  %b = affine.apply affine_map<()[s0, s1, s2] -> (s0 ceildiv (s1 + (-s1 + 2) * (-s1 + s1 * s2 + 1)))>()[%c13, %dim, %c1]
+  // CHECK:      %[[C6:.*]] = arith.constant 6 : index
+  // CHECK-NEXT: %[[C7:.*]] = arith.constant 7 : index
+  // CHECK-NEXT: return %[[C6]], %[[C7]]
+  return %a, %b : index, index
+}

Copy link
Member

@lipracer lipracer left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, LGTM if CI green.

Copy link
Contributor

@abdulraheembeigh abdulraheembeigh left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM. Thanks

@CoTinker CoTinker merged commit b00ddce into llvm:main Jun 23, 2025
11 checks passed
@CoTinker CoTinker deleted the affine_expr branch June 23, 2025 01:38
@arnab-polymage
Copy link
Contributor

The fix is not correct. It simply hides the original bug. Local expression is always supposed to be a binary expression -- it cannot be dimension, symbol or constant. This PR fixes this.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

mlir:affine mlir:core MLIR Core Infrastructure mlir

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[mlir] Canonicalizer crashed in mlir::simplifyAffineExpr with Assertion `isa<To>(Val) && "cast<Ty>() argument of incompatible type!"' failed.

5 participants