Skip to content

Commit d3556c9

Browse files
committed
[mlir] update create APIs (1/n)
1 parent 4775b96 commit d3556c9

File tree

4 files changed

+229
-67
lines changed

4 files changed

+229
-67
lines changed

mlir/include/mlir/Dialect/Affine/IR/AffineOps.h

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -114,6 +114,21 @@ class AffineDmaStartOp
114114
AffineMap tagMap, ValueRange tagIndices, Value numElements,
115115
Value stride = nullptr, Value elementsPerStride = nullptr);
116116

117+
static AffineDmaStartOp
118+
create(OpBuilder &builder, Location location, Value srcMemRef,
119+
AffineMap srcMap, ValueRange srcIndices, Value destMemRef,
120+
AffineMap dstMap, ValueRange destIndices, Value tagMemRef,
121+
AffineMap tagMap, ValueRange tagIndices, Value numElements,
122+
Value stride = nullptr, Value elementsPerStride = nullptr);
123+
124+
static AffineDmaStartOp create(ImplicitLocOpBuilder &builder, Value srcMemRef,
125+
AffineMap srcMap, ValueRange srcIndices,
126+
Value destMemRef, AffineMap dstMap,
127+
ValueRange destIndices, Value tagMemRef,
128+
AffineMap tagMap, ValueRange tagIndices,
129+
Value numElements, Value stride = nullptr,
130+
Value elementsPerStride = nullptr);
131+
117132
/// Returns the operand index of the source memref.
118133
unsigned getSrcMemRefOperandIndex() { return 0; }
119134

@@ -319,6 +334,12 @@ class AffineDmaWaitOp
319334

320335
static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
321336
AffineMap tagMap, ValueRange tagIndices, Value numElements);
337+
static AffineDmaWaitOp create(OpBuilder &builder, Location location,
338+
Value tagMemRef, AffineMap tagMap,
339+
ValueRange tagIndices, Value numElements);
340+
static AffineDmaWaitOp create(ImplicitLocOpBuilder &builder, Value tagMemRef,
341+
AffineMap tagMap, ValueRange tagIndices,
342+
Value numElements);
322343

323344
static StringRef getOperationName() { return "affine.dma_wait"; }
324345

