Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 9 additions & 4 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.td
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,13 @@
#ifndef AFFINE_OPS
#define AFFINE_OPS

include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Dialect/Affine/IR/AffineMemoryOpInterfaces.td"
include "mlir/Dialect/Arith/IR/ArithBase.td"
include "mlir/Interfaces/ControlFlowInterfaces.td"
include "mlir/Interfaces/InferTypeOpInterface.td"
include "mlir/Interfaces/LoopLikeInterface.td"
include "mlir/Interfaces/SideEffectInterfaces.td"
include "mlir/IR/CommonTypeConstraints.td"

def Affine_Dialect : Dialect {
let name = "affine";
Expand Down Expand Up @@ -57,18 +58,22 @@ def AffineApplyOp : Affine_Op<"apply", [Pure]> {
%2 = affine.apply affine_map<(i)[s0] -> (i+s0)> (%42)[%n]
```
}];
let arguments = (ins AffineMapAttr:$map, Variadic<Index>:$mapOperands);
let results = (outs Index);
let arguments = (ins AffineMapAttr:$map, Variadic<IndexLike>:$mapOperands);
let results = (outs IndexLike);

// TODO: The auto-generated builders should check to see if the return type
// has a constant builder. That way we wouldn't need to explicitly specify the
// result types here.
let builders = [
OpBuilder<(ins "ArrayRef<AffineExpr> ":$exprList,"ValueRange":$mapOperands),
OpBuilder<(ins "ArrayRef<AffineExpr>":$exprList,"ValueRange":$mapOperands),
[{
build($_builder, $_state, $_builder.getIndexType(),
AffineMap::inferFromExprList(exprList, $_builder.getContext())
.front(), mapOperands);
}]>,
OpBuilder<(ins "AffineMap":$map,"ValueRange":$mapOperands),
[{
build($_builder, $_state, $_builder.getIndexType(), map, mapOperands);
}]>
];

Expand Down
10 changes: 5 additions & 5 deletions mlir/include/mlir/Dialect/Affine/Utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -294,14 +294,14 @@ void createAffineComputationSlice(Operation *opInst,
/// Emit code that computes the given affine expression using standard
/// arithmetic operations applied to the provided dimension and symbol values.
Value expandAffineExpr(OpBuilder &builder, Location loc, AffineExpr expr,
ValueRange dimValues, ValueRange symbolValues);
ValueRange dimValues, ValueRange symbolValues,
Type type = {});

/// Create a sequence of operations that implement the `affineMap` applied to
/// the given `operands` (as it it were an AffineApplyOp).
std::optional<SmallVector<Value, 8>> expandAffineMap(OpBuilder &builder,
Location loc,
AffineMap affineMap,
ValueRange operands);
std::optional<SmallVector<Value, 8>>
expandAffineMap(OpBuilder &builder, Location loc, AffineMap affineMap,
ValueRange operands, Type type = {});

/// Holds the result of (div a, b) and (mod a, b).
struct DivModValue {
Expand Down
4 changes: 4 additions & 0 deletions mlir/include/mlir/IR/CommonTypeConstraints.td
Original file line number Diff line number Diff line change
Expand Up @@ -892,6 +892,10 @@ class TypeOrValueSemanticsContainer<Type allowedType, string name>
// bools.
def BoolLike : TypeOrValueSemanticsContainer<I1, "bool-like">;

// Type constraint for index-like types: index, vectors of index, tensors of
// index.
def IndexLike : TypeOrValueSemanticsContainer<Index, "index-like">;

// Type constraint for signless-integer-like types: signless integers or
// value-semantics containers of signless integers.
def SignlessIntegerLike : TypeOrValueSemanticsContainer<
Expand Down
2 changes: 1 addition & 1 deletion mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ class AffineApplyLowering : public OpRewritePattern<AffineApplyOp> {
PatternRewriter &rewriter) const override {
auto maybeExpandedMap =
expandAffineMap(rewriter, op.getLoc(), op.getAffineMap(),
llvm::to_vector<8>(op.getOperands()));
llvm::to_vector<8>(op.getOperands()), op.getType());
if (!maybeExpandedMap)
return failure();
rewriter.replaceOp(op, *maybeExpandedMap);
Expand Down
57 changes: 47 additions & 10 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -491,20 +491,37 @@ static void printDimAndSymbolList(Operation::operand_iterator begin,
printer << '[' << operands.drop_front(numDims) << ']';
}

/// Parses dimension and symbol list and returns true if parsing failed.
ParseResult mlir::affine::parseDimAndSymbolList(
OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
/// Parse dimension and symbol list, but not resolve yet, as we may not know the
/// operands types.
static ParseResult parseDimAndSymbolListImpl(
OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &opInfos,
unsigned &numDims) {
if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::Paren))
return failure();

// Store number of dimensions for validation by caller.
numDims = opInfos.size();

// Parse the optional symbol operands.
if (parser.parseOperandList(opInfos, OpAsmParser::Delimiter::OptionalSquare))
return failure();

return success();
}

/// Parses dimension and symbol list and returns true if parsing failed.
ParseResult mlir::affine::parseDimAndSymbolList(
OpAsmParser &parser, SmallVectorImpl<Value> &operands, unsigned &numDims) {
SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
if (parseDimAndSymbolListImpl(parser, opInfos, numDims))
return failure();

auto indexTy = parser.getBuilder().getIndexType();
return failure(parser.parseOperandList(
opInfos, OpAsmParser::Delimiter::OptionalSquare) ||
parser.resolveOperands(opInfos, indexTy, operands));
if (parser.resolveOperands(opInfos, indexTy, operands))
return failure();

return success();
}

/// Utility function to verify that a set of operands are valid dimension and
Expand Down Expand Up @@ -538,14 +555,25 @@ AffineValueMap AffineApplyOp::getAffineValueMap() {

ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
auto &builder = parser.getBuilder();
auto indexTy = builder.getIndexType();

AffineMapAttr mapAttr;
unsigned numDims;
SmallVector<OpAsmParser::UnresolvedOperand, 8> opInfos;
if (parser.parseAttribute(mapAttr, "map", result.attributes) ||
parseDimAndSymbolList(parser, result.operands, numDims) ||
parseDimAndSymbolListImpl(parser, opInfos, numDims) ||
parser.parseOptionalAttrDict(result.attributes))
return failure();

Type type;
if (parser.parseOptionalColon()) {
type = builder.getIndexType();
} else if (parser.parseType(type)) {
return failure();
}

if (parser.resolveOperands(opInfos, type, result.operands))
return failure();

auto map = mapAttr.getValue();

if (map.getNumDims() != numDims ||
Expand All @@ -554,7 +582,7 @@ ParseResult AffineApplyOp::parse(OpAsmParser &parser, OperationState &result) {
"dimension or symbol index mismatch");
}

result.types.append(map.getNumResults(), indexTy);
result.types.append(map.getNumResults(), type);
return success();
}

Expand All @@ -563,9 +591,18 @@ void AffineApplyOp::print(OpAsmPrinter &p) {
printDimAndSymbolList(operand_begin(), operand_end(),
getAffineMap().getNumDims(), p);
p.printOptionalAttrDict((*this)->getAttrs(), /*elidedAttrs=*/{"map"});
Type resType = getType();
if (!isa<IndexType>(resType))
p << ":" << resType;
}

LogicalResult AffineApplyOp::verify() {
// Check all operand and result types are the same.
// We cannot use `SameOperandsAndResultType` as it expects at least 1 operand.
if (!llvm::all_equal(
llvm::concat<Type>(getOperandTypes(), (*this)->getResultTypes())))
return emitOpError("requires the same type for all operands and results");

// Check input and output dimensions match.
AffineMap affineMap = getMap();

Expand Down
40 changes: 27 additions & 13 deletions mlir/lib/Dialect/Affine/Utils/Utils.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -46,9 +46,9 @@ class AffineApplyExpander
/// This internal class expects arguments to be non-null, checks must be
/// performed at the call site.
AffineApplyExpander(OpBuilder &builder, ValueRange dimValues,
ValueRange symbolValues, Location loc)
ValueRange symbolValues, Location loc, Type type)
: builder(builder), dimValues(dimValues), symbolValues(symbolValues),
loc(loc) {}
loc(loc), type(type) {}

template <typename OpTy>
Value buildBinaryExpr(AffineBinaryOpExpr expr,
Expand Down Expand Up @@ -189,8 +189,16 @@ class AffineApplyExpander
}

Value visitConstantExpr(AffineConstantExpr expr) {
auto op = builder.create<arith::ConstantIndexOp>(loc, expr.getValue());
return op.getResult();
int64_t value = expr.getValue();
if (isa<IndexType>(type))
return builder.create<arith::ConstantIndexOp>(loc, value);

if (auto shaped = dyn_cast<ShapedType>(type)) {
auto elements = DenseIntElementsAttr::get(shaped, value);
return builder.create<arith::ConstantOp>(loc, elements);
}

llvm_unreachable("AffineApplyExpander: Unsupported type");
}

Value visitDimExpr(AffineDimExpr expr) {
Expand All @@ -211,6 +219,7 @@ class AffineApplyExpander
ValueRange symbolValues;

Location loc;
Type type;
};
} // namespace

Expand All @@ -219,23 +228,28 @@ class AffineApplyExpander
mlir::Value mlir::affine::expandAffineExpr(OpBuilder &builder, Location loc,
AffineExpr expr,
ValueRange dimValues,
ValueRange symbolValues) {
return AffineApplyExpander(builder, dimValues, symbolValues, loc).visit(expr);
ValueRange symbolValues, Type type) {
if (!type)
type = builder.getIndexType();

return AffineApplyExpander(builder, dimValues, symbolValues, loc, type)
.visit(expr);
}

/// Create a sequence of operations that implement the `affineMap` applied to
/// the given `operands` (as it it were an AffineApplyOp).
std::optional<SmallVector<Value, 8>>
mlir::affine::expandAffineMap(OpBuilder &builder, Location loc,
AffineMap affineMap, ValueRange operands) {
AffineMap affineMap, ValueRange operands,
Type type) {
auto numDims = affineMap.getNumDims();
auto expanded = llvm::to_vector<8>(
llvm::map_range(affineMap.getResults(),
[numDims, &builder, loc, operands](AffineExpr expr) {
return expandAffineExpr(builder, loc, expr,
operands.take_front(numDims),
operands.drop_front(numDims));
}));
llvm::map_range(affineMap.getResults(), [numDims, &builder, loc, operands,
type](AffineExpr expr) {
return expandAffineExpr(builder, loc, expr,
operands.take_front(numDims),
operands.drop_front(numDims), type);
}));
if (llvm::all_of(expanded, [](Value v) { return v; }))
return expanded;
return std::nullopt;
Expand Down
24 changes: 15 additions & 9 deletions mlir/test/Conversion/AffineToStandard/lower-affine.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -429,8 +429,9 @@ func.func @min_reduction_tree(%v1 : index, %v2 : index, %v3 : index, %v4 : index
#map5 = affine_map<(d0,d1,d2) -> (d0,d1,d2)>
#map6 = affine_map<(d0,d1,d2) -> (d0 + d1 + d2)>

// CHECK-LABEL: func @affine_applies(
func.func @affine_applies(%arg0 : index) {
// CHECK-LABEL: func @affine_applies
// CHECK-SAME: (%[[ARG0:.*]]: index, %[[ARG1:.*]]: vector<4xindex>)
func.func @affine_applies(%arg0 : index, %arg1 : vector<4xindex>) {
// CHECK: %[[c0:.*]] = arith.constant 0 : index
%zero = affine.apply #map0()

Expand All @@ -448,24 +449,29 @@ func.func @affine_applies(%arg0 : index) {
%one = affine.apply #map3(%symbZero)[%zero]

// CHECK-NEXT: %[[c2:.*]] = arith.constant 2 : index
// CHECK-NEXT: %[[v2:.*]] = arith.muli %arg0, %[[c2]] overflow<nsw> : index
// CHECK-NEXT: %[[v3:.*]] = arith.addi %arg0, %[[v2]] : index
// CHECK-NEXT: %[[v2:.*]] = arith.muli %[[ARG0]], %[[c2]] overflow<nsw> : index
// CHECK-NEXT: %[[v3:.*]] = arith.addi %[[ARG0]], %[[v2]] : index
// CHECK-NEXT: %[[c3:.*]] = arith.constant 3 : index
// CHECK-NEXT: %[[v4:.*]] = arith.muli %arg0, %[[c3]] overflow<nsw> : index
// CHECK-NEXT: %[[v4:.*]] = arith.muli %[[ARG0]], %[[c3]] overflow<nsw> : index
// CHECK-NEXT: %[[v5:.*]] = arith.addi %[[v3]], %[[v4]] : index
// CHECK-NEXT: %[[c4:.*]] = arith.constant 4 : index
// CHECK-NEXT: %[[v6:.*]] = arith.muli %arg0, %[[c4]] overflow<nsw> : index
// CHECK-NEXT: %[[v6:.*]] = arith.muli %[[ARG0]], %[[c4]] overflow<nsw> : index
// CHECK-NEXT: %[[v7:.*]] = arith.addi %[[v5]], %[[v6]] : index
// CHECK-NEXT: %[[c5:.*]] = arith.constant 5 : index
// CHECK-NEXT: %[[v8:.*]] = arith.muli %arg0, %[[c5]] overflow<nsw> : index
// CHECK-NEXT: %[[v8:.*]] = arith.muli %[[ARG0]], %[[c5]] overflow<nsw> : index
// CHECK-NEXT: %[[v9:.*]] = arith.addi %[[v7]], %[[v8]] : index
// CHECK-NEXT: %[[c6:.*]] = arith.constant 6 : index
// CHECK-NEXT: %[[v10:.*]] = arith.muli %arg0, %[[c6]] overflow<nsw> : index
// CHECK-NEXT: %[[v10:.*]] = arith.muli %[[ARG0]], %[[c6]] overflow<nsw> : index
// CHECK-NEXT: %[[v11:.*]] = arith.addi %[[v9]], %[[v10]] : index
// CHECK-NEXT: %[[c7:.*]] = arith.constant 7 : index
// CHECK-NEXT: %[[v12:.*]] = arith.muli %arg0, %[[c7]] overflow<nsw> : index
// CHECK-NEXT: %[[v12:.*]] = arith.muli %[[ARG0]], %[[c7]] overflow<nsw> : index
// CHECK-NEXT: %[[v13:.*]] = arith.addi %[[v11]], %[[v12]] : index
%four = affine.apply #map4(%arg0, %arg0, %arg0, %arg0)[%arg0, %arg0, %arg0]

// CHECK-NEXT: %[[v14:.*]] = arith.addi %[[ARG1]], %[[ARG1]] : vector<4xindex>
// CHECK-NEXT: %[[cst:.*]] = arith.constant dense<1> : vector<4xindex>
// CHECK-NEXT: %[[v15:.*]] = arith.addi %[[v14]], %[[cst]] : vector<4xindex>
%vec = affine.apply #map3(%arg1)[%arg1] : vector<4xindex>
return
}

Expand Down
14 changes: 12 additions & 2 deletions mlir/test/Dialect/Affine/invalid.mlir
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
func.func @affine_apply_operand_non_index(%arg0 : i32) {
// Custom parser automatically assigns all arguments the `index` so we must
// use the generic syntax here to exercise the verifier.
// expected-error@+1 {{op operand #0 must be variadic of index, but got 'i32'}}
// expected-error@+1 {{op operand #0 must be variadic of index-like, but got 'i32'}}
%0 = "affine.apply"(%arg0) {map = affine_map<(d0) -> (d0)>} : (i32) -> (index)
return
}
Expand All @@ -15,11 +15,21 @@ func.func @affine_apply_operand_non_index(%arg0 : i32) {
func.func @affine_apply_resul_non_index(%arg0 : index) {
// Custom parser automatically assigns `index` as the result type so we must
// use the generic syntax here to exercise the verifier.
// expected-error@+1 {{op result #0 must be index, but got 'i32'}}
// expected-error@+1 {{op result #0 must be index-like, but got 'i32'}}
%0 = "affine.apply"(%arg0) {map = affine_map<(d0) -> (d0)>} : (index) -> (i32)
return
}

// -----

func.func @affine_apply_types_match(%arg0 : index) {
// We are now supporting vectors of index, but all operands and result types
// must match.
// expected-error@+1 {{op requires the same type for all operands and results}}
%0 = "affine.apply"(%arg0) {map = affine_map<(d0) -> (d0)>} : (index) -> (vector<4xindex>)
return
}

// -----
func.func @affine_load_invalid_dim(%M : memref<10xi32>) {
"unknown"() ({
Expand Down
Loading