Skip to content

Commit b6564b6

Browse files
committed
[mlir][affine] Support vector types in affine.apply
`affine.apply` is generally useful outside of affine to generate various index computations. Add support for vectors of index to enable vectorized code generation. All operands and result types must match. Type is optional in asm format and assumed `index` if missing so it's backward compatible with exisiting text IR, to reduce churn.
1 parent 00a0b0b commit b6564b6

File tree

8 files changed

+121
-45
lines changed

8 files changed

+121
-45
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.td

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -13,12 +13,13 @@
1313
#ifndef AFFINE_OPS
1414
#define AFFINE_OPS
1515

16-
include "mlir/Dialect/Arith/IR/ArithBase.td"
1716
include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
17+
include "mlir/Dialect/Arith/IR/ArithBase.td"
1818
include "mlir/Interfaces/ControlFlowInterfaces.td"
1919
include "mlir/Interfaces/InferTypeOpInterface.td"
2020
include "mlir/Interfaces/LoopLikeInterface.td"
2121
include "mlir/Interfaces/SideEffectInterfaces.td"
22+
include "mlir/IR/CommonTypeConstraints.td"
2223

2324
def Affine_Dialect : Dialect {
2425
let name = "affine";
@@ -57,18 +58,22 @@ def AffineApplyOp : Affine_Op<"apply", [Pure]> {
5758
%2 = affine.apply affine_map<(i)[s0] -> (i+s0)> (%42)[%n]
5859
```
5960
}];
60-
let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$mapOperands);
61-
let results = (outs Index);
61+
let arguments = (ins AffineMapAttr:$map, Variadic<IndexLike>:$mapOperands);
62+
let results = (outs IndexLike);
6263

6364
// TODO: The auto-generated builders should check to see if the return type
6465
// has a constant builder. That way we wouldn't need to explicitly specify the
6566
// result types here.
6667
let builders = [
67-
OpBuilder<(ins "ArrayRef<AffineExpr> ":$exprList,"ValueRange":$mapOperands),
68+
OpBuilder<(ins "ArrayRef<AffineExpr>":$exprList,"ValueRange":$mapOperands),
6869
[{
6970
build($_builder, $_state, $_builder.getIndexType(),
7071
AffineMap::inferFromExprList(exprList, $_builder.getContext())
7172
.front(), mapOperands);
73+
}]>,
74+
OpBuilder<(ins "AffineMap":$map,"ValueRange":$mapOperands),
75+
[{
76+
build($_builder, $_state, $_builder.getIndexType(), map, mapOperands);
7277
}]>
7378
];
7479

mlir/include/mlir/Dialect/Affine/Utils.h

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -294,14 +294,14 @@ void createAffineComputationSlice(Operation *opInst,
294294
/// Emit code that computes the given affine expression using standard
295295
/// arithmetic operations applied to the provided dimension and symbol values.
296296
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
297-
ValueRange dimValues, ValueRange symbolValues);
297+
ValueRange dimValues, ValueRange symbolValues,
298+
Type type = {});
298299

299300
/// Create a sequence of operations that implement the `affineMap` applied to
300301
/// the given `operands` (as it it were an AffineApplyOp).
301-
std::optional<SmallVector<Value, 8>> expandAffineMap(OpBuilder &builder,
302-
Location loc,
303-
AffineMap affineMap,
304-
ValueRange operands);
302+
std::optional<SmallVector<Value, 8>>
303+
expandAffineMap(OpBuilder &builder, Location loc, AffineMap affineMap,
304+
ValueRange operands, Type type = {});
305305

306306
/// Holds the result of (div a, b) and (mod a, b).
307307
struct DivModValue {

mlir/include/mlir/IR/CommonTypeConstraints.td

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ class NestedTupleOf<list<Type> allowedTypes> :
883883
// Type constraint for types that are "like" some type or set of types T, that is
884884
// they're either a T or a mapable container of Ts.
885885
class TypeOrValueSemanticsContainer<Type allowedType, string name>
886-
: TypeConstraint<Or<[
886+
: Type<Or<[
887887
allowedType.predicate,
888888
ValueSemanticsContainerOf<[allowedType]>.predicate]>,
889889
name>;
@@ -892,6 +892,10 @@ class TypeOrValueSemanticsContainer<Type allowedType, string name>
892892
// bools.
893893
def BoolLike : TypeOrValueSemanticsContainer<I1, "bool-like">;
894894

895+
// Type constraint for index-like types: index, vectors of index, tensors of
896+
// index.
897+
def IndexLike : TypeOrValueSemanticsContainer<Index, "index-like">;
898+
895899
// Type constraint for signless-integer-like types: signless integers or
896900
// value-semantics containers of signless integers.
897901
def SignlessIntegerLike : TypeOrValueSemanticsContainer<

mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -336,7 +336,7 @@ class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
336336
PatternRewriter &rewriter) const override {
337337
auto maybeExpandedMap =
338338
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
339-
llvm::to_vector<8>(op.getOperands()));
339+
llvm::to_vector<8>(op.getOperands()), op.getType());
340340
if (!maybeExpandedMap)
341341
return failure();
342342
rewriter.replaceOp(op, *maybeExpandedMap);

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 47 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -481,20 +481,37 @@ static void printDimAndSymbolList(Operation::operand_iterator begin,
481481
printer << '[' << operands.drop_front(numDims) << ']';
482482
}
483483

484-
/// Parses dimension and symbol list and returns true if parsing failed.
485-
ParseResult mlir::affine::parseDimAndSymbolList(
486-
OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
487-
SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
484+
/// Parse dimension and symbol list, but not resolve yet, as we may not know the
485+
/// operands types.
486+
static ParseResult parseDimAndSymbolListImpl(
487+
OpAsmParser &parser,
488+
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &opInfos,
489+
unsigned &numDims) {
488490
if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
489491
return failure();
492+
490493
// Store number of dimensions for validation by caller.
491494
numDims = opInfos.size();
492495

493496
// Parse the optional symbol operands.
497+
if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::OptionalSquare))
498+
return failure();
499+
500+
return success();
501+
}
502+
503+
/// Parses dimension and symbol list and returns true if parsing failed.
504+
ParseResult mlir::affine::parseDimAndSymbolList(
505+
OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
506+
SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
507+
if (parseDimAndSymbolListImpl(parser, opInfos, numDims))
508+
return failure();
509+
494510
auto indexTy = parser.getBuilder().getIndexType();
495-
return failure(parser.parseOperandList(
496-
opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
497-
parser.resolveOperands(opInfos, indexTy, operands));
511+
if (parser.resolveOperands(opInfos, indexTy, operands))
512+
return failure();
513+
514+
return success();
498515
}
499516

500517
/// Utility function to verify that a set of operands are valid dimension and
@@ -528,14 +545,25 @@ AffineValueMap AffineApplyOp::getAffineValueMap() {
528545

529546
ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
530547
auto &builder = parser.getBuilder();
531-
auto indexTy = builder.getIndexType();
532548

533549
AffineMapAttr mapAttr;
534550
unsigned numDims;
551+
SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
535552
if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
536-
parseDimAndSymbolList(parser, result.operands, numDims) ||
553+
parseDimAndSymbolListImpl(parser, opInfos, numDims) ||
537554
parser.parseOptionalAttrDict(result.attributes))
538555
return failure();
556+
557+
Type type;
558+
if (parser.parseOptionalColon()) {
559+
type = builder.getIndexType();
560+
} else if (parser.parseType(type)) {
561+
return failure();
562+
}
563+
564+
if (parser.resolveOperands(opInfos, type, result.operands))
565+
return failure();
566+
539567
auto map = mapAttr.getValue();
540568

541569
if (map.getNumDims() != numDims ||
@@ -544,7 +572,7 @@ ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
544572
"dimension or symbol index mismatch");
545573
}
546574

547-
result.types.append(map.getNumResults(), indexTy);
575+
result.types.append(map.getNumResults(), type);
548576
return success();
549577
}
550578

@@ -553,9 +581,18 @@ void AffineApplyOp::print(OpAsmPrinter &p) {
553581
printDimAndSymbolList(operand_begin(), operand_end(),
554582
getAffineMap().getNumDims(), p);
555583
p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
584+
Type resType = getType();
585+
if (!isa<IndexType>(resType))
586+
p << ":" << resType;
556587
}
557588

558589
LogicalResult AffineApplyOp::verify() {
590+
// Check all operand and result types are the same.
591+
// We cannot use `SameOperandsAndResultType` as it expects at least 1 operand.
592+
if (!llvm::all_equal(
593+
llvm::concat<Type>(getOperandTypes(), (*this)->getResultTypes())))
594+
return emitOpError("requires the same type for all operands and results");
595+
559596
// Check input and output dimensions match.
560597
AffineMap affineMap = getMap();
561598

mlir/lib/Dialect/Affine/Utils/Utils.cpp

Lines changed: 27 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -46,9 +46,9 @@ class AffineApplyExpander
4646
/// This internal class expects arguments to be non-null, checks must be
4747
/// performed at the call site.
4848
AffineApplyExpander(OpBuilder &builder, ValueRange dimValues,
49-
ValueRange symbolValues, Location loc)
49+
ValueRange symbolValues, Location loc, Type type)
5050
: builder(builder), dimValues(dimValues), symbolValues(symbolValues),
51-
loc(loc) {}
51+
loc(loc), type(type) {}
5252

5353
template <typename OpTy>
5454
Value buildBinaryExpr(AffineBinaryOpExpr expr,
@@ -189,8 +189,16 @@ class AffineApplyExpander
189189
}
190190

191191
Value visitConstantExpr(AffineConstantExpr expr) {
192-
auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue());
193-
return op.getResult();
192+
int64_t value = expr.getValue();
193+
if (isa<IndexType>(type))
194+
return builder.create<arith::ConstantIndexOp>(loc, value);
195+
196+
if (auto shaped = dyn_cast<ShapedType>(type)) {
197+
auto elements = DenseIntElementsAttr::get(shaped, value);
198+
return builder.create<arith::ConstantOp>(loc, elements);
199+
}
200+
201+
llvm_unreachable("AffineApplyExpander: Unsupported type");
194202
}
195203

196204
Value visitDimExpr(AffineDimExpr expr) {
@@ -211,6 +219,7 @@ class AffineApplyExpander
211219
ValueRange symbolValues;
212220

213221
Location loc;
222+
Type type;
214223
};
215224
} // namespace
216225

@@ -219,23 +228,28 @@ class AffineApplyExpander
219228
mlir::Value mlir::affine::expandAffineExpr(OpBuilder &builder, Location loc,
220229
AffineExpr expr,
221230
ValueRange dimValues,
222-
ValueRange symbolValues) {
223-
return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
231+
ValueRange symbolValues, Type type) {
232+
if (!type)
233+
type = builder.getIndexType();
234+
235+
return AffineApplyExpander(builder, dimValues, symbolValues, loc, type)
236+
.visit(expr);
224237
}
225238

226239
/// Create a sequence of operations that implement the `affineMap` applied to
227240
/// the given `operands` (as it it were an AffineApplyOp).
228241
std::optional<SmallVector<Value, 8>>
229242
mlir::affine::expandAffineMap(OpBuilder &builder, Location loc,
230-
AffineMap affineMap, ValueRange operands) {
243+
AffineMap affineMap, ValueRange operands,
244+
Type type) {
231245
auto numDims = affineMap.getNumDims();
232246
auto expanded = llvm::to_vector<8>(
233-
llvm::map_range(affineMap.getResults(),
234-
[numDims, &builder, loc, operands](AffineExpr expr) {
235-
return expandAffineExpr(builder, loc, expr,
236-
operands.take_front(numDims),
237-
operands.drop_front(numDims));
238-
}));
247+
llvm::map_range(affineMap.getResults(), [numDims, &builder, loc, operands,
248+
type](AffineExpr expr) {
249+
return expandAffineExpr(builder, loc, expr,
250+
operands.take_front(numDims),
251+
operands.drop_front(numDims), type);
252+
}));
239253
if (llvm::all_of(expanded, [](Value v) { return v; }))
240254
return expanded;
241255
return std::nullopt;

mlir/test/Conversion/AffineToStandard/lower-affine.mlir

Lines changed: 15 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -429,8 +429,9 @@ func.func @min_reduction_tree(%v1 : index, %v2 : index, %v3 : index, %v4 : index
429429
#map5 = affine_map<(d0,d1,d2) -> (d0,d1,d2)>
430430
#map6 = affine_map<(d0,d1,d2) -> (d0 + d1 + d2)>
431431

432-
// CHECK-LABEL: func @affine_applies(
433-
func.func @affine_applies(%arg0 : index) {
432+
// CHECK-LABEL: func @affine_applies
433+
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: vector<4xindex>)
434+
func.func @affine_applies(%arg0 : index, %arg1 : vector<4xindex>) {
434435
// CHECK: %[[c0:.*]] = arith.constant 0 : index
435436
%zero = affine.apply #map0()
436437

@@ -448,24 +449,29 @@ func.func @affine_applies(%arg0 : index) {
448449
%one = affine.apply #map3(%symbZero)[%zero]
449450

450451
// CHECK-NEXT: %[[c2:.*]] = arith.constant 2 : index
451-
// CHECK-NEXT: %[[v2:.*]] = arith.muli %arg0, %[[c2]] overflow<nsw> : index
452-
// CHECK-NEXT: %[[v3:.*]] = arith.addi %arg0, %[[v2]] : index
452+
// CHECK-NEXT: %[[v2:.*]] = arith.muli %[[ARG0]], %[[c2]] overflow<nsw> : index
453+
// CHECK-NEXT: %[[v3:.*]] = arith.addi %[[ARG0]], %[[v2]] : index
453454
// CHECK-NEXT: %[[c3:.*]] = arith.constant 3 : index
454-
// CHECK-NEXT: %[[v4:.*]] = arith.muli %arg0, %[[c3]] overflow<nsw> : index
455+
// CHECK-NEXT: %[[v4:.*]] = arith.muli %[[ARG0]], %[[c3]] overflow<nsw> : index
455456
// CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] : index
456457
// CHECK-NEXT: %[[c4:.*]] = arith.constant 4 : index
457-
// CHECK-NEXT: %[[v6:.*]] = arith.muli %arg0, %[[c4]] overflow<nsw> : index
458+
// CHECK-NEXT: %[[v6:.*]] = arith.muli %[[ARG0]], %[[c4]] overflow<nsw> : index
458459
// CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v5]], %[[v6]] : index
459460
// CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : index
460-
// CHECK-NEXT: %[[v8:.*]] = arith.muli %arg0, %[[c5]] overflow<nsw> : index
461+
// CHECK-NEXT: %[[v8:.*]] = arith.muli %[[ARG0]], %[[c5]] overflow<nsw> : index
461462
// CHECK-NEXT: %[[v9:.*]] = arith.addi %[[v7]], %[[v8]] : index
462463
// CHECK-NEXT: %[[c6:.*]] = arith.constant 6 : index
463-
// CHECK-NEXT: %[[v10:.*]] = arith.muli %arg0, %[[c6]] overflow<nsw> : index
464+
// CHECK-NEXT: %[[v10:.*]] = arith.muli %[[ARG0]], %[[c6]] overflow<nsw> : index
464465
// CHECK-NEXT: %[[v11:.*]] = arith.addi %[[v9]], %[[v10]] : index
465466
// CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
466-
// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] overflow<nsw> : index
467+
// CHECK-NEXT: %[[v12:.*]] = arith.muli %[[ARG0]], %[[c7]] overflow<nsw> : index
467468
// CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v11]], %[[v12]] : index
468469
%four = affine.apply #map4(%arg0, %arg0, %arg0, %arg0)[%arg0, %arg0, %arg0]
470+
471+
// CHECK-NEXT: %[[v14:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : vector<4xindex>
472+
// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<1> : vector<4xindex>
473+
// CHECK-NEXT: %[[v15:.*]] = arith.addi %[[v14]], %[[cst]] : vector<4xindex>
474+
%vec = affine.apply #map3(%arg1)[%arg1] : vector<4xindex>
469475
return
470476
}
471477

mlir/test/Dialect/Affine/invalid.mlir

Lines changed: 12 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@
55
func.func @affine_apply_operand_non_index(%arg0 : i32) {
66
// Custom parser automatically assigns all arguments the `index` so we must
77
// use the generic syntax here to exercise the verifier.
8-
// expected-error@+1 {{op operand #0 must be variadic of index, but got 'i32'}}
8+
// expected-error@+1 {{op operand #0 must be variadic of index-like, but got 'i32'}}
99
%0 = "affine.apply"(%arg0) {map = affine_map<(d0) -> (d0)>} : (i32) -> (index)
1010
return
1111
}
@@ -15,11 +15,21 @@ func.func @affine_apply_operand_non_index(%arg0 : i32) {
1515
func.func @affine_apply_resul_non_index(%arg0 : index) {
1616
// Custom parser automatically assigns `index` as the result type so we must
1717
// use the generic syntax here to exercise the verifier.
18-
// expected-error@+1 {{op result #0 must be index, but got 'i32'}}
18+
// expected-error@+1 {{op result #0 must be index-like, but got 'i32'}}
1919
%0 = "affine.apply"(%arg0) {map = affine_map<(d0) -> (d0)>} : (index) -> (i32)
2020
return
2121
}
2222

23+
// -----
24+
25+
func.func @affine_apply_types_match(%arg0 : index) {
26+
// We are now supporting vectors of index, but all operands and result types
27+
// must match.
28+
// expected-error@+1 {{op requires the same type for all operands and results}}
29+
%0 = "affine.apply"(%arg0) {map = affine_map<(d0) -> (d0)>} : (index) -> (vector<4xindex>)
30+
return
31+
}
32+
2333
// -----
2434
func.func @affine_load_invalid_dim(%M : memref<10xi32>) {
2535
"unknown"() ({

0 commit comments

Comments
 (0)