Skip to content

[mlir] update affine+arith create APIs (1/n) #149656

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Jul 19, 2025
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
21 changes: 21 additions & 0 deletions mlir/include/mlir/Dialect/Affine/IR/AffineOps.h
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,21 @@ class AffineDmaStartOp
AffineMap tagMap, ValueRange tagIndices, Value numElements,
Value stride = nullptr, Value elementsPerStride = nullptr);

static AffineDmaStartOp
create(OpBuilder &builder, Location location, Value srcMemRef,
AffineMap srcMap, ValueRange srcIndices, Value destMemRef,
AffineMap dstMap, ValueRange destIndices, Value tagMemRef,
AffineMap tagMap, ValueRange tagIndices, Value numElements,
Value stride = nullptr, Value elementsPerStride = nullptr);

static AffineDmaStartOp create(ImplicitLocOpBuilder &builder, Value srcMemRef,
AffineMap srcMap, ValueRange srcIndices,
Value destMemRef, AffineMap dstMap,
ValueRange destIndices, Value tagMemRef,
AffineMap tagMap, ValueRange tagIndices,
Value numElements, Value stride = nullptr,
Value elementsPerStride = nullptr);

/// Returns the operand index of the source memref.
unsigned getSrcMemRefOperandIndex() { return 0; }

Expand Down Expand Up @@ -319,6 +334,12 @@ class AffineDmaWaitOp

static void build(OpBuilder &builder, OperationState &result, Value tagMemRef,
AffineMap tagMap, ValueRange tagIndices, Value numElements);
static AffineDmaWaitOp create(OpBuilder &builder, Location location,
Value tagMemRef, AffineMap tagMap,
ValueRange tagIndices, Value numElements);
static AffineDmaWaitOp create(ImplicitLocOpBuilder &builder, Value tagMemRef,
AffineMap tagMap, ValueRange tagIndices,
Value numElements);

static StringRef getOperationName() { return "affine.dma_wait"; }

