diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td index b54a8b7fe8680..3f45d0804e045 100644 --- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td +++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td @@ -13,20 +13,21 @@ #ifndef MLIR_DIALECT_VECTOR_IR_VECTOR_OPS #define MLIR_DIALECT_VECTOR_IR_VECTOR_OPS -include "mlir/Dialect/Vector/IR/Vector.td" -include "mlir/Dialect/Vector/IR/VectorAttributes.td" include "mlir/Dialect/Arith/IR/ArithBase.td" include "mlir/Dialect/Arith/IR/ArithOpsInterfaces.td" include "mlir/Dialect/Vector/Interfaces/MaskableOpInterface.td" include "mlir/Dialect/Vector/Interfaces/MaskingOpInterface.td" -include "mlir/IR/EnumAttr.td" +include "mlir/Dialect/Vector/IR/Vector.td" +include "mlir/Dialect/Vector/IR/VectorAttributes.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/DestinationStyleOpInterface.td" +include "mlir/Interfaces/InferIntRangeInterface.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" include "mlir/Interfaces/VectorInterfaces.td" include "mlir/Interfaces/ViewLikeInterface.td" include "mlir/IR/BuiltinAttributes.td" +include "mlir/IR/EnumAttr.td" // TODO: Add an attribute to specify a different algebra with operators other // than the current set: {*, +}. @@ -346,6 +347,7 @@ def Vector_MultiDimReductionOp : def Vector_BroadcastOp : Vector_Op<"broadcast", [Pure, + DeclareOpInterfaceMethods, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>]>, Arguments<(ins AnyType:$source)>, @@ -627,6 +629,7 @@ def Vector_DeinterleaveOp : def Vector_ExtractElementOp : Vector_Op<"extractelement", [Pure, + DeclareOpInterfaceMethods, TypesMatchWith<"result type matches element type of vector operand", "vector", "result", "::llvm::cast($_self).getElementType()">]>, @@ -673,6 +676,7 @@ def Vector_ExtractElementOp : def Vector_ExtractOp : Vector_Op<"extract", [Pure, + DeclareOpInterfaceMethods, PredOpTrait<"operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, InferTypeOpAdaptorWithIsCompatible]> { @@ -810,6 +814,7 @@ def Vector_FromElementsOp : Vector_Op<"from_elements", [ def Vector_InsertElementOp : Vector_Op<"insertelement", [Pure, + DeclareOpInterfaceMethods, TypesMatchWith<"source operand type matches element type of result", "result", "source", "::llvm::cast($_self).getElementType()">, @@ -858,6 +863,7 @@ def Vector_InsertElementOp : def Vector_InsertOp : Vector_Op<"insert", [Pure, + DeclareOpInterfaceMethods, PredOpTrait<"source operand and result have same element type", TCresVTEtIsSameAsOpBase<0, 0>>, AllTypesMatch<["dest", "result"]>]> { @@ -2204,7 +2210,9 @@ def Vector_CompressStoreOp : } def Vector_ShapeCastOp : - Vector_Op<"shape_cast", [Pure]>, + Vector_Op<"shape_cast", [Pure, + DeclareOpInterfaceMethods + ]>, Arguments<(ins AnyVectorOfAnyRank:$source)>, Results<(outs AnyVectorOfAnyRank:$result)> { let summary = "shape_cast casts between vector shapes"; @@ -2801,6 +2809,7 @@ def Vector_FlatTransposeOp : Vector_Op<"flat_transpose", [Pure, def Vector_SplatOp : Vector_Op<"splat", [ Pure, + DeclareOpInterfaceMethods, TypesMatchWith<"operand type matches element type of result", "aggregate", "input", "::llvm::cast($_self).getElementType()"> diff --git a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp index bf9eabbedc3a1..a97e43708d9a3 100644 --- a/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp +++ b/mlir/lib/Analysis/DataFlow/IntegerRangeAnalysis.cpp @@ -19,6 +19,7 @@ #include "mlir/IR/BuiltinAttributes.h" #include "mlir/IR/Dialect.h" #include "mlir/IR/OpDefinition.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/IR/Value.h" #include "mlir/Interfaces/ControlFlowInterfaces.h" #include "mlir/Interfaces/InferIntRangeInterface.h" @@ -53,9 +54,10 @@ void IntegerValueRangeLattice::onUpdate(DataFlowSolver *solver) const { dialect = parent->getDialect(); else dialect = value.getParentBlock()->getParentOp()->getDialect(); + + Type type = getElementTypeOrSelf(value); solver->propagateIfChanged( - cv, cv->join(ConstantValue(IntegerAttr::get(value.getType(), *constant), - dialect))); + cv, cv->join(ConstantValue(IntegerAttr::get(type, *constant), dialect))); } LogicalResult IntegerRangeAnalysis::visitOperation( diff --git a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp index 462044417b5fb..8682294c8a697 100644 --- a/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp +++ b/mlir/lib/Dialect/Arith/IR/InferIntRangeInterfaceImpls.cpp @@ -35,10 +35,22 @@ convertArithOverflowFlags(arith::IntegerOverflowFlags flags) { void arith::ConstantOp::inferResultRanges(ArrayRef argRanges, SetIntRangeFn setResultRange) { - auto constAttr = llvm::dyn_cast_or_null(getValue()); - if (constAttr) { - const APInt &value = constAttr.getValue(); + if (auto scalarCstAttr = llvm::dyn_cast_or_null(getValue())) { + const APInt &value = scalarCstAttr.getValue(); setResultRange(getResult(), ConstantIntRanges::constant(value)); + return; + } + if (auto arrayCstAttr = + llvm::dyn_cast_or_null(getValue())) { + std::optional result; + for (const APInt &val : arrayCstAttr) { + auto range = ConstantIntRanges::constant(val); + result = (result ? result->rangeUnion(range) : range); + } + + assert(result && "Zero-sized vectors are not allowed"); + setResultRange(getResult(), *result); + return; } } diff --git a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp index 521138c1f6f4c..d494bba081f80 100644 --- a/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp +++ b/mlir/lib/Dialect/Arith/Transforms/IntRangeOptimizations.cpp @@ -51,21 +51,27 @@ static LogicalResult maybeReplaceWithConstant(DataFlowSolver &solver, if (!maybeConstValue.has_value()) return failure(); + Type type = value.getType(); + Location loc = value.getLoc(); Operation *maybeDefiningOp = value.getDefiningOp(); Dialect *valueDialect = maybeDefiningOp ? maybeDefiningOp->getDialect() : value.getParentRegion()->getParentOp()->getDialect(); - Attribute constAttr = - rewriter.getIntegerAttr(value.getType(), *maybeConstValue); - Operation *constOp = valueDialect->materializeConstant( - rewriter, constAttr, value.getType(), value.getLoc()); + + Attribute constAttr; + if (auto shaped = dyn_cast(type)) { + constAttr = mlir::DenseIntElementsAttr::get(shaped, *maybeConstValue); + } else { + constAttr = rewriter.getIntegerAttr(type, *maybeConstValue); + } + Operation *constOp = + valueDialect->materializeConstant(rewriter, constAttr, type, loc); // Fall back to arith.constant if the dialect materializer doesn't know what // to do with an integer constant. if (!constOp) constOp = rewriter.getContext() ->getLoadedDialect() - ->materializeConstant(rewriter, constAttr, value.getType(), - value.getLoc()); + ->materializeConstant(rewriter, constAttr, type, loc); if (!constOp) return failure(); diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 5d018bdbe0b24..d8913251e56e9 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -1221,6 +1221,11 @@ void ContractionOp::getCanonicalizationPatterns(RewritePatternSet &results, // ExtractElementOp //===----------------------------------------------------------------------===// +void ExtractElementOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), argRanges.front()); +} + void vector::ExtractElementOp::build(OpBuilder &builder, OperationState &result, Value source) { result.addOperands({source}); @@ -1273,6 +1278,11 @@ OpFoldResult vector::ExtractElementOp::fold(FoldAdaptor adaptor) { // ExtractOp //===----------------------------------------------------------------------===// +void ExtractOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), argRanges.front()); +} + void vector::ExtractOp::build(OpBuilder &builder, OperationState &result, Value source, int64_t position) { build(builder, result, source, ArrayRef{position}); @@ -2252,6 +2262,11 @@ void FromElementsOp::getCanonicalizationPatterns(RewritePatternSet &results, // BroadcastOp //===----------------------------------------------------------------------===// +void BroadcastOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), argRanges.front()); +} + /// Return the dimensions of the result vector that were formerly ones in the /// source tensor and thus correspond to "dim-1" broadcasting. static llvm::SetVector @@ -2713,6 +2728,11 @@ void ShuffleOp::getCanonicalizationPatterns(RewritePatternSet &results, // InsertElementOp //===----------------------------------------------------------------------===// +void InsertElementOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1])); +} + void InsertElementOp::build(OpBuilder &builder, OperationState &result, Value source, Value dest) { build(builder, result, source, dest, {}); @@ -2762,6 +2782,11 @@ OpFoldResult vector::InsertElementOp::fold(FoldAdaptor adaptor) { // InsertOp //===----------------------------------------------------------------------===// +void vector::InsertOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), argRanges[0].rangeUnion(argRanges[1])); +} + void vector::InsertOp::build(OpBuilder &builder, OperationState &result, Value source, Value dest, int64_t position) { build(builder, result, source, dest, ArrayRef{position}); @@ -5277,6 +5302,11 @@ void CompressStoreOp::getCanonicalizationPatterns(RewritePatternSet &results, // ShapeCastOp //===----------------------------------------------------------------------===// +void ShapeCastOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), argRanges.front()); +} + /// Returns true if each element of 'a' is equal to the product of a contiguous /// sequence of the elements of 'b'. Returns false otherwise. static bool isValidShapeCast(ArrayRef a, ArrayRef b) { @@ -6423,6 +6453,11 @@ OpFoldResult SplatOp::fold(FoldAdaptor adaptor) { return SplatElementsAttr::get(getType(), {constOperand}); } +void SplatOp::inferResultRanges(ArrayRef argRanges, + SetIntRangeFn setResultRanges) { + setResultRanges(getResult(), argRanges.front()); +} + //===----------------------------------------------------------------------===// // WarpExecuteOnLane0Op //===----------------------------------------------------------------------===// diff --git a/mlir/test/Dialect/Vector/int-range-interface.mlir b/mlir/test/Dialect/Vector/int-range-interface.mlir new file mode 100644 index 0000000000000..29282423089ba --- /dev/null +++ b/mlir/test/Dialect/Vector/int-range-interface.mlir @@ -0,0 +1,106 @@ +// RUN: mlir-opt -int-range-optimizations -canonicalize %s | FileCheck %s + + +// CHECK-LABEL: func @constant_vec +// CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index} +func.func @constant_vec() -> vector<8xindex> { + %0 = arith.constant dense<[0, 1, 2, 3, 4, 5, 6, 7]> : vector<8xindex> + %1 = test.reflect_bounds %0 : vector<8xindex> + func.return %1 : vector<8xindex> +} + +// CHECK-LABEL: func @constant_splat +// CHECK: test.reflect_bounds {smax = 3 : si32, smin = 3 : si32, umax = 3 : ui32, umin = 3 : ui32} +func.func @constant_splat() -> vector<8xi32> { + %0 = arith.constant dense<3> : vector<8xi32> + %1 = test.reflect_bounds %0 : vector<8xi32> + func.return %1 : vector<8xi32> +} + +// CHECK-LABEL: func @vector_splat +// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} +func.func @vector_splat() -> vector<4xindex> { + %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index + %1 = vector.splat %0 : vector<4xindex> + %2 = test.reflect_bounds %1 : vector<4xindex> + func.return %2 : vector<4xindex> +} + +// CHECK-LABEL: func @vector_broadcast +// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} +func.func @vector_broadcast() -> vector<4x16xindex> { + %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<16xindex> + %1 = vector.broadcast %0 : vector<16xindex> to vector<4x16xindex> + %2 = test.reflect_bounds %1 : vector<4x16xindex> + func.return %2 : vector<4x16xindex> +} + +// CHECK-LABEL: func @vector_shape_cast +// CHECK: test.reflect_bounds {smax = 5 : index, smin = 4 : index, umax = 5 : index, umin = 4 : index} +func.func @vector_shape_cast() -> vector<4x4xindex> { + %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<16xindex> + %1 = vector.shape_cast %0 : vector<16xindex> to vector<4x4xindex> + %2 = test.reflect_bounds %1 : vector<4x4xindex> + func.return %2 : vector<4x4xindex> +} + +// CHECK-LABEL: func @vector_extract +// CHECK: test.reflect_bounds {smax = 6 : index, smin = 5 : index, umax = 6 : index, umin = 5 : index} +func.func @vector_extract() -> index { + %0 = test.with_bounds { umin = 5 : index, umax = 6 : index, smin = 5 : index, smax = 6 : index } : vector<4xindex> + %1 = vector.extract %0[0] : index from vector<4xindex> + %2 = test.reflect_bounds %1 : index + func.return %2 : index +} + +// CHECK-LABEL: func @vector_extractelement +// CHECK: test.reflect_bounds {smax = 7 : index, smin = 6 : index, umax = 7 : index, umin = 6 : index} +func.func @vector_extractelement() -> index { + %c0 = arith.constant 0 : index + %0 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex> + %1 = vector.extractelement %0[%c0 : index] : vector<4xindex> + %2 = test.reflect_bounds %1 : index + func.return %2 : index +} + +// CHECK-LABEL: func @vector_add +// CHECK: test.reflect_bounds {smax = 12 : index, smin = 10 : index, umax = 12 : index, umin = 10 : index} +func.func @vector_add() -> vector<4xindex> { + %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : vector<4xindex> + %1 = test.with_bounds { umin = 6 : index, umax = 7 : index, smin = 6 : index, smax = 7 : index } : vector<4xindex> + %2 = arith.addi %0, %1 : vector<4xindex> + %3 = test.reflect_bounds %2 : vector<4xindex> + func.return %3 : vector<4xindex> +} + +// CHECK-LABEL: func @vector_insert +// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index} +func.func @vector_insert() -> vector<4xindex> { + %0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex> + %1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index + %2 = vector.insert %1, %0[0] : index into vector<4xindex> + %3 = test.reflect_bounds %2 : vector<4xindex> + func.return %3 : vector<4xindex> +} + +// CHECK-LABEL: func @vector_insertelement +// CHECK: test.reflect_bounds {smax = 8 : index, smin = 5 : index, umax = 8 : index, umin = 5 : index} +func.func @vector_insertelement() -> vector<4xindex> { + %c0 = arith.constant 0 : index + %0 = test.with_bounds { umin = 5 : index, umax = 7 : index, smin = 5 : index, smax = 7 : index } : vector<4xindex> + %1 = test.with_bounds { umin = 6 : index, umax = 8 : index, smin = 6 : index, smax = 8 : index } : index + %2 = vector.insertelement %1, %0[%c0 : index] : vector<4xindex> + %3 = test.reflect_bounds %2 : vector<4xindex> + func.return %3 : vector<4xindex> +} + +// CHECK-LABEL: func @test_loaded_vector_extract +// No bounds +// CHECK: test.reflect_bounds %{{.*}} : i32 +func.func @test_loaded_vector_extract(%memref : memref<16xi32>) -> i32 { + %c0 = arith.constant 0 : index + %v = vector.load %memref[%c0] : memref<16xi32>, vector<4xi32> + %e = vector.extract %v[0] : i32 from vector<4xi32> + %bounds = test.reflect_bounds %e : i32 + func.return %bounds : i32 +} diff --git a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp index 69091fb893fad..b268e549b93ab 100644 --- a/mlir/test/lib/Dialect/Test/TestOpDefs.cpp +++ b/mlir/test/lib/Dialect/Test/TestOpDefs.cpp @@ -760,12 +760,13 @@ void TestReflectBoundsOp::inferResultRanges( Type sIntTy, uIntTy; // For plain `IntegerType`s, we can derive the appropriate signed and unsigned // Types for the Attributes. - if (auto intTy = llvm::dyn_cast(getType())) { + Type type = getElementTypeOrSelf(getType()); + if (auto intTy = llvm::dyn_cast(type)) { unsigned bitwidth = intTy.getWidth(); sIntTy = b.getIntegerType(bitwidth, /*isSigned=*/true); uIntTy = b.getIntegerType(bitwidth, /*isSigned=*/false); } else - sIntTy = uIntTy = getType(); + sIntTy = uIntTy = type; setUminAttr(b.getIntegerAttr(uIntTy, range.umin())); setUmaxAttr(b.getIntegerAttr(uIntTy, range.umax())); diff --git a/mlir/test/lib/Dialect/Test/TestOps.td b/mlir/test/lib/Dialect/Test/TestOps.td index bc6c6cf213ea4..cfe19a2fd5c08 100644 --- a/mlir/test/lib/Dialect/Test/TestOps.td +++ b/mlir/test/lib/Dialect/Test/TestOps.td @@ -2781,11 +2781,21 @@ def TestGraphLoopOp : TEST_Op<"graph_loop", //===----------------------------------------------------------------------===// // Test InferIntRangeInterface //===----------------------------------------------------------------------===// -def InferIntRangeType : AnyTypeOf<[AnyInteger, Index]>; +def InferIntRangeType : AnyTypeOf<[AnyInteger, Index, VectorOf<[AnyInteger, Index]>]>; def TestWithBoundsOp : TEST_Op<"with_bounds", [DeclareOpInterfaceMethods, NoMemoryEffect]> { + let description = [{ + Creates a value with specified [min, max] range for integer range analysis. + + Example: + + ```mlir + %0 = test.with_bounds { umin = 4 : index, umax = 5 : index, smin = 4 : index, smax = 5 : index } : index + ``` + }]; + let arguments = (ins APIntAttr:$umin, APIntAttr:$umax, APIntAttr:$smin, @@ -2819,6 +2829,18 @@ def TestIncrementOp : TEST_Op<"increment", def TestReflectBoundsOp : TEST_Op<"reflect_bounds", [DeclareOpInterfaceMethods, AllTypesMatch<["value", "result"]>]> { + let description = [{ + Integer range analysis will update this op to reflect inferred integer range + of the input, so it can be checked with FileCheck + + Example: + + ```mlir + CHECK: test.reflect_bounds {smax = 7 : index, smin = 0 : index, umax = 7 : index, umin = 0 : index} + %1 = test.reflect_bounds %0 : index + ``` + }]; + let arguments = (ins InferIntRangeType:$value, OptionalAttr:$umin, OptionalAttr:$umax,