mlir/include/mlir/Dialect/Arith/IR/Arith.h

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,15 +59,27 @@ class ConstantIntOp : public arith::ConstantOp {
5959
/// Build a constant int op that produces an integer of the specified width.
6060
static void build(OpBuilder &builder, OperationState &result, int64_t value,
6161
unsigned width);
62+
static ConstantIntOp create(OpBuilder &builder, Location location,
63+
int64_t value, unsigned width);
64+
static ConstantIntOp create(ImplicitLocOpBuilder &builder, int64_t value,
65+
unsigned width);
6266

6367
/// Build a constant int op that produces an integer of the specified type,
6468
/// which must be an integer type.
6569
static void build(OpBuilder &builder, OperationState &result, Type type,
6670
int64_t value);
71+
static ConstantIntOp create(OpBuilder &builder, Location location, Type type,
72+
int64_t value);
73+
static ConstantIntOp create(ImplicitLocOpBuilder &builder, Type type,
74+
int64_t value);
6775

6876
/// Build a constant int op that produces an integer from an APInt
6977
static void build(OpBuilder &builder, OperationState &result, Type type,
7078
const APInt &value);
79+
static ConstantIntOp create(OpBuilder &builder, Location location, Type type,
80+
const APInt &value);
81+
static ConstantIntOp create(ImplicitLocOpBuilder &builder, Type type,
82+
const APInt &value);
7183

7284
inline int64_t value() {
7385
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
@@ -85,6 +97,10 @@ class ConstantFloatOp : public arith::ConstantOp {
8597
/// Build a constant float op that produces a float of the specified type.
8698
static void build(OpBuilder &builder, OperationState &result, FloatType type,
8799
const APFloat &value);
100+
static ConstantFloatOp create(OpBuilder &builder, Location location,
101+
FloatType type, const APFloat &value);
102+
static ConstantFloatOp create(ImplicitLocOpBuilder &builder, FloatType type,
103+
const APFloat &value);
88104

89105
inline APFloat value() {
90106
return cast<FloatAttr>(arith::ConstantOp::getValue()).getValue();
@@ -100,6 +116,9 @@ class ConstantIndexOp : public arith::ConstantOp {
100116
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
101117
/// Build a constant int op that produces an index.
102118
static void build(OpBuilder &builder, OperationState &result, int64_t value);
119+
static ConstantIndexOp create(OpBuilder &builder, Location location,
120+
int64_t value);
121+
static ConstantIndexOp create(ImplicitLocOpBuilder &builder, int64_t value);
103122

104123
inline int64_t value() {
105124
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();

mlir/lib/Dialect/Affine/IR/AffineOps.cpp

Lines changed: 79 additions & 36 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,7 @@ void AffineDialect::initialize() {
228228
addOperations<AffineDmaStartOp, AffineDmaWaitOp,
229229
#define GET_OP_LIST
230230
#include "mlir/Dialect/Affine/IR/AffineOps.cpp.inc"
231+
231232
>();
232233
addInterfaces<AffineInlinerInterface>();
233234
declarePromisedInterfaces<ValueBoundsOpInterface, AffineApplyOp, AffineMaxOp,
@@ -240,7 +241,7 @@ Operation *AffineDialect::materializeConstant(OpBuilder &builder,
240241
Attribute value, Type type,
241242
Location loc) {
242243
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
243-
return builder.create<ub::PoisonOp>(loc, type, poison);
244+
return ub::PoisonOp::create(builder, loc, type, poison);
244245
return arith::ConstantOp::materialize(builder, value, type, loc);
245246
}
246247

@@ -1282,7 +1283,7 @@ mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
12821283
map = foldAttributesIntoMap(b, map, operands, valueOperands);
12831284
composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin);
12841285
assert(map);
1285-
return b.create<AffineApplyOp>(loc, map, valueOperands);
1286+
return AffineApplyOp::create(b, loc, map, valueOperands);
12861287
}
12871288

12881289
AffineApplyOp
@@ -1389,7 +1390,7 @@ static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map,
13891390
SmallVector<Value> valueOperands;
13901391
map = foldAttributesIntoMap(b, map, operands, valueOperands);
13911392
composeMultiResultAffineMap(map, valueOperands);
1392-
return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
1393+
return OpTy::create(b, loc, b.getIndexType(), map, valueOperands);
13931394
}
13941395

13951396
AffineMinOp
@@ -1747,6 +1748,32 @@ void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
17471748
}
17481749
}
17491750

1751+
AffineDmaStartOp AffineDmaStartOp::create(
1752+
OpBuilder &builder, Location location, Value srcMemRef, AffineMap srcMap,
1753+
ValueRange srcIndices, Value destMemRef, AffineMap dstMap,
1754+
ValueRange destIndices, Value tagMemRef, AffineMap tagMap,
1755+
ValueRange tagIndices, Value numElements, Value stride,
1756+
Value elementsPerStride) {
1757+
mlir::OperationState state(location, getOperationName());
1758+
build(builder, state, srcMemRef, srcMap, srcIndices, destMemRef, dstMap,
1759+
destIndices, tagMemRef, tagMap, tagIndices, numElements, stride,
1760+
elementsPerStride);
1761+
auto result = llvm::dyn_cast<AffineDmaStartOp>(builder.create(state));
1762+
assert(result && "builder didn't return the right type");
1763+
return result;
1764+
}
1765+
1766+
AffineDmaStartOp AffineDmaStartOp::create(
1767+
ImplicitLocOpBuilder &builder, Value srcMemRef, AffineMap srcMap,
1768+
ValueRange srcIndices, Value destMemRef, AffineMap dstMap,
1769+
ValueRange destIndices, Value tagMemRef, AffineMap tagMap,
1770+
ValueRange tagIndices, Value numElements, Value stride,
1771+
Value elementsPerStride) {
1772+
return create(builder, builder.getLoc(), srcMemRef, srcMap, srcIndices,
1773+
destMemRef, dstMap, destIndices, tagMemRef, tagMap, tagIndices,
1774+
numElements, stride, elementsPerStride);
1775+
}
1776+
17501777
void AffineDmaStartOp::print(OpAsmPrinter &p) {
17511778
p << " " << getSrcMemRef() << '[';
17521779
p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
@@ -1917,6 +1944,25 @@ void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
19171944
result.addOperands(numElements);
19181945
}
19191946

1947+
AffineDmaWaitOp AffineDmaWaitOp::create(OpBuilder &builder, Location location,
1948+
Value tagMemRef, AffineMap tagMap,
1949+
ValueRange tagIndices,
1950+
Value numElements) {
1951+
mlir::OperationState state(location, getOperationName());
1952+
build(builder, state, tagMemRef, tagMap, tagIndices, numElements);
1953+
auto result = llvm::dyn_cast<AffineDmaWaitOp>(builder.create(state));
1954+
assert(result && "builder didn't return the right type");
1955+
return result;
1956+
}
1957+
1958+
AffineDmaWaitOp AffineDmaWaitOp::create(ImplicitLocOpBuilder &builder,
1959+
Value tagMemRef, AffineMap tagMap,
1960+
ValueRange tagIndices,
1961+
Value numElements) {
1962+
return create(builder, builder.getLoc(), tagMemRef, tagMap, tagIndices,
1963+
numElements);
1964+
}
1965+
19201966
void AffineDmaWaitOp::print(OpAsmPrinter &p) {
19211967
p << " " << getTagMemRef() << '[';
19221968
SmallVector<Value, 2> operands(getTagIndices());
@@ -2688,8 +2734,8 @@ FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
26882734
rewriter.setInsertionPoint(getOperation());
26892735
auto inits = llvm::to_vector(getInits());
26902736
inits.append(newInitOperands.begin(), newInitOperands.end());
2691-
AffineForOp newLoop = rewriter.create<AffineForOp>(
2692-
getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
2737+
AffineForOp newLoop = AffineForOp::create(
2738+
rewriter, getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
26932739
getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);
26942740

26952741
// Generate the new yield values and append them to the scf.yield operation.
@@ -2831,7 +2877,7 @@ static void buildAffineLoopNestImpl(
28312877
OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
28322878
bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
28332879
}
2834-
nestedBuilder.create<AffineYieldOp>(nestedLoc);
2880+
AffineYieldOp::create(nestedBuilder, nestedLoc);
28352881
};
28362882

28372883
// Delegate actual loop creation to the callback in order to dispatch
@@ -2846,8 +2892,8 @@ static AffineForOp
28462892
buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
28472893
int64_t ub, int64_t step,
28482894
AffineForOp::BodyBuilderFn bodyBuilderFn) {
2849-
return builder.create<AffineForOp>(loc, lb, ub, step,
2850-
/*iterArgs=*/ValueRange(), bodyBuilderFn);
2895+
return AffineForOp::create(builder, loc, lb, ub, step,
2896+
/*iterArgs=*/ValueRange(), bodyBuilderFn);
28512897
}
28522898

28532899
/// Creates an affine loop from the bounds that may or may not be constants.
@@ -2860,9 +2906,9 @@ buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
28602906
if (lbConst && ubConst)
28612907
return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
28622908
ubConst.value(), step, bodyBuilderFn);
2863-
return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
2864-
builder.getDimIdentityMap(), step,
2865-
/*iterArgs=*/ValueRange(), bodyBuilderFn);
2909+
return AffineForOp::create(builder, loc, lb, builder.getDimIdentityMap(), ub,
2910+
builder.getDimIdentityMap(), step,
2911+
/*iterArgs=*/ValueRange(), bodyBuilderFn);
28662912
}
28672913

