Skip to content

Commit a33196c

Browse files
committed
[mlir][affine] Support vector types in affine.apply
`affine.apply` is generally useful outside of affine to generate various index computations. Add support for vectors of index to enable vectorized code generation. All operands and result types must match. Type is optional in asm format and assumed `index` if missing so it's backward compatible with exisiting text IR, to reduce churn.
1 parent 6a030b3 commit a33196c

File tree

8 files changed

+120
-44
lines changed

8 files changed

+120
-44
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
#ifndef AFFINE_OPS
1414
#define AFFINE_OPS
1515

16-
include "mlir/Dialect/Arith/IR/ArithBase.td"
1716
include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
17+
include "mlir/Dialect/Arith/IR/ArithBase.td"
1818
include "mlir/Interfaces/ControlFlowInterfaces.td"
1919
include "mlir/Interfaces/InferTypeOpInterface.td"
2020
include "mlir/Interfaces/LoopLikeInterface.td"
2121
include "mlir/Interfaces/SideEffectInterfaces.td"
22+
include "mlir/IR/CommonTypeConstraints.td"
2223

2324
def Affine_Dialect : Dialect {
2425
let name = "affine";
@@ -57,18 +58,22 @@ def AffineApplyOp : Affine_Op<"apply", [Pure]> {
5758
%2 = affine.apply affine_map<(i)[s0] -> (i+s0)> (%42)[%n]
5859
```
5960
}];
60-
let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$mapOperands);
61-
let results = (outs Index);
61+
let arguments = (ins AffineMapAttr:$map, Variadic<IndexLike>:$mapOperands);
62+
let results = (outs IndexLike);
6263

6364
// TODO: The auto-generated builders should check to see if the return type
6465
// has a constant builder. That way we wouldn't need to explicitly specify the
6566
// result types here.
6667
let builders = [
67-
OpBuilder<(ins "ArrayRef<AffineExpr> ":$exprList,"ValueRange":$mapOperands),
68+
OpBuilder<(ins "ArrayRef<AffineExpr>":$exprList,"ValueRange":$mapOperands),
6869
[{
6970
build($_builder, $_state, $_builder.getIndexType(),
7071
AffineMap::inferFromExprList(exprList, $_builder.getContext())
7172
.front(), mapOperands);
73+
}]>,
74+
OpBuilder<(ins "AffineMap":$map,"ValueRange":$mapOperands),
75+
[{
76+
build($_builder, $_state, $_builder.getIndexType(), map, mapOperands);
7277
}]>
7378
];
7479

mlir/include/mlir/Dialect/Affine/Utils.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,14 +294,14 @@ void createAffineComputationSlice(Operation *opInst,
294294
/// Emit code that computes the given affine expression using standard
295295
/// arithmetic operations applied to the provided dimension and symbol values.
296296
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
297-
ValueRange dimValues, ValueRange symbolValues);
297+
ValueRange dimValues, ValueRange symbolValues,
298+
Type type = {});
298299

299300
/// Create a sequence of operations that implement the `affineMap` applied to
300301
/// the given `operands` (as it it were an AffineApplyOp).
301-
std::optional<SmallVector<Value, 8>> expandAffineMap(OpBuilder &builder,
302-
Location loc,
303-
AffineMap affineMap,
304-
ValueRange operands);
302+
std::optional<SmallVector<Value, 8>>
303+
expandAffineMap(OpBuilder &builder, Location loc, AffineMap affineMap,
304+
ValueRange operands, Type type = {});
305305

306306
/// Holds the result of (div a, b) and (mod a, b).
307307
struct DivModValue {

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -892,6 +892,10 @@ class TypeOrValueSemanticsContainer<Type allowedType, string name>
892892
// bools.
893893
def BoolLike : TypeOrValueSemanticsContainer<I1, "bool-like">;
894894

895+
// Type constraint for index-like types: index, vectors of index, tensors of
896+
// index.
897+
def IndexLike : TypeOrValueSemanticsContainer<Index, "index-like">;
898+
895899
// Type constraint for signless-integer-like types: signless integers or
896900
// value-semantics containers of signless integers.
897901
def SignlessIntegerLike : TypeOrValueSemanticsContainer<

mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
336336
PatternRewriter &rewriter) const override {
337337
auto maybeExpandedMap =
338338
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
339-
llvm::to_vector<8>(op.getOperands()));
339+
llvm::to_vector<8>(op.getOperands()), op.getType());
340340
if (!maybeExpandedMap)
341341
return failure();
342342
rewriter.replaceOp(op, *maybeExpandedMap);

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -491,20 +491,37 @@ static void printDimAndSymbolList(Operation::operand_iterator begin,
491491
printer << '[' << operands.drop_front(numDims) << ']';
492492
}
493493

494-
/// Parses dimension and symbol list and returns true if parsing failed.
495-
ParseResult mlir::affine::parseDimAndSymbolList(
496-
OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
497-
SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
494+
/// Parse dimension and symbol list, but not resolve yet, as we may not know the
495+
/// operands types.
496+
static ParseResult parseDimAndSymbolListImpl(
497+
OpAsmParser &parser,
498+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &opInfos,
499+
unsigned &numDims) {
498500
if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
499501
return failure();
502+
500503
// Store number of dimensions for validation by caller.
501504
numDims = opInfos.size();
502505

503506
// Parse the optional symbol operands.
507+
if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::OptionalSquare))
508+
return failure();
509+
510+
return success();
511+
}
512+
513+
/// Parses dimension and symbol list and returns true if parsing failed.
514+
ParseResult mlir::affine::parseDimAndSymbolList(
515+
OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
516+
SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
517+
if (parseDimAndSymbolListImpl(parser, opInfos, numDims))
518+
return failure();
519+
504520
auto indexTy = parser.getBuilder().getIndexType();
505-
return failure(parser.parseOperandList(
506-
opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
507-
parser.resolveOperands(opInfos, indexTy, operands));
521+
if (parser.resolveOperands(opInfos, indexTy, operands))
522+
return failure();
523+
524+
return success();
508525
}
509526

510527
/// Utility function to verify that a set of operands are valid dimension and
@@ -538,14 +555,25 @@ AffineValueMap AffineApplyOp::getAffineValueMap() {
538555

539556
ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
540557
auto &builder = parser.getBuilder();
541-
auto indexTy = builder.getIndexType();
542558

543559
AffineMapAttr mapAttr;
544560
unsigned numDims;
561+
SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
545562
if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
546-
parseDimAndSymbolList(parser, result.operands, numDims) ||
563+
parseDimAndSymbolListImpl(parser, opInfos, numDims) ||
547564
parser.parseOptionalAttrDict(result.attributes))
548565
return failure();
566+
567+
Type type;
568+
if (parser.parseOptionalColon()) {
569+
type = builder.getIndexType();
570+
} else if (parser.parseType(type)) {
571+
return failure();
572+
}
573+
574+
if (parser.resolveOperands(opInfos, type, result.operands))
575+
return failure();
576+
549577
auto map = mapAttr.getValue();
550578

551579
if (map.getNumDims() != numDims ||
@@ -554,7 +582,7 @@ ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
554582
"dimension or symbol index mismatch");
555583
}
556584

557-
result.types.append(map.getNumResults(), indexTy);
585+
result.types.append(map.getNumResults(), type);
558586
return success();
559587
}
560588

@@ -563,9 +591,18 @@ void AffineApplyOp::print(OpAsmPrinter &p) {
563591
printDimAndSymbolList(operand_begin(), operand_end(),
564592
getAffineMap().getNumDims(), p);
565593
p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
594+
Type resType = getType();
595+
if (!isa<IndexType>(resType))
596+
p << ":" << resType;
566597
}
567598

568599
LogicalResult AffineApplyOp::verify() {
600+
// Check all operand and result types are the same.
601+
// We cannot use `SameOperandsAndResultType` as it expects at least 1 operand.
602+
if (!llvm::all_equal(
603+
llvm::concat<Type>(getOperandTypes(), (*this)->getResultTypes())))
604+
return emitOpError("requires the same type for all operands and results");
605+
569606
// Check input and output dimensions match.
570607
AffineMap affineMap = getMap();
571608

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ class AffineApplyExpander
4646
/// This internal class expects arguments to be non-null, checks must be
4747
/// performed at the call site.
4848
AffineApplyExpander(OpBuilder &builder, ValueRange dimValues,
49-
ValueRange symbolValues, Location loc)
49+
ValueRange symbolValues, Location loc, Type type)
5050
: builder(builder), dimValues(dimValues), symbolValues(symbolValues),
51-
loc(loc) {}
51+
loc(loc), type(type) {}
5252

5353
template <typename OpTy>
5454
Value buildBinaryExpr(AffineBinaryOpExpr expr,
@@ -189,8 +189,16 @@ class AffineApplyExpander
189189
}
190190

191191
Value visitConstantExpr(AffineConstantExpr expr) {
192-
auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue());
193-
return op.getResult();
192+
int64_t value = expr.getValue();
193+
if (isa<IndexType>(type))
194+
return builder.create<arith::ConstantIndexOp>(loc, value);
195+
196+
if (auto shaped = dyn_cast<ShapedType>(type)) {
197+
auto elements = DenseIntElementsAttr::get(shaped, value);
198+
return builder.create<arith::ConstantOp>(loc, elements);
199+
}
200+
201+
llvm_unreachable("AffineApplyExpander: Unsupported type");
194202
}
195203

196204
Value visitDimExpr(AffineDimExpr expr) {
@@ -211,6 +219,7 @@ class AffineApplyExpander
211219
ValueRange symbolValues;
212220

213221
Location loc;
222+
Type type;
214223
};
215224
} // namespace
216225

@@ -219,23 +228,28 @@ class AffineApplyExpander
219228
mlir::Value mlir::affine::expandAffineExpr(OpBuilder &builder, Location loc,
220229
AffineExpr expr,
221230
ValueRange dimValues,
222-
ValueRange symbolValues) {
223-
return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
231+
ValueRange symbolValues, Type type) {
232+
if (!type)
233+
type = builder.getIndexType();
234+
235+
return AffineApplyExpander(builder, dimValues, symbolValues, loc, type)
236+
.visit(expr);
224237
}
225238

226239
/// Create a sequence of operations that implement the `affineMap` applied to
227240
/// the given `operands` (as it it were an AffineApplyOp).
228241
std::optional<SmallVector<Value, 8>>
229242
mlir::affine::expandAffineMap(OpBuilder &builder, Location loc,
230-
AffineMap affineMap, ValueRange operands) {
243+
AffineMap affineMap, ValueRange operands,
244+
Type type) {
231245
auto numDims = affineMap.getNumDims();
232246
auto expanded = llvm::to_vector<8>(
233-
llvm::map_range(affineMap.getResults(),
234-
[numDims, &builder, loc, operands](AffineExpr expr) {
235-
return expandAffineExpr(builder, loc, expr,
236-
operands.take_front(numDims),
237-
operands.drop_front(numDims));
238-
}));
247+
llvm::map_range(affineMap.getResults(), [numDims, &builder, loc, operands,
248+
type](AffineExpr expr) {
249+
return expandAffineExpr(builder, loc, expr,
250+
operands.take_front(numDims),
251+
operands.drop_front(numDims), type);
252+
}));
239253
if (llvm::all_of(expanded, [](Value v) { return v; }))
240254
return expanded;
241255
return std::nullopt;

mlir/test/Conversion/AffineToStandard/lower-affine.mlir

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -429,8 +429,9 @@ func.func @min_reduction_tree(%v1 : index, %v2 : index, %v3 : index, %v4 : index
429429
#map5 = affine_map<(d0,d1,d2) -> (d0,d1,d2)>
430430
#map6 = affine_map<(d0,d1,d2) -> (d0 + d1 + d2)>
431431

432-
// CHECK-LABEL: func @affine_applies(
433-
func.func @affine_applies(%arg0 : index) {
432+
// CHECK-LABEL: func @affine_applies
433+
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: vector<4xindex>)
434+
func.func @affine_applies(%arg0 : index, %arg1 : vector<4xindex>) {
434435
// CHECK: %[[c0:.*]] = arith.constant 0 : index
435436
%zero = affine.apply #map0()
436437

@@ -448,24 +449,29 @@ func.func @affine_applies(%arg0 : index) {
448449
%one = affine.apply #map3(%symbZero)[%zero]
449450

450451
// CHECK-NEXT: %[[c2:.*]] = arith.constant 2 : index
451-
// CHECK-NEXT: %[[v2:.*]] = arith.muli %arg0, %[[c2]] overflow<nsw> : index
452-
// CHECK-NEXT: %[[v3:.*]] = arith.addi %arg0, %[[v2]] : index
452+
// CHECK-NEXT: %[[v2:.*]] = arith.muli %[[ARG0]], %[[c2]] overflow<nsw> : index
453+
// CHECK-NEXT: %[[v3:.*]] = arith.addi %[[ARG0]], %[[v2]] : index
453454
// CHECK-NEXT: %[[c3:.*]] = arith.constant 3 : index
454-
// CHECK-NEXT: %[[v4:.*]] = arith.muli %arg0, %[[c3]] overflow<nsw> : index
455+
// CHECK-NEXT: %[[v4:.*]] = arith.muli %[[ARG0]], %[[c3]] overflow<nsw> : index
455456
// CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] : index
456457
// CHECK-NEXT: %[[c4:.*]] = arith.constant 4 : index
457-
// CHECK-NEXT: %[[v6:.*]] = arith.muli %arg0, %[[c4]] overflow<nsw> : index
458+
// CHECK-NEXT: %[[v6:.*]] = arith.muli %[[ARG0]], %[[c4]] overflow<nsw> : index
458459
// CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v5]], %[[v6]] : index
459460
// CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : index
460-
// CHECK-NEXT: %[[v8:.*]] = arith.muli %arg0, %[[c5]] overflow<nsw> : index
461+
// CHECK-NEXT: %[[v8:.*]] = arith.muli %[[ARG0]], %[[c5]] overflow<nsw> : index
461462
// CHECK-NEXT: %[[v9:.*]] = arith.addi %[[v7]], %[[v8]] : index
462463
// CHECK-NEXT: %[[c6:.*]] = arith.constant 6 : index
463-
// CHECK-NEXT: %[[v10:.*]] = arith.muli %arg0, %[[c6]] overflow<nsw> : index
464+
// CHECK-NEXT: %[[v10:.*]] = arith.muli %[[ARG0]], %[[c6]] overflow<nsw> : index
464465
// CHECK-NEXT: %[[v11:.*]] = arith.addi %[[v9]], %[[v10]] : index
465466
// CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
466-
// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] overflow<nsw> : index
467+
// CHECK-NEXT: %[[v12:.*]] = arith.muli %[[ARG0]], %[[c7]] overflow<nsw> : index
467468
// CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v11]], %[[v12]] : index
468469
%four = affine.apply #map4(%arg0, %arg0, %arg0, %arg0)[%arg0, %arg0, %arg0]
470+
471+
// CHECK-NEXT: %[[v14:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : vector<4xindex>
472+
// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<1> : vector<4xindex>
473+
// CHECK-NEXT: %[[v15:.*]] = arith.addi %[[v14]], %[[cst]] : vector<4xindex>
474+
%vec = affine.apply #map3(%arg1)[%arg1] : vector<4xindex>
469475
return
470476
}
471477

mlir/test/Dialect/Affine/invalid.mlir

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
func.func @affine_apply_operand_non_index(%arg0 : i32) {
66
// Custom parser automatically assigns all arguments the `index` so we must
77
// use the generic syntax here to exercise the verifier.
8-
// expected-error@+1 {{op operand #0 must be variadic of index, but got 'i32'}}
8+
// expected-error@+1 {{op operand #0 must be variadic of index-like, but got 'i32'}}
99
%0 = "affine.apply"(%arg0) {map = affine_map<(d0) -> (d0)>} : (i32) -> (index)
1010
return
1111
}
@@ -15,11 +15,21 @@ func.func @affine_apply_operand_non_index(%arg0 : i32) {
1515
func.func @affine_apply_resul_non_index(%arg0 : index) {
1616
// Custom parser automatically assigns `index` as the result type so we must
1717
// use the generic syntax here to exercise the verifier.
18-
// expected-error@+1 {{op result #0 must be index, but got 'i32'}}
18+
// expected-error@+1 {{op result #0 must be index-like, but got 'i32'}}
1919
%0 = "affine.apply"(%arg0) {map = affine_map<(d0) -> (d0)>} : (index) -> (i32)
2020
return
2121
}
2222

23+
// -----
24+
25+
func.func @affine_apply_types_match(%arg0 : index) {
26+
// We are now supporting vectors of index, but all operands and result types
27+
// must match.
28+
// expected-error@+1 {{op requires the same type for all operands and results}}
29+
%0 = "affine.apply"(%arg0) {map = affine_map<(d0) -> (d0)>} : (index) -> (vector<4xindex>)
30+
return
31+
}
32+
2333
// -----
2434
func.func @affine_load_invalid_dim(%M : memref<10xi32>) {
2535
"unknown"() ({

0 commit comments

Comments
 (0)