Skip to content

Conversation

Mogball
Copy link
Contributor

@Mogball Mogball commented Sep 12, 2025

When materializing integer ranges of splat tensors or vector as constants, they should use DenseElementsAttr of the shaped type, not IntegerAttrs of the element types, since this can violate the invariants of tensor/vector ops.

When materializing integer ranges of splat tensors or vector as
constants, they should use DenseElementsAttr of the shaped type, not
IntegerAttrs of the element types, since this can violate the invariants
of tensor/vector ops.
@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2025

@llvm/pr-subscribers-mlir

Author: Jeff Niu (Mogball)

Changes

When materializing integer ranges of splat tensors or vector as constants, they should use DenseElementsAttr of the shaped type, not IntegerAttrs of the element types, since this can violate the invariants of tensor/vector ops.


Full diff: https://github.com/llvm/llvm-project/pull/158359.diff

3 Files Affected:

  • (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+12-3)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+2)
  • (modified) mlir/test/Dialect/Arith/int-range-opts.mlir (+16)
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index e79f6a8aec1cf..70b56ca77b2da 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -26,6 +26,7 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Support/DebugStringHelper.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Casting.h"
@@ -76,9 +77,17 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
   else
     dialect = value.getParentBlock()->getParentOp()->getDialect();
 
-  Type type = getElementTypeOrSelf(value);
-  solver->propagateIfChanged(
-      cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
+  Attribute cstAttr;
+  if (isa<IntegerType, IndexType>(value.getType())) {
+    cstAttr = IntegerAttr::get(value.getType(), *constant);
+  } else if (auto shapedTy = dyn_cast<ShapedType>(value.getType())) {
+    cstAttr = SplatElementsAttr::get(shapedTy, *constant);
+  } else {
+    llvm::report_fatal_error(
+        Twine("FIXME: Don't know how to create a constant for this type: ") +
+        mlir::debugString(value.getType()));
+  }
+  solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect)));
 }
 
 LogicalResult IntegerRangeAnalysis::visitOperation(
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 777ff0ecaa314..2017905587b26 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -8,6 +8,7 @@
 
 #include <utility>
 
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
 #include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 
@@ -485,6 +486,7 @@ struct IntRangeOptimizationsPass final
     MLIRContext *ctx = op->getContext();
     DataFlowSolver solver;
     solver.load<DeadCodeAnalysis>();
+    solver.load<SparseConstantPropagation>();
     solver.load<IntegerRangeAnalysis>();
     if (failed(solver.initializeAndRun(op)))
       return signalPassFailure();
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index ea5969a100258..e6e48d30cece5 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -132,3 +132,19 @@ func.func @wraps() -> i8 {
   %mod = arith.remsi %val, %c64 : i8
   return %mod : i8
 }
+
+// -----
+
+// CHECK-LABEL: @analysis_crash
+func.func @analysis_crash(%arg0: i32, %arg1: tensor<128xi1>) -> tensor<128xi64> {
+  %c0_i32 = arith.constant 0 : i32
+  %cst = arith.constant dense<-1> : tensor<128xi32>
+  %splat = tensor.splat %arg0 : tensor<128xi32>
+  %0 = scf.for %arg2 = %c0_i32 to %arg0 step %arg0 iter_args(%arg3 = %splat) -> (tensor<128xi32>)  : i32 {
+    scf.yield %arg3 : tensor<128xi32>
+  }
+  %1 = arith.select %arg1, %0#0, %cst : tensor<128xi1>, tensor<128xi32>
+  // Make sure the analysis doesn't crash when materializing the range as a tensor constant.
+  %2 = arith.extsi %1 : tensor<128xi32> to tensor<128xi64>
+  return %2 : tensor<128xi64>
+}

@llvmbot
Copy link
Member

llvmbot commented Sep 12, 2025

@llvm/pr-subscribers-mlir-arith

Author: Jeff Niu (Mogball)

Changes

When materializing integer ranges of splat tensors or vector as constants, they should use DenseElementsAttr of the shaped type, not IntegerAttrs of the element types, since this can violate the invariants of tensor/vector ops.


Full diff: https://github.com/llvm/llvm-project/pull/158359.diff

3 Files Affected:

  • (modified) mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp (+12-3)
  • (modified) mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp (+2)
  • (modified) mlir/test/Dialect/Arith/int-range-opts.mlir (+16)
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index e79f6a8aec1cf..70b56ca77b2da 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -26,6 +26,7 @@
 #include "mlir/Interfaces/ControlFlowInterfaces.h"
 #include "mlir/Interfaces/InferIntRangeInterface.h"
 #include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Support/DebugStringHelper.h"
 #include "mlir/Support/LLVM.h"
 #include "llvm/ADT/STLExtras.h"
 #include "llvm/Support/Casting.h"
@@ -76,9 +77,17 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
   else
     dialect = value.getParentBlock()->getParentOp()->getDialect();
 
-  Type type = getElementTypeOrSelf(value);
-  solver->propagateIfChanged(
-      cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
+  Attribute cstAttr;
+  if (isa<IntegerType, IndexType>(value.getType())) {
+    cstAttr = IntegerAttr::get(value.getType(), *constant);
+  } else if (auto shapedTy = dyn_cast<ShapedType>(value.getType())) {
+    cstAttr = SplatElementsAttr::get(shapedTy, *constant);
+  } else {
+    llvm::report_fatal_error(
+        Twine("FIXME: Don't know how to create a constant for this type: ") +
+        mlir::debugString(value.getType()));
+  }
+  solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect)));
 }
 
 LogicalResult IntegerRangeAnalysis::visitOperation(
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 777ff0ecaa314..2017905587b26 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -8,6 +8,7 @@
 
 #include <utility>
 
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
 #include "mlir/Analysis/DataFlowFramework.h"
 #include "mlir/Dialect/Arith/Transforms/Passes.h"
 
@@ -485,6 +486,7 @@ struct IntRangeOptimizationsPass final
     MLIRContext *ctx = op->getContext();
     DataFlowSolver solver;
     solver.load<DeadCodeAnalysis>();
+    solver.load<SparseConstantPropagation>();
     solver.load<IntegerRangeAnalysis>();
     if (failed(solver.initializeAndRun(op)))
       return signalPassFailure();
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index ea5969a100258..e6e48d30cece5 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -132,3 +132,19 @@ func.func @wraps() -> i8 {
   %mod = arith.remsi %val, %c64 : i8
   return %mod : i8
 }
+
+// -----
+
+// CHECK-LABEL: @analysis_crash
+func.func @analysis_crash(%arg0: i32, %arg1: tensor<128xi1>) -> tensor<128xi64> {
+  %c0_i32 = arith.constant 0 : i32
+  %cst = arith.constant dense<-1> : tensor<128xi32>
+  %splat = tensor.splat %arg0 : tensor<128xi32>
+  %0 = scf.for %arg2 = %c0_i32 to %arg0 step %arg0 iter_args(%arg3 = %splat) -> (tensor<128xi32>)  : i32 {
+    scf.yield %arg3 : tensor<128xi32>
+  }
+  %1 = arith.select %arg1, %0#0, %cst : tensor<128xi1>, tensor<128xi32>
+  // Make sure the analysis doesn't crash when materializing the range as a tensor constant.
+  %2 = arith.extsi %1 : tensor<128xi32> to tensor<128xi64>
+  return %2 : tensor<128xi64>
+}

Copy link
Contributor

@ThomasRaoux ThomasRaoux left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

@Mogball Mogball merged commit 86bcd1c into main Sep 12, 2025
12 checks passed
@Mogball Mogball deleted the users/mogball/fix_intrange branch September 12, 2025 20:53
ThomasRaoux added a commit to triton-lang/triton that referenced this pull request Sep 12, 2025
Reverts #8144

This exposed a bug in MLIR upstream. This will get merged when we
integrate the fix:
llvm/llvm-project#158359
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Projects

None yet

Development

Successfully merging this pull request may close these issues.

4 participants