28682914
void mlir::affine::buildAffineLoopNest(
@@ -4883,7 +4929,7 @@ struct DropUnitExtentBasis
48834929
Location loc = delinearizeOp->getLoc();
48844930
auto getZero = [&]() -> Value {
48854931
if (!zero)
4886-
zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
4932+
zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
48874933
return zero.value();
48884934
};
48894935

@@ -4906,8 +4952,8 @@ struct DropUnitExtentBasis
49064952

49074953
if (!newBasis.empty()) {
49084954
// Will drop the leading nullptr from `basis` if there was no outer bound.
4909-
auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
4910-
loc, delinearizeOp.getLinearIndex(), newBasis);
4955+
auto newDelinearizeOp = affine::AffineDelinearizeIndexOp::create(
4956+
rewriter, loc, delinearizeOp.getLinearIndex(), newBasis);
49114957
int newIndex = 0;
49124958
// Map back the new delinearized indices to the values they replace.
49134959
for (auto &replacement : replacements) {
@@ -4971,12 +5017,12 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail
49715017
return success();
49725018
}
49735019

4974-
Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
4975-
linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
5020+
Value newLinearize = affine::AffineLinearizeIndexOp::create(
5021+
rewriter, linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
49765022
ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
49775023
linearizeOp.getDisjoint());
4978-
auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
4979-
delinearizeOp.getLoc(), newLinearize,
5024+
auto newDelinearize = affine::AffineDelinearizeIndexOp::create(
5025+
rewriter, delinearizeOp.getLoc(), newLinearize,
49805026
ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
49815027
delinearizeOp.hasOuterBound());
49825028
SmallVector<Value> mergedResults(newDelinearize.getResults());
@@ -5048,19 +5094,16 @@ struct SplitDelinearizeSpanningLastLinearizeArg final
50485094
delinearizeOp,
50495095
"need at least two elements to form the basis product");
50505096

