Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 12 additions & 3 deletions mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
#include "mlir/Interfaces/ControlFlowInterfaces.h"
#include "mlir/Interfaces/InferIntRangeInterface.h"
#include "mlir/Interfaces/LoopLikeInterface.h"
#include "mlir/Support/DebugStringHelper.h"
#include "mlir/Support/LLVM.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/Support/Casting.h"
Expand Down Expand Up @@ -76,9 +77,17 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const {
else
dialect = value.getParentBlock()->getParentOp()->getDialect();

Type type = getElementTypeOrSelf(value);
solver->propagateIfChanged(
cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect)));
Attribute cstAttr;
if (isa<IntegerType, IndexType>(value.getType())) {
cstAttr = IntegerAttr::get(value.getType(), *constant);
} else if (auto shapedTy = dyn_cast<ShapedType>(value.getType())) {
cstAttr = SplatElementsAttr::get(shapedTy, *constant);
} else {
llvm::report_fatal_error(
Twine("FIXME: Don't know how to create a constant for this type: ") +
mlir::debugString(value.getType()));
}
solver->propagateIfChanged(cv, cv->join(ConstantValue(cstAttr, dialect)));
}

LogicalResult IntegerRangeAnalysis::visitOperation(
Expand Down
2 changes: 2 additions & 0 deletions mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@

#include <utility>

#include "mlir/Analysis/DataFlow/ConstantPropagationAnalysis.h"
#include "mlir/Analysis/DataFlowFramework.h"
#include "mlir/Dialect/Arith/Transforms/Passes.h"

Expand Down Expand Up @@ -485,6 +486,7 @@ struct IntRangeOptimizationsPass final
MLIRContext *ctx = op->getContext();
DataFlowSolver solver;
solver.load<DeadCodeAnalysis>();
solver.load<SparseConstantPropagation>();
solver.load<IntegerRangeAnalysis>();
if (failed(solver.initializeAndRun(op)))
return signalPassFailure();
Expand Down
16 changes: 16 additions & 0 deletions mlir/test/Dialect/Arith/int-range-opts.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -132,3 +132,19 @@ func.func @wraps() -> i8 {
%mod = arith.remsi %val, %c64 : i8
return %mod : i8
}

// -----

// CHECK-LABEL: @analysis_crash
func.func @analysis_crash(%arg0: i32, %arg1: tensor<128xi1>) -> tensor<128xi64> {
%c0_i32 = arith.constant 0 : i32
%cst = arith.constant dense<-1> : tensor<128xi32>
%splat = tensor.splat %arg0 : tensor<128xi32>
%0 = scf.for %arg2 = %c0_i32 to %arg0 step %arg0 iter_args(%arg3 = %splat) -> (tensor<128xi32>) : i32 {
scf.yield %arg3 : tensor<128xi32>
}
%1 = arith.select %arg1, %0#0, %cst : tensor<128xi1>, tensor<128xi32>
// Make sure the analysis doesn't crash when materializing the range as a tensor constant.
%2 = arith.extsi %1 : tensor<128xi32> to tensor<128xi64>
return %2 : tensor<128xi64>
}
Loading