Expand Down
19 changes: 19 additions & 0 deletions mlir/include/mlir/Dialect/Arith/IR/Arith.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,15 +59,27 @@ class ConstantIntOp : public arith::ConstantOp {
/// Build a constant int op that produces an integer of the specified width.
static void build(OpBuilder &builder, OperationState &result, int64_t value,
unsigned width);
static ConstantIntOp create(OpBuilder &builder, Location location,
int64_t value, unsigned width);
static ConstantIntOp create(ImplicitLocOpBuilder &builder, int64_t value,
unsigned width);

/// Build a constant int op that produces an integer of the specified type,
/// which must be an integer type.
static void build(OpBuilder &builder, OperationState &result, Type type,
int64_t value);
static ConstantIntOp create(OpBuilder &builder, Location location, Type type,
int64_t value);
static ConstantIntOp create(ImplicitLocOpBuilder &builder, Type type,
int64_t value);

/// Build a constant int op that produces an integer from an APInt
static void build(OpBuilder &builder, OperationState &result, Type type,
const APInt &value);
static ConstantIntOp create(OpBuilder &builder, Location location, Type type,
const APInt &value);
static ConstantIntOp create(ImplicitLocOpBuilder &builder, Type type,
const APInt &value);

inline int64_t value() {
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
Expand All @@ -85,6 +97,10 @@ class ConstantFloatOp : public arith::ConstantOp {
/// Build a constant float op that produces a float of the specified type.
static void build(OpBuilder &builder, OperationState &result, FloatType type,
const APFloat &value);
static ConstantFloatOp create(OpBuilder &builder, Location location,
FloatType type, const APFloat &value);
static ConstantFloatOp create(ImplicitLocOpBuilder &builder, FloatType type,
const APFloat &value);

inline APFloat value() {
return cast<FloatAttr>(arith::ConstantOp::getValue()).getValue();
Expand All @@ -100,6 +116,9 @@ class ConstantIndexOp : public arith::ConstantOp {
static ::mlir::TypeID resolveTypeID() { return TypeID::get<ConstantOp>(); }
/// Build a constant int op that produces an index.
static void build(OpBuilder &builder, OperationState &result, int64_t value);
static ConstantIndexOp create(OpBuilder &builder, Location location,
int64_t value);
static ConstantIndexOp create(ImplicitLocOpBuilder &builder, int64_t value);

inline int64_t value() {
return cast<IntegerAttr>(arith::ConstantOp::getValue()).getInt();
Expand Down
114 changes: 78 additions & 36 deletions mlir/lib/Dialect/Affine/IR/AffineOps.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -240,7 +240,7 @@ Operation *AffineDialect::materializeConstant(OpBuilder &builder,
Attribute value, Type type,
Location loc) {
if (auto poison = dyn_cast<ub::PoisonAttr>(value))
return builder.create<ub::PoisonOp>(loc, type, poison);
return ub::PoisonOp::create(builder, loc, type, poison);
return arith::ConstantOp::materialize(builder, value, type, loc);
}

Expand Down Expand Up @@ -1282,7 +1282,7 @@ mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map,
map = foldAttributesIntoMap(b, map, operands, valueOperands);
composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin);
assert(map);
return b.create<AffineApplyOp>(loc, map, valueOperands);
return AffineApplyOp::create(b, loc, map, valueOperands);
}

AffineApplyOp
Expand Down Expand Up @@ -1389,7 +1389,7 @@ static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map,
SmallVector<Value> valueOperands;
map = foldAttributesIntoMap(b, map, operands, valueOperands);
composeMultiResultAffineMap(map, valueOperands);
return b.create<OpTy>(loc, b.getIndexType(), map, valueOperands);
return OpTy::create(b, loc, b.getIndexType(), map, valueOperands);
}

AffineMinOp
Expand Down Expand Up @@ -1747,6 +1747,32 @@ void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result,
}
}

AffineDmaStartOp AffineDmaStartOp::create(
OpBuilder &builder, Location location, Value srcMemRef, AffineMap srcMap,
ValueRange srcIndices, Value destMemRef, AffineMap dstMap,
ValueRange destIndices, Value tagMemRef, AffineMap tagMap,
ValueRange tagIndices, Value numElements, Value stride,
Value elementsPerStride) {
mlir::OperationState state(location, getOperationName());
build(builder, state, srcMemRef, srcMap, srcIndices, destMemRef, dstMap,
destIndices, tagMemRef, tagMap, tagIndices, numElements, stride,
elementsPerStride);
auto result = llvm::dyn_cast<AffineDmaStartOp>(builder.create(state));
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: I don't think we need these llvm:: namespaces being explicit in front of casts. Also everywhere else.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed in both files

assert(result && "builder didn't return the right type");
return result;
}

AffineDmaStartOp AffineDmaStartOp::create(
ImplicitLocOpBuilder &builder, Value srcMemRef, AffineMap srcMap,
ValueRange srcIndices, Value destMemRef, AffineMap dstMap,
ValueRange destIndices, Value tagMemRef, AffineMap tagMap,
ValueRange tagIndices, Value numElements, Value stride,
Value elementsPerStride) {
return create(builder, builder.getLoc(), srcMemRef, srcMap, srcIndices,
destMemRef, dstMap, destIndices, tagMemRef, tagMap, tagIndices,
numElements, stride, elementsPerStride);
}

void AffineDmaStartOp::print(OpAsmPrinter &p) {
p << " " << getSrcMemRef() << '[';
p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices());
Expand Down Expand Up @@ -1917,6 +1943,25 @@ void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result,
result.addOperands(numElements);
}

AffineDmaWaitOp AffineDmaWaitOp::create(OpBuilder &builder, Location location,
Value tagMemRef, AffineMap tagMap,
ValueRange tagIndices,
Value numElements) {
mlir::OperationState state(location, getOperationName());
build(builder, state, tagMemRef, tagMap, tagIndices, numElements);
auto result = llvm::dyn_cast<AffineDmaWaitOp>(builder.create(state));
assert(result && "builder didn't return the right type");
return result;
}

