Skip to content

Commit 2217053

Browse files
[mlir][Intrange] Fix materializing ShapedType constant values
When materializing integer ranges of splat tensors or vector as constants, they should use DenseElementsAttr of the shaped type, not IntegerAttrs of the element types, since this can violate the invariants of tensor/vector ops.
1 parent 3706070 commit 2217053

File tree

3 files changed

+30
-3
lines changed

3 files changed

+30
-3
lines changed

mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp

Lines changed: 12 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Interfaces/ControlFlowInterfaces.h"
2727
#include "mlir/Interfaces/InferIntRangeInterface.h"
2828
#include "mlir/Interfaces/LoopLikeInterface.h"
29+
#include "mlir/Support/DebugStringHelper.h"
2930
#include "mlir/Support/LLVM.h"
3031
#include "llvm/ADT/STLExtras.h"
3132
#include "llvm/Support/Casting.h"
@@ -76,9 +77,17 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
7677
else
7778
dialect = value.getParentBlock()->getParentOp()->getDialect();
7879

79-
Type type = getElementTypeOrSelf(value);
80-
solver->propagateIfChanged(
81-
cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
80+
Attribute cstAttr;
81+
if (isa<IntegerType, IndexType>(value.getType())) {
82+
cstAttr = IntegerAttr::get(value.getType(), *constant);
83+
} else if (auto shapedTy = dyn_cast<ShapedType>(value.getType())) {
84+
cstAttr = SplatElementsAttr::get(shapedTy, *constant);
85+
} else {
86+
llvm::report_fatal_error(
87+
Twine("FIXME: Don't know how to create a constant for this type: ") +
88+
mlir::debugString(value.getType()));
89+
}
90+
solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect)));
8291
}
8392

8493
LogicalResult IntegerRangeAnalysis::visitOperation(

mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88

99
#include <utility>
1010

11+
#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
1112
#include "mlir/Analysis/DataFlowFramework.h"
1213
#include "mlir/Dialect/Arith/Transforms/Passes.h"
1314

@@ -485,6 +486,7 @@ struct IntRangeOptimizationsPass final
485486
MLIRContext *ctx = op->getContext();
486487
DataFlowSolver solver;
487488
solver.load<DeadCodeAnalysis>();
489+
solver.load<SparseConstantPropagation>();
488490
solver.load<IntegerRangeAnalysis>();
489491
if (failed(solver.initializeAndRun(op)))
490492
return signalPassFailure();

mlir/test/Dialect/Arith/int-range-opts.mlir

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,19 @@ func.func @wraps() -> i8 {
132132
%mod = arith.remsi %val, %c64 : i8
133133
return %mod : i8
134134
}
135+
136+
// -----
137+
138+
// CHECK-LABEL: @analysis_crash
139+
func.func @analysis_crash(%arg0: i32, %arg1: tensor<128xi1>) -> tensor<128xi64> {
140+
%c0_i32 = arith.constant 0 : i32
141+
%cst = arith.constant dense<-1> : tensor<128xi32>
142+
%splat = tensor.splat %arg0 : tensor<128xi32>
143+
%0 = scf.for %arg2 = %c0_i32 to %arg0 step %arg0 iter_args(%arg3 = %splat) -> (tensor<128xi32>) : i32 {
144+
scf.yield %arg3 : tensor<128xi32>
145+
}
146+
%1 = arith.select %arg1, %0#0, %cst : tensor<128xi1>, tensor<128xi32>
147+
// Make sure the analysis doesn't crash when materializing the range as a tensor constant.
148+
%2 = arith.extsi %1 : tensor<128xi32> to tensor<128xi64>
149+
return %2 : tensor<128xi64>
150+
}

0 commit comments

Comments
 (0)