5051-
Value linearizeWithoutBack =
5052-
rewriter.create<affine::AffineLinearizeIndexOp>(
5053-
linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
5054-
linearizeOp.getDynamicBasis(),
5055-
linearizeOp.getStaticBasis().drop_back(),
5056-
linearizeOp.getDisjoint());
5057-
auto delinearizeWithoutSplitPart =
5058-
rewriter.create<affine::AffineDelinearizeIndexOp>(
5059-
delinearizeOp.getLoc(), linearizeWithoutBack,
5060-
delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
5061-
delinearizeOp.hasOuterBound());
5062-
auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
5063-
delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
5097+
Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create(
5098+
rewriter, linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
5099+
linearizeOp.getDynamicBasis(), linearizeOp.getStaticBasis().drop_back(),
5100+
linearizeOp.getDisjoint());
5101+
auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create(
5102+
rewriter, delinearizeOp.getLoc(), linearizeWithoutBack,
5103+
delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
5104+
delinearizeOp.hasOuterBound());
5105+
auto delinearizeBack = affine::AffineDelinearizeIndexOp::create(
5106+
rewriter, delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
50645107
basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
50655108
SmallVector<Value> results = llvm::to_vector(
50665109
llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
@@ -5272,7 +5315,7 @@ OpFoldResult computeProduct(Location loc, OpBuilder &builder,
52725315
}
52735316
if (auto constant = dyn_cast<AffineConstantExpr>(result))
52745317
return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
5275-
return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
5318+
return AffineApplyOp::create(builder, loc, result, dynamicPart).getResult();
52765319
}
52775320

52785321
/// If conseceutive outputs of a delinearize_index are linearized with the same
@@ -5437,16 +5480,16 @@ struct CancelLinearizeOfDelinearizePortion final
54375480
newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
54385481
newDelinBasis.begin() + m.delinStart + m.length);
54395482
newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
5440-
auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5441-
m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
5483+
auto newDelinearize = AffineDelinearizeIndexOp::create(
5484+
rewriter, m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
54425485
newDelinBasis);
54435486

54445487
// Since there may be other uses of the indices we just merged together,
54455488
// create a residual affine.delinearize_index that delinearizes the
54465489
// merged output into its component parts.
54475490
Value combinedElem = newDelinearize.getResult(m.delinStart);
5448-
auto residualDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
5449-
m.delinearize.getLoc(), combinedElem, basisToMerge);
5491+
auto residualDelinearize = AffineDelinearizeIndexOp::create(
5492+
rewriter, m.delinearize.getLoc(), combinedElem, basisToMerge);
54505493

54515494
// Swap all the uses of the unaffected delinearize outputs to the new
54525495
// delinearization so that the old code can be removed if this

0 commit comments

Comments
 (0)