diff --git a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h index 2091faa6b0b02..333de6bbd8a05 100644 --- a/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h +++ b/mlir/include/mlir/Dialect/Affine/IR/AffineOps.h @@ -114,6 +114,21 @@ class AffineDmaStartOp AffineMap tagMap, ValueRange tagIndices, Value numElements, Value stride = nullptr, Value elementsPerStride = nullptr); + static AffineDmaStartOp + create(OpBuilder &builder, Location location, Value srcMemRef, + AffineMap srcMap, ValueRange srcIndices, Value destMemRef, + AffineMap dstMap, ValueRange destIndices, Value tagMemRef, + AffineMap tagMap, ValueRange tagIndices, Value numElements, + Value stride = nullptr, Value elementsPerStride = nullptr); + + static AffineDmaStartOp create(ImplicitLocOpBuilder &builder, Value srcMemRef, + AffineMap srcMap, ValueRange srcIndices, + Value destMemRef, AffineMap dstMap, + ValueRange destIndices, Value tagMemRef, + AffineMap tagMap, ValueRange tagIndices, + Value numElements, Value stride = nullptr, + Value elementsPerStride = nullptr); + /// Returns the operand index of the source memref. unsigned getSrcMemRefOperandIndex() { return 0; } @@ -319,6 +334,12 @@ class AffineDmaWaitOp static void build(OpBuilder &builder, OperationState &result, Value tagMemRef, AffineMap tagMap, ValueRange tagIndices, Value numElements); + static AffineDmaWaitOp create(OpBuilder &builder, Location location, + Value tagMemRef, AffineMap tagMap, + ValueRange tagIndices, Value numElements); + static AffineDmaWaitOp create(ImplicitLocOpBuilder &builder, Value tagMemRef, + AffineMap tagMap, ValueRange tagIndices, + Value numElements); static StringRef getOperationName() { return "affine.dma_wait"; } diff --git a/mlir/include/mlir/Dialect/Arith/IR/Arith.h b/mlir/include/mlir/Dialect/Arith/IR/Arith.h index 7c50c2036ffdc..0fc3db8e993d8 100644 --- a/mlir/include/mlir/Dialect/Arith/IR/Arith.h +++ b/mlir/include/mlir/Dialect/Arith/IR/Arith.h @@ -59,15 +59,27 @@ class ConstantIntOp : public arith::ConstantOp { /// Build a constant int op that produces an integer of the specified width. static void build(OpBuilder &builder, OperationState &result, int64_t value, unsigned width); + static ConstantIntOp create(OpBuilder &builder, Location location, + int64_t value, unsigned width); + static ConstantIntOp create(ImplicitLocOpBuilder &builder, int64_t value, + unsigned width); /// Build a constant int op that produces an integer of the specified type, /// which must be an integer type. static void build(OpBuilder &builder, OperationState &result, Type type, int64_t value); + static ConstantIntOp create(OpBuilder &builder, Location location, Type type, + int64_t value); + static ConstantIntOp create(ImplicitLocOpBuilder &builder, Type type, + int64_t value); /// Build a constant int op that produces an integer from an APInt static void build(OpBuilder &builder, OperationState &result, Type type, const APInt &value); + static ConstantIntOp create(OpBuilder &builder, Location location, Type type, + const APInt &value); + static ConstantIntOp create(ImplicitLocOpBuilder &builder, Type type, + const APInt &value); inline int64_t value() { return cast(arith::ConstantOp::getValue()).getInt(); @@ -85,6 +97,10 @@ class ConstantFloatOp : public arith::ConstantOp { /// Build a constant float op that produces a float of the specified type. static void build(OpBuilder &builder, OperationState &result, FloatType type, const APFloat &value); + static ConstantFloatOp create(OpBuilder &builder, Location location, + FloatType type, const APFloat &value); + static ConstantFloatOp create(ImplicitLocOpBuilder &builder, FloatType type, + const APFloat &value); inline APFloat value() { return cast(arith::ConstantOp::getValue()).getValue(); @@ -100,6 +116,9 @@ class ConstantIndexOp : public arith::ConstantOp { static ::mlir::TypeID resolveTypeID() { return TypeID::get(); } /// Build a constant int op that produces an index. static void build(OpBuilder &builder, OperationState &result, int64_t value); + static ConstantIndexOp create(OpBuilder &builder, Location location, + int64_t value); + static ConstantIndexOp create(ImplicitLocOpBuilder &builder, int64_t value); inline int64_t value() { return cast(arith::ConstantOp::getValue()).getInt(); diff --git a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp index e8e8f624d806e..ee5db073ffc4e 100644 --- a/mlir/lib/Dialect/Affine/IR/AffineOps.cpp +++ b/mlir/lib/Dialect/Affine/IR/AffineOps.cpp @@ -49,7 +49,7 @@ using llvm::mod; /// top level of a `AffineScope` region is always a valid symbol for all /// uses in that region. bool mlir::affine::isTopLevelValue(Value value, Region *region) { - if (auto arg = llvm::dyn_cast(value)) + if (auto arg = dyn_cast(value)) return arg.getParentRegion() == region; return value.getDefiningOp()->getParentRegion() == region; } @@ -240,7 +240,7 @@ Operation *AffineDialect::materializeConstant(OpBuilder &builder, Attribute value, Type type, Location loc) { if (auto poison = dyn_cast(value)) - return builder.create(loc, type, poison); + return ub::PoisonOp::create(builder, loc, type, poison); return arith::ConstantOp::materialize(builder, value, type, loc); } @@ -249,7 +249,7 @@ Operation *AffineDialect::materializeConstant(OpBuilder &builder, /// conservatively assume it is not top-level. A value of index type defined at /// the top level is always a valid symbol. bool mlir::affine::isTopLevelValue(Value value) { - if (auto arg = llvm::dyn_cast(value)) { + if (auto arg = dyn_cast(value)) { // The block owning the argument may be unlinked, e.g. when the surrounding // region has not yet been attached to an Op, at which point the parent Op // is null. @@ -1282,7 +1282,7 @@ mlir::affine::makeComposedAffineApply(OpBuilder &b, Location loc, AffineMap map, map = foldAttributesIntoMap(b, map, operands, valueOperands); composeAffineMapAndOperands(&map, &valueOperands, composeAffineMin); assert(map); - return b.create(loc, map, valueOperands); + return AffineApplyOp::create(b, loc, map, valueOperands); } AffineApplyOp @@ -1389,7 +1389,7 @@ static OpTy makeComposedMinMax(OpBuilder &b, Location loc, AffineMap map, SmallVector valueOperands; map = foldAttributesIntoMap(b, map, operands, valueOperands); composeMultiResultAffineMap(map, valueOperands); - return b.create(loc, b.getIndexType(), map, valueOperands); + return OpTy::create(b, loc, b.getIndexType(), map, valueOperands); } AffineMinOp @@ -1747,6 +1747,32 @@ void AffineDmaStartOp::build(OpBuilder &builder, OperationState &result, } } +AffineDmaStartOp AffineDmaStartOp::create( + OpBuilder &builder, Location location, Value srcMemRef, AffineMap srcMap, + ValueRange srcIndices, Value destMemRef, AffineMap dstMap, + ValueRange destIndices, Value tagMemRef, AffineMap tagMap, + ValueRange tagIndices, Value numElements, Value stride, + Value elementsPerStride) { + mlir::OperationState state(location, getOperationName()); + build(builder, state, srcMemRef, srcMap, srcIndices, destMemRef, dstMap, + destIndices, tagMemRef, tagMap, tagIndices, numElements, stride, + elementsPerStride); + auto result = dyn_cast(builder.create(state)); + assert(result && "builder didn't return the right type"); + return result; +} + +AffineDmaStartOp AffineDmaStartOp::create( + ImplicitLocOpBuilder &builder, Value srcMemRef, AffineMap srcMap, + ValueRange srcIndices, Value destMemRef, AffineMap dstMap, + ValueRange destIndices, Value tagMemRef, AffineMap tagMap, + ValueRange tagIndices, Value numElements, Value stride, + Value elementsPerStride) { + return create(builder, builder.getLoc(), srcMemRef, srcMap, srcIndices, + destMemRef, dstMap, destIndices, tagMemRef, tagMap, tagIndices, + numElements, stride, elementsPerStride); +} + void AffineDmaStartOp::print(OpAsmPrinter &p) { p << " " << getSrcMemRef() << '['; p.printAffineMapOfSSAIds(getSrcMapAttr(), getSrcIndices()); @@ -1917,6 +1943,25 @@ void AffineDmaWaitOp::build(OpBuilder &builder, OperationState &result, result.addOperands(numElements); } +AffineDmaWaitOp AffineDmaWaitOp::create(OpBuilder &builder, Location location, + Value tagMemRef, AffineMap tagMap, + ValueRange tagIndices, + Value numElements) { + mlir::OperationState state(location, getOperationName()); + build(builder, state, tagMemRef, tagMap, tagIndices, numElements); + auto result = dyn_cast(builder.create(state)); + assert(result && "builder didn't return the right type"); + return result; +} + +AffineDmaWaitOp AffineDmaWaitOp::create(ImplicitLocOpBuilder &builder, + Value tagMemRef, AffineMap tagMap, + ValueRange tagIndices, + Value numElements) { + return create(builder, builder.getLoc(), tagMemRef, tagMap, tagIndices, + numElements); +} + void AffineDmaWaitOp::print(OpAsmPrinter &p) { p << " " << getTagMemRef() << '['; SmallVector operands(getTagIndices()); @@ -2153,7 +2198,7 @@ static ParseResult parseBound(bool isLower, OperationState &result, return failure(); // Parse full form - affine map followed by dim and symbol list. - if (auto affineMapAttr = llvm::dyn_cast(boundAttr)) { + if (auto affineMapAttr = dyn_cast(boundAttr)) { unsigned currentNumOperands = result.operands.size(); unsigned numDims; if (parseDimAndSymbolList(p, result.operands, numDims)) @@ -2186,7 +2231,7 @@ static ParseResult parseBound(bool isLower, OperationState &result, } // Parse custom assembly form. - if (auto integerAttr = llvm::dyn_cast(boundAttr)) { + if (auto integerAttr = dyn_cast(boundAttr)) { result.attributes.pop_back(); result.addAttribute( boundAttrStrName, @@ -2688,8 +2733,8 @@ FailureOr AffineForOp::replaceWithAdditionalYields( rewriter.setInsertionPoint(getOperation()); auto inits = llvm::to_vector(getInits()); inits.append(newInitOperands.begin(), newInitOperands.end()); - AffineForOp newLoop = rewriter.create( - getLoc(), getLowerBoundOperands(), getLowerBoundMap(), + AffineForOp newLoop = AffineForOp::create( + rewriter, getLoc(), getLowerBoundOperands(), getLowerBoundMap(), getUpperBoundOperands(), getUpperBoundMap(), getStepAsInt(), inits); // Generate the new yield values and append them to the scf.yield operation. @@ -2756,7 +2801,7 @@ bool mlir::affine::isAffineInductionVar(Value val) { } AffineForOp mlir::affine::getForInductionVarOwner(Value val) { - auto ivArg = llvm::dyn_cast(val); + auto ivArg = dyn_cast(val); if (!ivArg || !ivArg.getOwner() || !ivArg.getOwner()->getParent()) return AffineForOp(); if (auto forOp = @@ -2767,7 +2812,7 @@ AffineForOp mlir::affine::getForInductionVarOwner(Value val) { } AffineParallelOp mlir::affine::getAffineParallelInductionVarOwner(Value val) { - auto ivArg = llvm::dyn_cast(val); + auto ivArg = dyn_cast(val); if (!ivArg || !ivArg.getOwner()) return nullptr; Operation *containingOp = ivArg.getOwner()->getParentOp(); @@ -2831,7 +2876,7 @@ static void buildAffineLoopNestImpl( OpBuilder::InsertionGuard nestedGuard(nestedBuilder); bodyBuilderFn(nestedBuilder, nestedLoc, ivs); } - nestedBuilder.create(nestedLoc); + AffineYieldOp::create(nestedBuilder, nestedLoc); }; // Delegate actual loop creation to the callback in order to dispatch @@ -2846,8 +2891,8 @@ static AffineForOp buildAffineLoopFromConstants(OpBuilder &builder, Location loc, int64_t lb, int64_t ub, int64_t step, AffineForOp::BodyBuilderFn bodyBuilderFn) { - return builder.create(loc, lb, ub, step, - /*iterArgs=*/ValueRange(), bodyBuilderFn); + return AffineForOp::create(builder, loc, lb, ub, step, + /*iterArgs=*/ValueRange(), bodyBuilderFn); } /// Creates an affine loop from the bounds that may or may not be constants. @@ -2860,9 +2905,9 @@ buildAffineLoopFromValues(OpBuilder &builder, Location loc, Value lb, Value ub, if (lbConst && ubConst) return buildAffineLoopFromConstants(builder, loc, lbConst.value(), ubConst.value(), step, bodyBuilderFn); - return builder.create(loc, lb, builder.getDimIdentityMap(), ub, - builder.getDimIdentityMap(), step, - /*iterArgs=*/ValueRange(), bodyBuilderFn); + return AffineForOp::create(builder, loc, lb, builder.getDimIdentityMap(), ub, + builder.getDimIdentityMap(), step, + /*iterArgs=*/ValueRange(), bodyBuilderFn); } void mlir::affine::buildAffineLoopNest( @@ -3294,11 +3339,11 @@ OpFoldResult AffineLoadOp::fold(FoldAdaptor adaptor) { // Check if the global memref is a constant. auto cstAttr = - llvm::dyn_cast_or_null(global.getConstantInitValue()); + dyn_cast_or_null(global.getConstantInitValue()); if (!cstAttr) return {}; // If it's a splat constant, we can fold irrespective of indices. - if (auto splatAttr = llvm::dyn_cast(cstAttr)) + if (auto splatAttr = dyn_cast(cstAttr)) return splatAttr.getSplatValue(); // Otherwise, we can fold only if we know the indices. if (!getAffineMap().isConstant()) @@ -4065,19 +4110,19 @@ static bool isResultTypeMatchAtomicRMWKind(Type resultType, case arith::AtomicRMWKind::minimumf: return isa(resultType); case arith::AtomicRMWKind::maxs: { - auto intType = llvm::dyn_cast(resultType); + auto intType = dyn_cast(resultType); return intType && intType.isSigned(); } case arith::AtomicRMWKind::mins: { - auto intType = llvm::dyn_cast(resultType); + auto intType = dyn_cast(resultType); return intType && intType.isSigned(); } case arith::AtomicRMWKind::maxu: { - auto intType = llvm::dyn_cast(resultType); + auto intType = dyn_cast(resultType); return intType && intType.isUnsigned(); } case arith::AtomicRMWKind::minu: { - auto intType = llvm::dyn_cast(resultType); + auto intType = dyn_cast(resultType); return intType && intType.isUnsigned(); } case arith::AtomicRMWKind::ori: @@ -4134,7 +4179,7 @@ LogicalResult AffineParallelOp::verify() { // ops for (auto it : llvm::enumerate((getReductions()))) { Attribute attr = it.value(); - auto intAttr = llvm::dyn_cast(attr); + auto intAttr = dyn_cast(attr); if (!intAttr || !arith::symbolizeAtomicRMWKind(intAttr.getInt())) return emitOpError("invalid reduction attribute"); auto kind = arith::symbolizeAtomicRMWKind(intAttr.getInt()).value(); @@ -4883,7 +4928,7 @@ struct DropUnitExtentBasis Location loc = delinearizeOp->getLoc(); auto getZero = [&]() -> Value { if (!zero) - zero = rewriter.create(loc, 0); + zero = arith::ConstantIndexOp::create(rewriter, loc, 0); return zero.value(); }; @@ -4906,8 +4951,8 @@ struct DropUnitExtentBasis if (!newBasis.empty()) { // Will drop the leading nullptr from `basis` if there was no outer bound. - auto newDelinearizeOp = rewriter.create( - loc, delinearizeOp.getLinearIndex(), newBasis); + auto newDelinearizeOp = affine::AffineDelinearizeIndexOp::create( + rewriter, loc, delinearizeOp.getLinearIndex(), newBasis); int newIndex = 0; // Map back the new delinearized indices to the values they replace. for (auto &replacement : replacements) { @@ -4971,12 +5016,12 @@ struct CancelDelinearizeOfLinearizeDisjointExactTail return success(); } - Value newLinearize = rewriter.create( - linearizeOp.getLoc(), linearizeIns.drop_back(numMatches), + Value newLinearize = affine::AffineLinearizeIndexOp::create( + rewriter, linearizeOp.getLoc(), linearizeIns.drop_back(numMatches), ArrayRef{linearizeBasis}.drop_back(numMatches), linearizeOp.getDisjoint()); - auto newDelinearize = rewriter.create( - delinearizeOp.getLoc(), newLinearize, + auto newDelinearize = affine::AffineDelinearizeIndexOp::create( + rewriter, delinearizeOp.getLoc(), newLinearize, ArrayRef{delinearizeBasis}.drop_back(numMatches), delinearizeOp.hasOuterBound()); SmallVector mergedResults(newDelinearize.getResults()); @@ -5048,19 +5093,16 @@ struct SplitDelinearizeSpanningLastLinearizeArg final delinearizeOp, "need at least two elements to form the basis product"); - Value linearizeWithoutBack = - rewriter.create( - linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(), - linearizeOp.getDynamicBasis(), - linearizeOp.getStaticBasis().drop_back(), - linearizeOp.getDisjoint()); - auto delinearizeWithoutSplitPart = - rewriter.create( - delinearizeOp.getLoc(), linearizeWithoutBack, - delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit), - delinearizeOp.hasOuterBound()); - auto delinearizeBack = rewriter.create( - delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(), + Value linearizeWithoutBack = affine::AffineLinearizeIndexOp::create( + rewriter, linearizeOp.getLoc(), linearizeOp.getMultiIndex().drop_back(), + linearizeOp.getDynamicBasis(), linearizeOp.getStaticBasis().drop_back(), + linearizeOp.getDisjoint()); + auto delinearizeWithoutSplitPart = affine::AffineDelinearizeIndexOp::create( + rewriter, delinearizeOp.getLoc(), linearizeWithoutBack, + delinearizeOp.getDynamicBasis(), basis.drop_back(elemsToSplit), + delinearizeOp.hasOuterBound()); + auto delinearizeBack = affine::AffineDelinearizeIndexOp::create( + rewriter, delinearizeOp.getLoc(), linearizeOp.getMultiIndex().back(), basis.take_back(elemsToSplit), /*hasOuterBound=*/true); SmallVector results = llvm::to_vector( llvm::concat(delinearizeWithoutSplitPart.getResults(), @@ -5272,7 +5314,7 @@ OpFoldResult computeProduct(Location loc, OpBuilder &builder, } if (auto constant = dyn_cast(result)) return getAsIndexOpFoldResult(builder.getContext(), constant.getValue()); - return builder.create(loc, result, dynamicPart).getResult(); + return AffineApplyOp::create(builder, loc, result, dynamicPart).getResult(); } /// If conseceutive outputs of a delinearize_index are linearized with the same @@ -5437,16 +5479,16 @@ struct CancelLinearizeOfDelinearizePortion final newDelinBasis.erase(newDelinBasis.begin() + m.delinStart, newDelinBasis.begin() + m.delinStart + m.length); newDelinBasis.insert(newDelinBasis.begin() + m.delinStart, newSize); - auto newDelinearize = rewriter.create( - m.delinearize.getLoc(), m.delinearize.getLinearIndex(), + auto newDelinearize = AffineDelinearizeIndexOp::create( + rewriter, m.delinearize.getLoc(), m.delinearize.getLinearIndex(), newDelinBasis); // Since there may be other uses of the indices we just merged together, // create a residual affine.delinearize_index that delinearizes the // merged output into its component parts. Value combinedElem = newDelinearize.getResult(m.delinStart); - auto residualDelinearize = rewriter.create( - m.delinearize.getLoc(), combinedElem, basisToMerge); + auto residualDelinearize = AffineDelinearizeIndexOp::create( + rewriter, m.delinearize.getLoc(), combinedElem, basisToMerge); // Swap all the uses of the unaffected delinearize outputs to the new // delinearization so that the old code can be removed if this diff --git a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp index 4e40d4ebda004..910334b17748b 100644 --- a/mlir/lib/Dialect/Arith/IR/ArithOps.cpp +++ b/mlir/lib/Dialect/Arith/IR/ArithOps.cpp @@ -148,7 +148,7 @@ static FailureOr getIntOrSplatIntValue(Attribute attr) { static Attribute getBoolAttribute(Type type, bool value) { auto boolAttr = BoolAttr::get(type.getContext(), value); - ShapedType shapedType = llvm::dyn_cast_or_null(type); + ShapedType shapedType = dyn_cast_or_null(type); if (!shapedType) return boolAttr; return DenseElementsAttr::get(shapedType, boolAttr); @@ -169,7 +169,7 @@ namespace { /// Return the type of the same shape (scalar, vector or tensor) containing i1. static Type getI1SameShape(Type type) { auto i1Type = IntegerType::get(type.getContext(), 1); - if (auto shapedType = llvm::dyn_cast(type)) + if (auto shapedType = dyn_cast(type)) return shapedType.cloneWith(std::nullopt, i1Type); if (llvm::isa(type)) return UnrankedTensorType::get(i1Type); @@ -183,8 +183,8 @@ static Type getI1SameShape(Type type) { void arith::ConstantOp::getAsmResultNames( function_ref setNameFn) { auto type = getType(); - if (auto intCst = llvm::dyn_cast(getValue())) { - auto intType = llvm::dyn_cast(type); + if (auto intCst = dyn_cast(getValue())) { + auto intType = dyn_cast(type); // Sugar i1 constants with 'true' and 'false'. if (intType && intType.getWidth() == 1) @@ -228,7 +228,7 @@ LogicalResult arith::ConstantOp::verify() { bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { // The value's type must be the same as the provided type. - auto typedAttr = llvm::dyn_cast(value); + auto typedAttr = dyn_cast(value); if (!typedAttr || typedAttr.getType() != type) return false; // Integer values must be signless. @@ -242,7 +242,7 @@ bool arith::ConstantOp::isBuildableWith(Attribute value, Type type) { ConstantOp arith::ConstantOp::materialize(OpBuilder &builder, Attribute value, Type type, Location loc) { if (isBuildableWith(value, type)) - return builder.create(loc, cast(value)); + return arith::ConstantOp::create(builder, loc, cast(value)); return nullptr; } @@ -255,18 +255,66 @@ void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, builder.getIntegerAttr(type, value)); } +arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder, + Location location, + int64_t value, + unsigned width) { + mlir::OperationState state(location, getOperationName()); + build(builder, state, value, width); + auto result = dyn_cast(builder.create(state)); + assert(result && "builder didn't return the right type"); + return result; +} + +arith::ConstantIntOp arith::ConstantIntOp::create(ImplicitLocOpBuilder &builder, + int64_t value, + unsigned width) { + return create(builder, builder.getLoc(), value, width); +} + void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, Type type, int64_t value) { arith::ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value)); } +arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder, + Location location, Type type, + int64_t value) { + mlir::OperationState state(location, getOperationName()); + build(builder, state, type, value); + auto result = dyn_cast(builder.create(state)); + assert(result && "builder didn't return the right type"); + return result; +} + +arith::ConstantIntOp arith::ConstantIntOp::create(ImplicitLocOpBuilder &builder, + Type type, int64_t value) { + return create(builder, builder.getLoc(), type, value); +} + void arith::ConstantIntOp::build(OpBuilder &builder, OperationState &result, Type type, const APInt &value) { arith::ConstantOp::build(builder, result, type, builder.getIntegerAttr(type, value)); } +arith::ConstantIntOp arith::ConstantIntOp::create(OpBuilder &builder, + Location location, Type type, + const APInt &value) { + mlir::OperationState state(location, getOperationName()); + build(builder, state, type, value); + auto result = dyn_cast(builder.create(state)); + assert(result && "builder didn't return the right type"); + return result; +} + +arith::ConstantIntOp arith::ConstantIntOp::create(ImplicitLocOpBuilder &builder, + Type type, + const APInt &value) { + return create(builder, builder.getLoc(), type, value); +} + bool arith::ConstantIntOp::classof(Operation *op) { if (auto constOp = dyn_cast_or_null(op)) return constOp.getType().isSignlessInteger(); @@ -279,6 +327,23 @@ void arith::ConstantFloatOp::build(OpBuilder &builder, OperationState &result, builder.getFloatAttr(type, value)); } +arith::ConstantFloatOp arith::ConstantFloatOp::create(OpBuilder &builder, + Location location, + FloatType type, + const APFloat &value) { + mlir::OperationState state(location, getOperationName()); + build(builder, state, type, value); + auto result = dyn_cast(builder.create(state)); + assert(result && "builder didn't return the right type"); + return result; +} + +arith::ConstantFloatOp +arith::ConstantFloatOp::create(ImplicitLocOpBuilder &builder, FloatType type, + const APFloat &value) { + return create(builder, builder.getLoc(), type, value); +} + bool arith::ConstantFloatOp::classof(Operation *op) { if (auto constOp = dyn_cast_or_null(op)) return llvm::isa(constOp.getType()); @@ -291,6 +356,21 @@ void arith::ConstantIndexOp::build(OpBuilder &builder, OperationState &result, builder.getIndexAttr(value)); } +arith::ConstantIndexOp arith::ConstantIndexOp::create(OpBuilder &builder, + Location location, + int64_t value) { + mlir::OperationState state(location, getOperationName()); + build(builder, state, value); + auto result = dyn_cast(builder.create(state)); + assert(result && "builder didn't return the right type"); + return result; +} + +arith::ConstantIndexOp +arith::ConstantIndexOp::create(ImplicitLocOpBuilder &builder, int64_t value) { + return create(builder, builder.getLoc(), value); +} + bool arith::ConstantIndexOp::classof(Operation *op) { if (auto constOp = dyn_cast_or_null(op)) return constOp.getType().isIndex(); @@ -304,7 +384,7 @@ Value mlir::arith::getZeroConstant(OpBuilder &builder, Location loc, "type doesn't have a zero representation"); TypedAttr zeroAttr = builder.getZeroAttr(type); assert(zeroAttr && "unsupported type for zero attribute"); - return builder.create(loc, zeroAttr); + return arith::ConstantOp::create(builder, loc, zeroAttr); } //===----------------------------------------------------------------------===// @@ -343,7 +423,7 @@ void arith::AddIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, std::optional> arith::AddUIExtendedOp::getShapeForUnroll() { - if (auto vt = llvm::dyn_cast(getType(0))) + if (auto vt = dyn_cast(getType(0))) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } @@ -489,7 +569,7 @@ void arith::MulIOp::getCanonicalizationPatterns(RewritePatternSet &patterns, std::optional> arith::MulSIExtendedOp::getShapeForUnroll() { - if (auto vt = llvm::dyn_cast(getType(0))) + if (auto vt = dyn_cast(getType(0))) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } @@ -535,7 +615,7 @@ void arith::MulSIExtendedOp::getCanonicalizationPatterns( std::optional> arith::MulUIExtendedOp::getShapeForUnroll() { - if (auto vt = llvm::dyn_cast(getType(0))) + if (auto vt = dyn_cast(getType(0))) return llvm::to_vector<4>(vt.getShape()); return std::nullopt; } @@ -1815,7 +1895,7 @@ OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) { return {}; /// Bitcast dense elements. - if (auto denseAttr = llvm::dyn_cast_or_null(operand)) + if (auto denseAttr = dyn_cast_or_null(operand)) return denseAttr.bitcast(llvm::cast(resType).getElementType()); /// Other shaped types unhandled. if (llvm::isa(resType)) @@ -1832,7 +1912,7 @@ OpFoldResult arith::BitcastOp::fold(FoldAdaptor adaptor) { assert(resType.getIntOrFloatBitWidth() == bits.getBitWidth() && "trying to fold on broken IR: operands have incompatible types"); - if (auto resFloatType = llvm::dyn_cast(resType)) + if (auto resFloatType = dyn_cast(resType)) return FloatAttr::get(resType, APFloat(resFloatType.getFloatSemantics(), bits)); return IntegerAttr::get(resType, bits); @@ -1896,10 +1976,10 @@ static bool applyCmpPredicateToEqualOperands(arith::CmpIPredicate predicate) { } static std::optional getIntegerWidth(Type t) { - if (auto intType = llvm::dyn_cast(t)) { + if (auto intType = dyn_cast(t)) { return intType.getWidth(); } - if (auto vectorIntType = llvm::dyn_cast(t)) { + if (auto vectorIntType = dyn_cast(t)) { return llvm::cast(vectorIntType.getElementType()).getWidth(); } return std::nullopt; @@ -1969,7 +2049,7 @@ OpFoldResult arith::CmpIOp::fold(FoldAdaptor adaptor) { // We are moving constants to the right side; So if lhs is constant rhs is // guaranteed to be a constant. - if (auto lhs = llvm::dyn_cast_if_present(adaptor.getLhs())) { + if (auto lhs = dyn_cast_if_present(adaptor.getLhs())) { return constFoldBinaryOp( adaptor.getOperands(), getI1SameShape(lhs.getType()), [pred = getPredicate()](const APInt &lhs, const APInt &rhs) { @@ -2039,8 +2119,8 @@ bool mlir::arith::applyCmpPredicate(arith::CmpFPredicate predicate, } OpFoldResult arith::CmpFOp::fold(FoldAdaptor adaptor) { - auto lhs = llvm::dyn_cast_if_present(adaptor.getLhs()); - auto rhs = llvm::dyn_cast_if_present(adaptor.getRhs()); + auto lhs = dyn_cast_if_present(adaptor.getLhs()); + auto rhs = dyn_cast_if_present(adaptor.getRhs()); // If one operand is NaN, making them both NaN does not change the result. if (lhs && lhs.getValue().isNaN()) @@ -2334,9 +2414,8 @@ class CmpFIntToFPConst final : public OpRewritePattern { // comparison. rewriter.replaceOpWithNewOp( op, pred, intVal, - rewriter.create( - op.getLoc(), intVal.getType(), - rewriter.getIntegerAttr(intVal.getType(), rhsInt))); + ConstantOp::create(rewriter, op.getLoc(), intVal.getType(), + rewriter.getIntegerAttr(intVal.getType(), rhsInt))); return success(); } }; @@ -2373,10 +2452,10 @@ struct SelectToExtUI : public OpRewritePattern { matchPattern(op.getFalseValue(), m_One())) { rewriter.replaceOpWithNewOp( op, op.getType(), - rewriter.create( - op.getLoc(), op.getCondition(), - rewriter.create( - op.getLoc(), op.getCondition().getType(), 1))); + arith::XOrIOp::create( + rewriter, op.getLoc(), op.getCondition(), + arith::ConstantIntOp::create(rewriter, op.getLoc(), + op.getCondition().getType(), 1))); return success(); } @@ -2440,11 +2519,11 @@ OpFoldResult arith::SelectOp::fold(FoldAdaptor adaptor) { // Constant-fold constant operands over non-splat constant condition. // select %cst_vec, %cst0, %cst1 => %cst2 if (auto cond = - llvm::dyn_cast_if_present(adaptor.getCondition())) { + dyn_cast_if_present(adaptor.getCondition())) { if (auto lhs = - llvm::dyn_cast_if_present(adaptor.getTrueValue())) { + dyn_cast_if_present(adaptor.getTrueValue())) { if (auto rhs = - llvm::dyn_cast_if_present(adaptor.getFalseValue())) { + dyn_cast_if_present(adaptor.getFalseValue())) { SmallVector results; results.reserve(static_cast(cond.getNumElements())); auto condVals = llvm::make_range(cond.value_begin(), @@ -2493,8 +2572,7 @@ void arith::SelectOp::print(OpAsmPrinter &p) { p << " " << getOperands(); p.printOptionalAttrDict((*this)->getAttrs()); p << " : "; - if (ShapedType condType = - llvm::dyn_cast(getCondition().getType())) + if (ShapedType condType = dyn_cast(getCondition().getType())) p << condType << ", "; p << getType(); } @@ -2692,7 +2770,7 @@ Value mlir::arith::getIdentityValue(AtomicRMWKind op, Type resultType, bool useOnlyFiniteValue) { auto attr = getIdentityValueAttr(op, resultType, builder, loc, useOnlyFiniteValue); - return builder.create(loc, attr); + return arith::ConstantOp::create(builder, loc, attr); } /// Return the value obtained by applying the reduction operation kind @@ -2701,33 +2779,33 @@ Value mlir::arith::getReductionOp(AtomicRMWKind op, OpBuilder &builder, Location loc, Value lhs, Value rhs) { switch (op) { case AtomicRMWKind::addf: - return builder.create(loc, lhs, rhs); + return arith::AddFOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::addi: - return builder.create(loc, lhs, rhs); + return arith::AddIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::mulf: - return builder.create(loc, lhs, rhs); + return arith::MulFOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::muli: - return builder.create(loc, lhs, rhs); + return arith::MulIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::maximumf: - return builder.create(loc, lhs, rhs); + return arith::MaximumFOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::minimumf: - return builder.create(loc, lhs, rhs); - case AtomicRMWKind::maxnumf: - return builder.create(loc, lhs, rhs); + return arith::MinimumFOp::create(builder, loc, lhs, rhs); + case AtomicRMWKind::maxnumf: + return arith::MaxNumFOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::minnumf: - return builder.create(loc, lhs, rhs); + return arith::MinNumFOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::maxs: - return builder.create(loc, lhs, rhs); + return arith::MaxSIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::mins: - return builder.create(loc, lhs, rhs); + return arith::MinSIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::maxu: - return builder.create(loc, lhs, rhs); + return arith::MaxUIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::minu: - return builder.create(loc, lhs, rhs); + return arith::MinUIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::ori: - return builder.create(loc, lhs, rhs); + return arith::OrIOp::create(builder, loc, lhs, rhs); case AtomicRMWKind::andi: - return builder.create(loc, lhs, rhs); + return arith::AndIOp::create(builder, loc, lhs, rhs); // TODO: Add remaining reduction operations. default: (void)emitOptionalError(loc, "Reduction operation type not supported");