AffineDmaWaitOp AffineDmaWaitOp::create(ImplicitLocOpBuilder &builder,
Value tagMemRef, AffineMap tagMap,
ValueRange tagIndices,
Value numElements) {
return create(builder, builder.getLoc(), tagMemRef, tagMap, tagIndices,
numElements);
}

void AffineDmaWaitOp::print(OpAsmPrinter &p) {
p << " " << getTagMemRef() << '[';
SmallVector<Value, 2> operands(getTagIndices());
Expand Down Expand Up @@ -2688,8 +2733,8 @@ FailureOr<LoopLikeOpInterface> AffineForOp::replaceWithAdditionalYields(
rewriter.setInsertionPoint(getOperation());
auto inits = llvm::to_vector(getInits());
inits.append(newInitOperands.begin(), newInitOperands.end());
AffineForOp newLoop = rewriter.create<AffineForOp>(
getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
AffineForOp newLoop = AffineForOp::create(
rewriter, getLoc(), getLowerBoundOperands(), getLowerBoundMap(),
getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits);

// Generate the new yield values and append them to the scf.yield operation.
Expand Down Expand Up @@ -2831,7 +2876,7 @@ static void buildAffineLoopNestImpl(
OpBuilder::InsertionGuard nestedGuard(nestedBuilder);
bodyBuilderFn(nestedBuilder, nestedLoc, ivs);
}
nestedBuilder.create<AffineYieldOp>(nestedLoc);
AffineYieldOp::create(nestedBuilder, nestedLoc);
};

// Delegate actual loop creation to the callback in order to dispatch
Expand All @@ -2846,8 +2891,8 @@ static AffineForOp
buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb,
int64_t ub, int64_t step,
AffineForOp::BodyBuilderFn bodyBuilderFn) {
return builder.create<AffineForOp>(loc, lb, ub, step,
/*iterArgs=*/ValueRange(), bodyBuilderFn);
return AffineForOp::create(builder, loc, lb, ub, step,
/*iterArgs=*/ValueRange(), bodyBuilderFn);
}

/// Creates an affine loop from the bounds that may or may not be constants.
Expand All @@ -2860,9 +2905,9 @@ buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub,
if (lbConst && ubConst)
return buildAffineLoopFromConstants(builder, loc, lbConst.value(),
ubConst.value(), step, bodyBuilderFn);
return builder.create<AffineForOp>(loc, lb, builder.getDimIdentityMap(), ub,
builder.getDimIdentityMap(), step,
/*iterArgs=*/ValueRange(), bodyBuilderFn);
return AffineForOp::create(builder, loc, lb, builder.getDimIdentityMap(), ub,
builder.getDimIdentityMap(), step,
/*iterArgs=*/ValueRange(), bodyBuilderFn);
}

