-
Notifications
You must be signed in to change notification settings - Fork 14.9k
[mlir][Intrange] Fix materializing ShapedType constant values #158359
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Conversation
When materializing integer ranges of splat tensors or vector as constants, they should use DenseElementsAttr of the shaped type, not IntegerAttrs of the element types, since this can violate the invariants of tensor/vector ops.
@llvm/pr-subscribers-mlir Author: Jeff Niu (Mogball) ChangesWhen materializing integer ranges of splat tensors or vector as constants, they should use DenseElementsAttr of the shaped type, not IntegerAttrs of the element types, since this can violate the invariants of tensor/vector ops. Full diff: https://github.com/llvm/llvm-project/pull/158359.diff 3 Files Affected:
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index e79f6a8aec1cf..70b56ca77b2da 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -26,6 +26,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Support/DebugStringHelper.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
@@ -76,9 +77,17 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
else
dialect = value.getParentBlock()->getParentOp()->getDialect();
- Type type = getElementTypeOrSelf(value);
- solver->propagateIfChanged(
- cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
+ Attribute cstAttr;
+ if (isa<IntegerType, IndexType>(value.getType())) {
+ cstAttr = IntegerAttr::get(value.getType(), *constant);
+ } else if (auto shapedTy = dyn_cast<ShapedType>(value.getType())) {
+ cstAttr = SplatElementsAttr::get(shapedTy, *constant);
+ } else {
+ llvm::report_fatal_error(
+ Twine("FIXME: Don't know how to create a constant for this type: ") +
+ mlir::debugString(value.getType()));
+ }
+ solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect)));
}
LogicalResult IntegerRangeAnalysis::visitOperation(
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 777ff0ecaa314..2017905587b26 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -8,6 +8,7 @@
#include <utility>
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
@@ -485,6 +486,7 @@ struct IntRangeOptimizationsPass final
MLIRContext *ctx = op->getContext();
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
+ solver.load<SparseConstantPropagation>();
solver.load<IntegerRangeAnalysis>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index ea5969a100258..e6e48d30cece5 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -132,3 +132,19 @@ func.func @wraps() -> i8 {
%mod = arith.remsi %val, %c64 : i8
return %mod : i8
}
+
+// -----
+
+// CHECK-LABEL: @analysis_crash
+func.func @analysis_crash(%arg0: i32, %arg1: tensor<128xi1>) -> tensor<128xi64> {
+ %c0_i32 = arith.constant 0 : i32
+ %cst = arith.constant dense<-1> : tensor<128xi32>
+ %splat = tensor.splat %arg0 : tensor<128xi32>
+ %0 = scf.for %arg2 = %c0_i32 to %arg0 step %arg0 iter_args(%arg3 = %splat) -> (tensor<128xi32>) : i32 {
+ scf.yield %arg3 : tensor<128xi32>
+ }
+ %1 = arith.select %arg1, %0#0, %cst : tensor<128xi1>, tensor<128xi32>
+ // Make sure the analysis doesn't crash when materializing the range as a tensor constant.
+ %2 = arith.extsi %1 : tensor<128xi32> to tensor<128xi64>
+ return %2 : tensor<128xi64>
+}
|
@llvm/pr-subscribers-mlir-arith Author: Jeff Niu (Mogball) ChangesWhen materializing integer ranges of splat tensors or vector as constants, they should use DenseElementsAttr of the shaped type, not IntegerAttrs of the element types, since this can violate the invariants of tensor/vector ops. Full diff: https://github.com/llvm/llvm-project/pull/158359.diff 3 Files Affected:
diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
index e79f6a8aec1cf..70b56ca77b2da 100644
--- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
+++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
@@ -26,6 +26,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
+#include "mlir/Support/DebugStringHelper.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
@@ -76,9 +77,17 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
else
dialect = value.getParentBlock()->getParentOp()->getDialect();
- Type type = getElementTypeOrSelf(value);
- solver->propagateIfChanged(
- cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
+ Attribute cstAttr;
+ if (isa<IntegerType, IndexType>(value.getType())) {
+ cstAttr = IntegerAttr::get(value.getType(), *constant);
+ } else if (auto shapedTy = dyn_cast<ShapedType>(value.getType())) {
+ cstAttr = SplatElementsAttr::get(shapedTy, *constant);
+ } else {
+ llvm::report_fatal_error(
+ Twine("FIXME: Don't know how to create a constant for this type: ") +
+ mlir::debugString(value.getType()));
+ }
+ solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect)));
}
LogicalResult IntegerRangeAnalysis::visitOperation(
diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
index 777ff0ecaa314..2017905587b26 100644
--- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
+++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
@@ -8,6 +8,7 @@
#include <utility>
+#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"
@@ -485,6 +486,7 @@ struct IntRangeOptimizationsPass final
MLIRContext *ctx = op->getContext();
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
+ solver.load<SparseConstantPropagation>();
solver.load<IntegerRangeAnalysis>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
diff --git a/mlir/test/Dialect/Arith/int-range-opts.mlir b/mlir/test/Dialect/Arith/int-range-opts.mlir
index ea5969a100258..e6e48d30cece5 100644
--- a/mlir/test/Dialect/Arith/int-range-opts.mlir
+++ b/mlir/test/Dialect/Arith/int-range-opts.mlir
@@ -132,3 +132,19 @@ func.func @wraps() -> i8 {
%mod = arith.remsi %val, %c64 : i8
return %mod : i8
}
+
+// -----
+
+// CHECK-LABEL: @analysis_crash
+func.func @analysis_crash(%arg0: i32, %arg1: tensor<128xi1>) -> tensor<128xi64> {
+ %c0_i32 = arith.constant 0 : i32
+ %cst = arith.constant dense<-1> : tensor<128xi32>
+ %splat = tensor.splat %arg0 : tensor<128xi32>
+ %0 = scf.for %arg2 = %c0_i32 to %arg0 step %arg0 iter_args(%arg3 = %splat) -> (tensor<128xi32>) : i32 {
+ scf.yield %arg3 : tensor<128xi32>
+ }
+ %1 = arith.select %arg1, %0#0, %cst : tensor<128xi1>, tensor<128xi32>
+ // Make sure the analysis doesn't crash when materializing the range as a tensor constant.
+ %2 = arith.extsi %1 : tensor<128xi32> to tensor<128xi64>
+ return %2 : tensor<128xi64>
+}
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM!
Reverts #8144 This exposed a bug in MLIR upstream. This will get merged when we integrate the fix: llvm/llvm-project#158359
When materializing integer ranges of splat tensors or vector as constants, they should use DenseElementsAttr of the shaped type, not IntegerAttrs of the element types, since this can violate the invariants of tensor/vector ops.