diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td index 6cd3408e2b2e9..8be819323fd6f 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.td @@ -13,12 +13,13 @@ #ifndef AFFINE_OPS #define AFFINE_OPS -include "mlir/Dialect/Arith/IR/ArithBase.td" include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td" +include "mlir/Dialect/Arith/IR/ArithBase.td" include "mlir/Interfaces/ControlFlowInterfaces.td" include "mlir/Interfaces/InferTypeOpInterface.td" include "mlir/Interfaces/LoopLikeInterface.td" include "mlir/Interfaces/SideEffectInterfaces.td" +include "mlir/IR/CommonTypeConstraints.td" def Affine_Dialect : Dialect { let name = "affine"; @@ -57,18 +58,22 @@ def AffineApplyOp : Affine_Op<"apply", [Pure]> { %2 = affine.apply affine_map<(i)[s0] -> (i+s0)> (%42)[%n] ``` }]; - let arguments = (ins AffineMapAttr:$map, Variadic:$mapOperands); - let results = (outs Index); + let arguments = (ins AffineMapAttr:$map, Variadic:$mapOperands); + let results = (outs IndexLike); // TODO: The auto-generated builders should check to see if the return type // has a constant builder. That way we wouldn't need to explicitly specify the // result types here. let builders = [ - OpBuilder<(ins "ArrayRef ":$exprList,"ValueRange":$mapOperands), + OpBuilder<(ins "ArrayRef":$exprList,"ValueRange":$mapOperands), [{ build($_builder, $_state, $_builder.getIndexType(), AffineMap::inferFromExprList(exprList, $_builder.getContext()) .front(), mapOperands); + }]>, + OpBuilder<(ins "AffineMap":$map,"ValueRange":$mapOperands), + [{ + build($_builder, $_state, $_builder.getIndexType(), map, mapOperands); }]> ]; diff --git a/mlir/include/mlir/Dialect/Affine/Utils.h b/mlir/include/mlir/Dialect/Affine/Utils.h index ff1900bc8f2eb..ab39239a77312 100644 --- a/mlir/include/mlir/Dialect/Affine/Utils.h +++ b/mlir/include/mlir/Dialect/Affine/Utils.h @@ -294,14 +294,14 @@ void createAffineComputationSlice(Operation *opInst, /// Emit code that computes the given affine expression using standard /// arithmetic operations applied to the provided dimension and symbol values. Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, - ValueRange dimValues, ValueRange symbolValues); + ValueRange dimValues, ValueRange symbolValues, + Type type = {}); /// Create a sequence of operations that implement the `affineMap` applied to /// the given `operands` (as it it were an AffineApplyOp). -std::optional> expandAffineMap(OpBuilder &builder, - Location loc, - AffineMap affineMap, - ValueRange operands); +std::optional> +expandAffineMap(OpBuilder &builder, Location loc, AffineMap affineMap, + ValueRange operands, Type type = {}); /// Holds the result of (div a, b) and (mod a, b). struct DivModValue { diff --git a/mlir/include/mlir/IR/CommonTypeConstraints.td b/mlir/include/mlir/IR/CommonTypeConstraints.td index 601517717978e..65d2ddb6a5c85 100644 --- a/mlir/include/mlir/IR/CommonTypeConstraints.td +++ b/mlir/include/mlir/IR/CommonTypeConstraints.td @@ -892,6 +892,10 @@ class TypeOrValueSemanticsContainer // bools. def BoolLike : TypeOrValueSemanticsContainer; +// Type constraint for index-like types: index, vectors of index, tensors of +// index. +def IndexLike : TypeOrValueSemanticsContainer; + // Type constraint for signless-integer-like types: signless integers or // value-semantics containers of signless integers. def SignlessIntegerLike : TypeOrValueSemanticsContainer< diff --git a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp index 4fbe6a03f6bad..e9fb745a068a4 100644 --- a/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp +++ b/mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp @@ -336,7 +336,7 @@ class AffineApplyLowering : public OpRewritePattern { PatternRewriter &rewriter) const override { auto maybeExpandedMap = expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(), - llvm::to_vector<8>(op.getOperands())); + llvm::to_vector<8>(op.getOperands()), op.getType()); if (!maybeExpandedMap) return failure(); rewriter.replaceOp(op, *maybeExpandedMap); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index 8acb21d5074b4..717d3bbd8e3e9 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -491,20 +491,37 @@ static void printDimAndSymbolList(Operation::operand_iterator begin, printer << '[' << operands.drop_front(numDims) << ']'; } -/// Parses dimension and symbol list and returns true if parsing failed. -ParseResult mlir::affine::parseDimAndSymbolList( - OpAsmParser &parser, SmallVectorImpl &operands, unsigned &numDims) { - SmallVector opInfos; +/// Parse dimension and symbol list, but not resolve yet, as we may not know the +/// operands types. +static ParseResult parseDimAndSymbolListImpl( + OpAsmParser &parser, + SmallVectorImpl &opInfos, + unsigned &numDims) { if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren)) return failure(); + // Store number of dimensions for validation by caller. numDims = opInfos.size(); // Parse the optional symbol operands. + if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::OptionalSquare)) + return failure(); + + return success(); +} + +/// Parses dimension and symbol list and returns true if parsing failed. +ParseResult mlir::affine::parseDimAndSymbolList( + OpAsmParser &parser, SmallVectorImpl &operands, unsigned &numDims) { + SmallVector opInfos; + if (parseDimAndSymbolListImpl(parser, opInfos, numDims)) + return failure(); + auto indexTy = parser.getBuilder().getIndexType(); - return failure(parser.parseOperandList( - opInfos, OpAsmParser::Delimiter::OptionalSquare) || - parser.resolveOperands(opInfos, indexTy, operands)); + if (parser.resolveOperands(opInfos, indexTy, operands)) + return failure(); + + return success(); } /// Utility function to verify that a set of operands are valid dimension and @@ -538,14 +555,25 @@ AffineValueMap AffineApplyOp::getAffineValueMap() { ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) { auto &builder = parser.getBuilder(); - auto indexTy = builder.getIndexType(); AffineMapAttr mapAttr; unsigned numDims; + SmallVector opInfos; if (parser.parseAttribute(mapAttr, "map", result.attributes) || - parseDimAndSymbolList(parser, result.operands, numDims) || + parseDimAndSymbolListImpl(parser, opInfos, numDims) || parser.parseOptionalAttrDict(result.attributes)) return failure(); + + Type type; + if (parser.parseOptionalColon()) { + type = builder.getIndexType(); + } else if (parser.parseType(type)) { + return failure(); + } + + if (parser.resolveOperands(opInfos, type, result.operands)) + return failure(); + auto map = mapAttr.getValue(); if (map.getNumDims() != numDims || @@ -554,7 +582,7 @@ ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) { "dimension or symbol index mismatch"); } - result.types.append(map.getNumResults(), indexTy); + result.types.append(map.getNumResults(), type); return success(); } @@ -563,9 +591,18 @@ void AffineApplyOp::print(OpAsmPrinter &p) { printDimAndSymbolList(operand_begin(), operand_end(), getAffineMap().getNumDims(), p); p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"}); + Type resType = getType(); + if (!isa(resType)) + p << ":" << resType; } LogicalResult AffineApplyOp::verify() { + // Check all operand and result types are the same. + // We cannot use `SameOperandsAndResultType` as it expects at least 1 operand. + if (!llvm::all_equal( + llvm::concat(getOperandTypes(), (*this)->getResultTypes()))) + return emitOpError("requires the same type for all operands and results"); + // Check input and output dimensions match. AffineMap affineMap = getMap(); diff --git a/mlir/lib/Dialect/Affine/Utils/Utils.cpp b/mlir/lib/Dialect/Affine/Utils/Utils.cpp index 2723cff6900d0..0342ae3ac6908 100644 --- a/mlir/lib/Dialect/Affine/Utils/Utils.cpp +++ b/mlir/lib/Dialect/Affine/Utils/Utils.cpp @@ -46,9 +46,9 @@ class AffineApplyExpander /// This internal class expects arguments to be non-null, checks must be /// performed at the call site. AffineApplyExpander(OpBuilder &builder, ValueRange dimValues, - ValueRange symbolValues, Location loc) + ValueRange symbolValues, Location loc, Type type) : builder(builder), dimValues(dimValues), symbolValues(symbolValues), - loc(loc) {} + loc(loc), type(type) {} template Value buildBinaryExpr(AffineBinaryOpExpr expr, @@ -189,8 +189,16 @@ class AffineApplyExpander } Value visitConstantExpr(AffineConstantExpr expr) { - auto op = builder.create(loc, expr.getValue()); - return op.getResult(); + int64_t value = expr.getValue(); + if (isa(type)) + return builder.create(loc, value); + + if (auto shaped = dyn_cast(type)) { + auto elements = DenseIntElementsAttr::get(shaped, value); + return builder.create(loc, elements); + } + + llvm_unreachable("AffineApplyExpander: Unsupported type"); } Value visitDimExpr(AffineDimExpr expr) { @@ -211,6 +219,7 @@ class AffineApplyExpander ValueRange symbolValues; Location loc; + Type type; }; } // namespace @@ -219,23 +228,28 @@ class AffineApplyExpander mlir::Value mlir::affine::expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr, ValueRange dimValues, - ValueRange symbolValues) { - return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr); + ValueRange symbolValues, Type type) { + if (!type) + type = builder.getIndexType(); + + return AffineApplyExpander(builder, dimValues, symbolValues, loc, type) + .visit(expr); } /// Create a sequence of operations that implement the `affineMap` applied to /// the given `operands` (as it it were an AffineApplyOp). std::optional> mlir::affine::expandAffineMap(OpBuilder &builder, Location loc, - AffineMap affineMap, ValueRange operands) { + AffineMap affineMap, ValueRange operands, + Type type) { auto numDims = affineMap.getNumDims(); auto expanded = llvm::to_vector<8>( - llvm::map_range(affineMap.getResults(), - [numDims, &builder, loc, operands](AffineExpr expr) { - return expandAffineExpr(builder, loc, expr, - operands.take_front(numDims), - operands.drop_front(numDims)); - })); + llvm::map_range(affineMap.getResults(), [numDims, &builder, loc, operands, + type](AffineExpr expr) { + return expandAffineExpr(builder, loc, expr, + operands.take_front(numDims), + operands.drop_front(numDims), type); + })); if (llvm::all_of(expanded, [](Value v) { return v; })) return expanded; return std::nullopt; diff --git a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir index 550ea71882e14..4f01f05dfb6b4 100644 --- a/mlir/test/Conversion/AffineToStandard/lower-affine.mlir +++ b/mlir/test/Conversion/AffineToStandard/lower-affine.mlir @@ -429,8 +429,9 @@ func.func @min_reduction_tree(%v1 : index, %v2 : index, %v3 : index, %v4 : index #map5 = affine_map<(d0,d1,d2) -> (d0,d1,d2)> #map6 = affine_map<(d0,d1,d2) -> (d0 + d1 + d2)> -// CHECK-LABEL: func @affine_applies( -func.func @affine_applies(%arg0 : index) { +// CHECK-LABEL: func @affine_applies +// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: vector<4xindex>) +func.func @affine_applies(%arg0 : index, %arg1 : vector<4xindex>) { // CHECK: %[[c0:.*]] = arith.constant 0 : index %zero = affine.apply #map0() @@ -448,24 +449,29 @@ func.func @affine_applies(%arg0 : index) { %one = affine.apply #map3(%symbZero)[%zero] // CHECK-NEXT: %[[c2:.*]] = arith.constant 2 : index -// CHECK-NEXT: %[[v2:.*]] = arith.muli %arg0, %[[c2]] overflow : index -// CHECK-NEXT: %[[v3:.*]] = arith.addi %arg0, %[[v2]] : index +// CHECK-NEXT: %[[v2:.*]] = arith.muli %[[ARG0]], %[[c2]] overflow : index +// CHECK-NEXT: %[[v3:.*]] = arith.addi %[[ARG0]], %[[v2]] : index // CHECK-NEXT: %[[c3:.*]] = arith.constant 3 : index -// CHECK-NEXT: %[[v4:.*]] = arith.muli %arg0, %[[c3]] overflow : index +// CHECK-NEXT: %[[v4:.*]] = arith.muli %[[ARG0]], %[[c3]] overflow : index // CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] : index // CHECK-NEXT: %[[c4:.*]] = arith.constant 4 : index -// CHECK-NEXT: %[[v6:.*]] = arith.muli %arg0, %[[c4]] overflow : index +// CHECK-NEXT: %[[v6:.*]] = arith.muli %[[ARG0]], %[[c4]] overflow : index // CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v5]], %[[v6]] : index // CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : index -// CHECK-NEXT: %[[v8:.*]] = arith.muli %arg0, %[[c5]] overflow : index +// CHECK-NEXT: %[[v8:.*]] = arith.muli %[[ARG0]], %[[c5]] overflow : index // CHECK-NEXT: %[[v9:.*]] = arith.addi %[[v7]], %[[v8]] : index // CHECK-NEXT: %[[c6:.*]] = arith.constant 6 : index -// CHECK-NEXT: %[[v10:.*]] = arith.muli %arg0, %[[c6]] overflow : index +// CHECK-NEXT: %[[v10:.*]] = arith.muli %[[ARG0]], %[[c6]] overflow : index // CHECK-NEXT: %[[v11:.*]] = arith.addi %[[v9]], %[[v10]] : index // CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index -// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] overflow : index +// CHECK-NEXT: %[[v12:.*]] = arith.muli %[[ARG0]], %[[c7]] overflow : index // CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v11]], %[[v12]] : index %four = affine.apply #map4(%arg0, %arg0, %arg0, %arg0)[%arg0, %arg0, %arg0] + +// CHECK-NEXT: %[[v14:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : vector<4xindex> +// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<1> : vector<4xindex> +// CHECK-NEXT: %[[v15:.*]] = arith.addi %[[v14]], %[[cst]] : vector<4xindex> + %vec = affine.apply #map3(%arg1)[%arg1] : vector<4xindex> return } diff --git a/mlir/test/Dialect/Affine/invalid.mlir b/mlir/test/Dialect/Affine/invalid.mlir index 9bbd19c381163..af948aa56eef5 100644 --- a/mlir/test/Dialect/Affine/invalid.mlir +++ b/mlir/test/Dialect/Affine/invalid.mlir @@ -5,7 +5,7 @@ func.func @affine_apply_operand_non_index(%arg0 : i32) { // Custom parser automatically assigns all arguments the `index` so we must // use the generic syntax here to exercise the verifier. - // expected-error@+1 {{op operand #0 must be variadic of index, but got 'i32'}} + // expected-error@+1 {{op operand #0 must be variadic of index-like, but got 'i32'}} %0 = "affine.apply"(%arg0) {map = affine_map<(d0) -> (d0)>} : (i32) -> (index) return } @@ -15,11 +15,21 @@ func.func @affine_apply_operand_non_index(%arg0 : i32) { func.func @affine_apply_resul_non_index(%arg0 : index) { // Custom parser automatically assigns `index` as the result type so we must // use the generic syntax here to exercise the verifier. - // expected-error@+1 {{op result #0 must be index, but got 'i32'}} + // expected-error@+1 {{op result #0 must be index-like, but got 'i32'}} %0 = "affine.apply"(%arg0) {map = affine_map<(d0) -> (d0)>} : (index) -> (i32) return } +// ----- + +func.func @affine_apply_types_match(%arg0 : index) { + // We are now supporting vectors of index, but all operands and result types + // must match. + // expected-error@+1 {{op requires the same type for all operands and results}} + %0 = "affine.apply"(%arg0) {map = affine_map<(d0) -> (d0)>} : (index) -> (vector<4xindex>) + return +} + // ----- func.func @affine_load_invalid_dim(%M : memref<10xi32>) { "unknown"() ({