void mlir::affine::buildAffineLoopNest(
Expand Down Expand Up @@ -4883,7 +4928,7 @@ struct DropUnitExtentBasis
Location loc = delinearizeOp->getLoc();
auto getZero = [&]() -> Value {
if (!zero)
zero = rewriter.create<arith::ConstantIndexOp>(loc, 0);
zero = arith::ConstantIndexOp::create(rewriter, loc, 0);
return zero.value();
};

Expand All @@ -4906,8 +4951,8 @@ struct DropUnitExtentBasis

if (!newBasis.empty()) {
// Will drop the leading nullptr from `basis` if there was no outer bound.
auto newDelinearizeOp = rewriter.create<affine::AffineDelinearizeIndexOp>(
loc, delinearizeOp.getLinearIndex(), newBasis);
auto newDelinearizeOp = affine::AffineDelinearizeIndexOp::create(
rewriter, loc, delinearizeOp.getLinearIndex(), newBasis);
int newIndex = 0;
// Map back the new delinearized indices to the values they replace.
for (auto &replacement : replacements) {
Expand Down Expand Up @@ -4971,12 +5016,12 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail
return success();
}

Value newLinearize = rewriter.create<affine::AffineLinearizeIndexOp>(
linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
Value newLinearize = affine::AffineLinearizeIndexOp::create(
rewriter, linearizeOp.getLoc(), linearizeIns.drop_back(numMatches),
ArrayRef<OpFoldResult>{linearizeBasis}.drop_back(numMatches),
linearizeOp.getDisjoint());
auto newDelinearize = rewriter.create<affine::AffineDelinearizeIndexOp>(
delinearizeOp.getLoc(), newLinearize,
auto newDelinearize = affine::AffineDelinearizeIndexOp::create(
rewriter, delinearizeOp.getLoc(), newLinearize,
ArrayRef<OpFoldResult>{delinearizeBasis}.drop_back(numMatches),
delinearizeOp.hasOuterBound());
SmallVector<Value> mergedResults(newDelinearize.getResults());
Expand Down Expand Up @@ -5048,19 +5093,16 @@ struct SplitDelinearizeSpanningLastLinearizeArg final
delinearizeOp,
"need at least two elements to form the basis product");

Value linearizeWithoutBack =
rewriter.create<affine::AffineLinearizeIndexOp>(
linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
linearizeOp.getDynamicBasis(),
linearizeOp.getStaticBasis().drop_back(),
linearizeOp.getDisjoint());
auto delinearizeWithoutSplitPart =
rewriter.create<affine::AffineDelinearizeIndexOp>(
delinearizeOp.getLoc(), linearizeWithoutBack,
delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
delinearizeOp.hasOuterBound());
auto delinearizeBack = rewriter.create<affine::AffineDelinearizeIndexOp>(
delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create(
rewriter, linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(),
linearizeOp.getDynamicBasis(), linearizeOp.getStaticBasis().drop_back(),
linearizeOp.getDisjoint());
auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create(
rewriter, delinearizeOp.getLoc(), linearizeWithoutBack,
delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit),
delinearizeOp.hasOuterBound());
auto delinearizeBack = affine::AffineDelinearizeIndexOp::create(
rewriter, delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(),
basis.take_back(elemsToSplit), /*hasOuterBound=*/true);
SmallVector<Value> results = llvm::to_vector(
llvm::concat<Value>(delinearizeWithoutSplitPart.getResults(),
Expand Down Expand Up @@ -5272,7 +5314,7 @@ OpFoldResult computeProduct(Location loc, OpBuilder &builder,
}
if (auto constant = dyn_cast<AffineConstantExpr>(result))
return getAsIndexOpFoldResult(builder.getContext(), constant.getValue());
return builder.create<AffineApplyOp>(loc, result, dynamicPart).getResult();
return AffineApplyOp::create(builder, loc, result, dynamicPart).getResult();
}

/// If conseceutive outputs of a delinearize_index are linearized with the same
Expand Down Expand Up @@ -5437,16 +5479,16 @@ struct CancelLinearizeOfDelinearizePortion final
newDelinBasis.erase(newDelinBasis.begin() + m.delinStart,
newDelinBasis.begin() + m.delinStart + m.length);
newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize);
auto newDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
auto newDelinearize = AffineDelinearizeIndexOp::create(
rewriter, m.delinearize.getLoc(), m.delinearize.getLinearIndex(),
newDelinBasis);

// Since there may be other uses of the indices we just merged together,
// create a residual affine.delinearize_index that delinearizes the
// merged output into its component parts.
Value combinedElem = newDelinearize.getResult(m.delinStart);
auto residualDelinearize = rewriter.create<AffineDelinearizeIndexOp>(
m.delinearize.getLoc(), combinedElem, basisToMerge);
auto residualDelinearize = AffineDelinearizeIndexOp::create(
rewriter, m.delinearize.getLoc(), combinedElem, basisToMerge);

// Swap all the uses of the unaffected delinearize outputs to the new
// delinearization so that the old code can be removed if this
Expand Down
Loading
Loading