Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
170 changes: 85 additions & 85 deletions mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp

Large diffs are not rendered by default.

26 changes: 13 additions & 13 deletions mlir/lib/Conversion/AffineToStandard/AffineToStandard.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -51,9 +51,9 @@ static Value buildMinMaxReductionSeq(Location loc,
Value value = *valueIt++;
for (; valueIt != values.end(); ++valueIt) {
if (predicate == arith::CmpIPredicate::sgt)
value = builder.create<arith::MaxSIOp>(loc, value, *valueIt);
value = arith::MaxSIOp::create(builder, loc, value, *valueIt);
else
value = builder.create<arith::MinSIOp>(loc, value, *valueIt);
value = arith::MinSIOp::create(builder, loc, value, *valueIt);
}

return value;
Expand Down Expand Up @@ -155,8 +155,8 @@ class AffineForLowering : public OpRewritePattern<AffineForOp> {
Value lowerBound = lowerAffineLowerBound(op, rewriter);
Value upperBound = lowerAffineUpperBound(op, rewriter);
Value step =
rewriter.create<arith::ConstantIndexOp>(loc, op.getStepAsInt());
auto scfForOp = rewriter.create<scf::ForOp>(loc, lowerBound, upperBound,
arith::ConstantIndexOp::create(rewriter, loc, op.getStepAsInt());
auto scfForOp = scf::ForOp::create(rewriter, loc, lowerBound, upperBound,
step, op.getInits());
rewriter.eraseBlock(scfForOp.getBody());
rewriter.inlineRegionBefore(op.getRegion(), scfForOp.getRegion(),
Expand Down Expand Up @@ -198,15 +198,15 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
}
steps.reserve(op.getSteps().size());
for (int64_t step : op.getSteps())
steps.push_back(rewriter.create<arith::ConstantIndexOp>(loc, step));
steps.push_back(arith::ConstantIndexOp::create(rewriter, loc, step));

// Get the terminator op.
auto affineParOpTerminator =
cast<AffineYieldOp>(op.getBody()->getTerminator());
scf::ParallelOp parOp;
if (op.getResults().empty()) {
// Case with no reduction operations/return values.
parOp = rewriter.create<scf::ParallelOp>(loc, lowerBoundTuple,
parOp = scf::ParallelOp::create(rewriter, loc, lowerBoundTuple,
upperBoundTuple, steps,
/*bodyBuilderFn=*/nullptr);
rewriter.eraseBlock(parOp.getBody());
Expand Down Expand Up @@ -234,7 +234,7 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
identityVals.push_back(
arith::getIdentityValue(reductionOpValue, resultType, rewriter, loc));
}
parOp = rewriter.create<scf::ParallelOp>(
parOp = scf::ParallelOp::create(rewriter,
loc, lowerBoundTuple, upperBoundTuple, steps, identityVals,
/*bodyBuilderFn=*/nullptr);

Expand Down Expand Up @@ -262,7 +262,7 @@ class AffineParallelLowering : public OpRewritePattern<AffineParallelOp> {
Value reductionResult = arith::getReductionOp(
reductionOpValue, rewriter, loc, reductionBody.getArgument(0),
reductionBody.getArgument(1));
rewriter.create<scf::ReduceReturnOp>(loc, reductionResult);
scf::ReduceReturnOp::create(rewriter, loc, reductionResult);
}
rewriter.replaceOp(op, parOp.getResults());
return success();
Expand All @@ -279,7 +279,7 @@ class AffineIfLowering : public OpRewritePattern<AffineIfOp> {

// Now we just have to handle the condition logic.
auto integerSet = op.getIntegerSet();
Value zeroConstant = rewriter.create<arith::ConstantIndexOp>(loc, 0);
Value zeroConstant = arith::ConstantIndexOp::create(rewriter, loc, 0);
SmallVector<Value, 8> operands(op.getOperands());
auto operandsRef = llvm::ArrayRef(operands);

Expand All @@ -299,17 +299,17 @@ class AffineIfLowering : public OpRewritePattern<AffineIfOp> {
auto pred =
isEquality ? arith::CmpIPredicate::eq : arith::CmpIPredicate::sge;
Value cmpVal =
rewriter.create<arith::CmpIOp>(loc, pred, affResult, zeroConstant);
arith::CmpIOp::create(rewriter, loc, pred, affResult, zeroConstant);
cond = cond
? rewriter.create<arith::AndIOp>(loc, cond, cmpVal).getResult()
? arith::AndIOp::create(rewriter, loc, cond, cmpVal).getResult()
: cmpVal;
}
cond = cond ? cond
: rewriter.create<arith::ConstantIntOp>(loc, /*value=*/1,
: arith::ConstantIntOp::create(rewriter, loc, /*value=*/1,
/*width=*/1);

bool hasElseRegion = !op.getElseRegion().empty();
auto ifOp = rewriter.create<scf::IfOp>(loc, op.getResultTypes(), cond,
auto ifOp = scf::IfOp::create(rewriter, loc, op.getResultTypes(), cond,
hasElseRegion);
rewriter.inlineRegionBefore(op.getThenRegion(),
&ifOp.getThenRegion().back());
Expand Down
90 changes: 45 additions & 45 deletions mlir/lib/Conversion/ArithToAMDGPU/ArithToAMDGPU.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -89,9 +89,9 @@ static Value castF32To(Type desType, Value f32, Location loc,
if (elementType.isF32())
return f32;
if (elementType.getIntOrFloatBitWidth() < 32)
return rewriter.create<arith::TruncFOp>(loc, desType, f32);
return arith::TruncFOp::create(rewriter, loc, desType, f32);
if (elementType.getIntOrFloatBitWidth() > 32)
return rewriter.create<arith::ExtFOp>(loc, desType, f32);
return arith::ExtFOp::create(rewriter, loc, desType, f32);
llvm_unreachable("The only 32-bit float type is f32");
}

Expand All @@ -113,26 +113,26 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
Type outElemType = getElementTypeOrSelf(op.getOut().getType());
VectorType extResType = VectorType::get(2, rewriter.getF32Type());
if (!inVecType) {
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
Value asFloat = amdgpu::ExtPackedFp8Op::create(rewriter,
loc, rewriter.getF32Type(), in, 0);
Value result = castF32To(outElemType, asFloat, loc, rewriter);
rewriter.replaceOp(op, result);
return success();
}
int64_t numElements = inVecType.getNumElements();

Value zero = rewriter.create<arith::ConstantOp>(
Value zero = arith::ConstantOp::create(rewriter,
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
VectorType outType = cast<VectorType>(op.getOut().getType());

if (inVecType.getShape().empty()) {
Value zerodSplat =
rewriter.createOrFold<vector::SplatOp>(loc, outType, zero);
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
Value scalarExt =
rewriter.create<arith::ExtFOp>(loc, outElemType, scalarIn);
Value result = rewriter.create<vector::InsertOp>(loc, scalarExt, zerodSplat,
arith::ExtFOp::create(rewriter, loc, outElemType, scalarIn);
Value result = vector::InsertOp::create(rewriter, loc, scalarExt, zerodSplat,
ArrayRef<int64_t>{});
rewriter.replaceOp(op, result);
return success();
Expand All @@ -145,32 +145,32 @@ ExtFOnFloat8RewritePattern::matchAndRewrite(arith::ExtFOp op,
if (inVecType.getRank() > 1) {
inVecType = VectorType::get(SmallVector<int64_t>{numElements},
inVecType.getElementType());
in = rewriter.create<vector::ShapeCastOp>(loc, inVecType, in);
in = vector::ShapeCastOp::create(rewriter, loc, inVecType, in);
}

for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value inSlice = rewriter.create<vector::ExtractStridedSliceOp>(
Value inSlice = vector::ExtractStridedSliceOp::create(rewriter,
loc, in, i, elemsThisOp, 1);
for (int64_t j = 0; j < elemsThisOp; j += 2) {
if (i + j + 1 < numElements) { // Convert two 8-bit elements
Value asFloats = rewriter.create<amdgpu::ExtPackedFp8Op>(
Value asFloats = amdgpu::ExtPackedFp8Op::create(rewriter,
loc, extResType, inSlice, j / 2);
Type desType = VectorType::get(2, outElemType);
Value asType = castF32To(desType, asFloats, loc, rewriter);
result = rewriter.create<vector::InsertStridedSliceOp>(
result = vector::InsertStridedSliceOp::create(rewriter,
loc, asType, result, i + j, 1);
} else { // Convert a 8-bit element
Value asFloat = rewriter.create<amdgpu::ExtPackedFp8Op>(
Value asFloat = amdgpu::ExtPackedFp8Op::create(rewriter,
loc, rewriter.getF32Type(), inSlice, j / 2 * 2);
Value asType = castF32To(outElemType, asFloat, loc, rewriter);
result = rewriter.create<vector::InsertOp>(loc, asType, result, i + j);
result = vector::InsertOp::create(rewriter, loc, asType, result, i + j);
}
}
}

if (inVecType.getRank() != outType.getRank()) {
result = rewriter.create<vector::ShapeCastOp>(loc, outType, result);
result = vector::ShapeCastOp::create(rewriter, loc, outType, result);
}

rewriter.replaceOp(op, result);
Expand All @@ -182,9 +182,9 @@ static Value castToF32(Value value, Location loc, PatternRewriter &rewriter) {
if (type.isF32())
return value;
if (type.getIntOrFloatBitWidth() < 32)
return rewriter.create<arith::ExtFOp>(loc, rewriter.getF32Type(), value);
return arith::ExtFOp::create(rewriter, loc, rewriter.getF32Type(), value);
if (type.getIntOrFloatBitWidth() > 32)
return rewriter.create<arith::TruncFOp>(loc, rewriter.getF32Type(), value);
return arith::TruncFOp::create(rewriter, loc, rewriter.getF32Type(), value);
llvm_unreachable("The only 32-bit float type is f32");
}

Expand Down Expand Up @@ -224,13 +224,13 @@ static Value clampInput(PatternRewriter &rewriter, Location loc,
loc, arith::CmpFPredicate::OEQ, source, negInf);
Value isNan = rewriter.createOrFold<arith::CmpFOp>(
loc, arith::CmpFPredicate::UNO, source, source);
Value isNonFinite = rewriter.create<arith::OrIOp>(
loc, rewriter.create<arith::OrIOp>(loc, isInf, isNegInf), isNan);
Value isNonFinite = arith::OrIOp::create(rewriter,
loc, arith::OrIOp::create(rewriter, loc, isInf, isNegInf), isNan);

Value clampedBelow = rewriter.create<arith::MaximumFOp>(loc, source, minCst);
Value clamped = rewriter.create<arith::MinimumFOp>(loc, clampedBelow, maxCst);
Value clampedBelow = arith::MaximumFOp::create(rewriter, loc, source, minCst);
Value clamped = arith::MinimumFOp::create(rewriter, loc, clampedBelow, maxCst);
Value res =
rewriter.create<arith::SelectOp>(loc, isNonFinite, source, clamped);
arith::SelectOp::create(rewriter, loc, isNonFinite, source, clamped);
return res;
}

Expand Down Expand Up @@ -264,24 +264,24 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
VectorType truncResType = VectorType::get(4, outElemType);
if (!inVectorTy) {
Value asFloat = castToF32(in, loc, rewriter);
Value asF8s = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
Value asF8s = amdgpu::PackedTrunc2xFp8Op::create(rewriter,
loc, truncResType, asFloat, /*sourceB=*/nullptr, 0,
/*existing=*/nullptr);
Value result = rewriter.create<vector::ExtractOp>(loc, asF8s, 0);
Value result = vector::ExtractOp::create(rewriter, loc, asF8s, 0);
rewriter.replaceOp(op, result);
return success();
}

int64_t numElements = outVecType.getNumElements();
Value zero = rewriter.create<arith::ConstantOp>(
Value zero = arith::ConstantOp::create(rewriter,
loc, outElemType, rewriter.getFloatAttr(outElemType, 0.0));
if (outVecType.getShape().empty()) {
Value scalarIn =
rewriter.create<vector::ExtractOp>(loc, in, ArrayRef<int64_t>{});
vector::ExtractOp::create(rewriter, loc, in, ArrayRef<int64_t>{});
// Recurse to send the 0-D vector case to the 1-D vector case
Value scalarTrunc =
rewriter.create<arith::TruncFOp>(loc, outElemType, scalarIn);
Value result = rewriter.create<vector::InsertOp>(loc, scalarTrunc, zero,
arith::TruncFOp::create(rewriter, loc, outElemType, scalarIn);
Value result = vector::InsertOp::create(rewriter, loc, scalarTrunc, zero,
ArrayRef<int64_t>{});
rewriter.replaceOp(op, result);
return success();
Expand All @@ -294,32 +294,32 @@ TruncFToFloat8RewritePattern::matchAndRewrite(arith::TruncFOp op,
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
inVectorTy.getElementType());
in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
}

for (int64_t i = 0; i < numElements; i += 4) {
int64_t elemsThisOp = std::min(numElements, i + 4) - i;
Value thisResult = nullptr;
for (int64_t j = 0; j < elemsThisOp; j += 2) {
Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i + j);
Value elemA = vector::ExtractOp::create(rewriter, loc, in, i + j);
Value asFloatA = castToF32(elemA, loc, rewriter);
Value asFloatB = nullptr;
if (j + 1 < elemsThisOp) {
Value elemB = rewriter.create<vector::ExtractOp>(loc, in, i + j + 1);
Value elemB = vector::ExtractOp::create(rewriter, loc, in, i + j + 1);
asFloatB = castToF32(elemB, loc, rewriter);
}
thisResult = rewriter.create<amdgpu::PackedTrunc2xFp8Op>(
thisResult = amdgpu::PackedTrunc2xFp8Op::create(rewriter,
loc, truncResType, asFloatA, asFloatB, j / 2, thisResult);
}
if (elemsThisOp < 4)
thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
thisResult = vector::ExtractStridedSliceOp::create(rewriter,
loc, thisResult, 0, elemsThisOp, 1);
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
result, i, 1);
}

if (inVectorTy.getRank() != outVecType.getRank()) {
result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
}

rewriter.replaceOp(op, result);
Expand Down Expand Up @@ -347,10 +347,10 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(

// Handle the case where input type is not a vector type
if (!inVectorTy) {
auto sourceB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
auto sourceB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());
Value asF16s =
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, in, sourceB);
Value result = rewriter.create<vector::ExtractOp>(loc, asF16s, 0);
ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, in, sourceB);
Value result = vector::ExtractOp::create(rewriter, loc, asF16s, 0);
rewriter.replaceOp(op, result);
return success();
}
Expand All @@ -362,33 +362,33 @@ LogicalResult TruncfToFloat16RewritePattern::matchAndRewrite(
if (inVectorTy.getRank() > 1) {
inVectorTy = VectorType::get(SmallVector<int64_t>{numElements},
inVectorTy.getElementType());
in = rewriter.create<vector::ShapeCastOp>(loc, inVectorTy, in);
in = vector::ShapeCastOp::create(rewriter, loc, inVectorTy, in);
}

// Handle the vector case. We also handle the (uncommon) case where the vector
// length is odd
for (int64_t i = 0; i < numElements; i += 2) {
int64_t elemsThisOp = std::min(numElements, i + 2) - i;
Value thisResult = nullptr;
Value elemA = rewriter.create<vector::ExtractOp>(loc, in, i);
Value elemB = rewriter.create<LLVM::PoisonOp>(loc, rewriter.getF32Type());
Value elemA = vector::ExtractOp::create(rewriter, loc, in, i);
Value elemB = LLVM::PoisonOp::create(rewriter, loc, rewriter.getF32Type());

if (elemsThisOp == 2) {
elemB = rewriter.create<vector::ExtractOp>(loc, in, i + 1);
elemB = vector::ExtractOp::create(rewriter, loc, in, i + 1);
}

thisResult =
rewriter.create<ROCDL::CvtPkRtz>(loc, truncResType, elemA, elemB);
ROCDL::CvtPkRtz::create(rewriter, loc, truncResType, elemA, elemB);
// Place back the truncated result into the possibly larger vector. If we
// are operating on a size 2 vector, these operations should be folded away
thisResult = rewriter.create<vector::ExtractStridedSliceOp>(
thisResult = vector::ExtractStridedSliceOp::create(rewriter,
loc, thisResult, 0, elemsThisOp, 1);
result = rewriter.create<vector::InsertStridedSliceOp>(loc, thisResult,
result = vector::InsertStridedSliceOp::create(rewriter, loc, thisResult,
result, i, 1);
}

if (inVectorTy.getRank() != outVecType.getRank()) {
result = rewriter.create<vector::ShapeCastOp>(loc, outVecType, result);
result = vector::ShapeCastOp::create(rewriter, loc, outVecType, result);
}

rewriter.replaceOp(op, result);
Expand Down
6 changes: 3 additions & 3 deletions mlir/lib/Conversion/ArithToArmSME/ArithToArmSME.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -74,14 +74,14 @@ struct ConstantOpToArmSMELowering : public OpRewritePattern<arith::ConstantOp> {
VectorType tileSliceType = VectorType::Builder(tileType).dropDim(0);
auto denseAttr1D = DenseElementsAttr::get(
tileSliceType, denseAttr.getSplatValue<Attribute>());
auto constantOp1D = rewriter.create<arith::ConstantOp>(loc, denseAttr1D);
auto constantOp1D = arith::ConstantOp::create(rewriter, loc, denseAttr1D);

auto initTile = rewriter.create<arm_sme::GetTileOp>(loc, tileType);
auto initTile = arm_sme::GetTileOp::create(rewriter, loc, tileType);
auto makeLoopBody = [&](OpBuilder &b, Location loc, Value tileSliceIndex,
Value currentTile) {
// Create 'arm_sme.insert_tile_slice' to write vector to tile
// slice.
auto nextTile = b.create<arm_sme::InsertTileSliceOp>(
auto nextTile = arm_sme::InsertTileSliceOp::create(b,
loc, tileType, constantOp1D, currentTile, tileSliceIndex);
return nextTile.getResult();
};
Expand Down
Loading