Skip to content

Commit 252d6bd

Browse files
committed
[AMD] Added fixes to refine-ops-pass after rebase with upstream
1 parent 28d84c6 commit 252d6bd

File tree

6 files changed

+50
-51
lines changed

6 files changed

+50
-51
lines changed

bin/RegisterTritonDialects.h

Lines changed: 1 addition & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -92,10 +92,7 @@ inline void registerTritonDialects(mlir::DialectRegistry &registry) {
9292

9393
// NVWS passes
9494
mlir::triton::registerNVWSTransformsPasses();
95-
96-
// NVGPU transform passes
97-
mlir::registerNVHopperTransformsPasses();
98-
mlir::triton::registerTritonAMDGPURefineOps();
95+
mlir::registerTritonAMDGPURefineOps();
9996

10097
registry.insert<
10198
mlir::triton::TritonDialect, mlir::cf::ControlFlowDialect,

third_party/amd/include/TritonAMDGPUTransforms/DotTiling.h

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -42,9 +42,10 @@ unsigned getCyclesPerMfma(DotOp dotOp) {
4242
bool allowXF32 =
4343
dotOp.getInputPrecision() == InputPrecision::TF32 && mfmaVersion == 3;
4444

45-
FailureOr<MfmaIntrinsic> maybeMfmaInsn = MfmaIntrinsic::selectFor(
46-
mfmaVersion, mDim, nDim, kDimOperandSize, elemTyA, elemTyB,
47-
/*withScale=*/false, allowXF32);
45+
FailureOr<MfmaIntrinsic> maybeMfmaInsn =
46+
MfmaIntrinsic::selectFor(dotOp->getLoc(), mfmaVersion, mDim, nDim,
47+
kDimOperandSize, elemTyA, elemTyB,
48+
/*withScale=*/false, allowXF32);
4849

4950
if (failed(maybeMfmaInsn))
5051
llvm::report_fatal_error("No match found in MFMA database\n");

third_party/amd/lib/Dialect/TritonAMDGPU/IR/Dialect.cpp

Lines changed: 26 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -458,10 +458,18 @@ struct CanonicalizeConcatOpFromExtractSlice
458458
if (!concatOp)
459459
return failure();
460460

461-
auto offset = op.getStaticOffsets();
462-
auto coords = concatOp.getCoords();
463-
if (coords.size() != offset.size())
461+
auto origTensorType = dyn_cast<RankedTensorType>(op.getOperand().getType());
462+
auto concatTensorType =
463+
dyn_cast<RankedTensorType>(op.getResult().getType());
464+
465+
if (!(origTensorType && concatTensorType)) {
464466
return failure();
467+
}
468+
auto origTensorShape = origTensorType.getShape();
469+
auto concatTensorShape = concatTensorType.getShape();
470+
if (origTensorShape.equals(concatTensorShape)) {
471+
return failure();
472+
}
465473

466474
auto sliceResult = op.getResult();
467475
auto sliceResultType = sliceResult.getType();
@@ -476,10 +484,16 @@ struct CanonicalizeConcatOpFromExtractSlice
476484
return failure();
477485

478486
auto concatItemShape = concatItemType.getShape();
487+
llvm::SmallVector<int64_t> coords(concatTensorShape.size());
488+
for (size_t i = 0; i < coords.size(); ++i) {
489+
coords[i] = concatTensorShape[i] / concatItemShape[i];
490+
}
491+
479492
SmallVector<int64_t> dimScales(concatItemShape.size(), 1);
480493
int64_t concatItemIndex = 0;
481494
std::exclusive_scan(coords.rbegin(), coords.rend(), dimScales.rbegin(), 1,
482495
std::multiplies<>());
496+
auto offset = op.getStaticOffsets();
483497
for (auto [idx, itemDimSize] : llvm::enumerate(concatItemShape)) {
484498
if ((offset[idx] % itemDimSize) != 0)
485499
return failure();
@@ -504,12 +518,16 @@ struct CanonicalizeConcatOp : public mlir::OpRewritePattern<amdgpu::ConcatOp> {
504518

505519
auto result = op.getResult();
506520
auto sources = op.getSources();
507-
auto offsets = op.getCoords();
508521
if (sources.size() == 1) {
509-
assert(product(offsets) == 1);
510-
auto source = sources.front();
511-
result.replaceAllUsesWith(source);
512-
return success();
522+
auto resultShape =
523+
cast<RankedTensorType>(op.getResult().getType()).getShape();
524+
auto sourceShape =
525+
cast<RankedTensorType>(sources.front().getType()).getShape();
526+
if (resultShape.equals(sourceShape)) {
527+
auto source = sources.front();
528+
result.replaceAllUsesWith(source);
529+
return success();
530+
}
513531
}
514532

515533
return failure();

third_party/amd/lib/TritonAMDGPUDialectToLLVM/ConcatOpToLLVM.cpp

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -115,4 +115,3 @@ void populateConcatOpToLLVMPatterns(mlir::LLVMTypeConverter &typeConverter,
115115
patterns.add<ConcatOpConversion>(typeConverter, benefit);
116116
}
117117
} // namespace mlir::triton::AMD
118-

third_party/amd/lib/TritonAMDGPUTransforms/RefineOps.cpp

Lines changed: 18 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -218,9 +218,10 @@ struct DotOpMFMAConverter {
218218
bool allowXF32 =
219219
dotOp.getInputPrecision() == InputPrecision::TF32 && mfmaVersion == 3;
220220

221-
FailureOr<MfmaIntrinsic> maybeMfmaInsn = MfmaIntrinsic::selectFor(
222-
mfmaVersion, mDim, nDim, kDimOperandSize, elemTyA, elemTyB,
223-
/*withScale=*/false, allowXF32);
221+
FailureOr<MfmaIntrinsic> maybeMfmaInsn =
222+
MfmaIntrinsic::selectFor(dotOp->getLoc(), mfmaVersion, mDim, nDim,
223+
kDimOperandSize, elemTyA, elemTyB,
224+
/*withScale=*/false, allowXF32);
224225

225226
SmallVector<unsigned> mfmaShape = {16, 16, 16};
226227
if (failed(maybeMfmaInsn)) {
@@ -265,15 +266,15 @@ struct DotOpMFMAConverter {
265266
auto extractSliceTypeA =
266267
RankedTensorType::get(refinedShapeA, elemTyA, encodeA);
267268
rewriter.setInsertionPointAfter(dotOp);
268-
SmallVector<SmallVector<amdgpu::ExtractSliceOp>> subtilesA;
269+
SmallVector<SmallVector<triton::amdgpu::ExtractSliceOp>> subtilesA;
269270
unsigned tileIdx = 0;
270271
for (int32_t k = 0; k < numRepK; ++k) {
271-
SmallVector<amdgpu::ExtractSliceOp> subtilesK;
272+
SmallVector<triton::amdgpu::ExtractSliceOp> subtilesK;
272273
for (int32_t i = 0; i < numRepM; ++i) {
273274
LDBG("local_load_a[" << i << "][" << k << "] extract_slice");
274275
int32_t shiftM = i * elementsPerSliceM;
275276
int32_t shiftK = k * elementsPerSliceK;
276-
auto extract = rewriter.create<amdgpu::ExtractSliceOp>(
277+
auto extract = rewriter.create<triton::amdgpu::ExtractSliceOp>(
277278
loc, Type{extractSliceTypeA}, Value{a},
278279
DenseI64ArrayAttr::get(ctx, {shiftM, shiftK}));
279280
// Add dot-tile info to local_load's slice;
@@ -304,15 +305,15 @@ struct DotOpMFMAConverter {
304305
// Extract slices for B operands.
305306
auto extractSliceTypeB =
306307
RankedTensorType::get(refinedShapeB, elemTyB, encodeB);
307-
SmallVector<SmallVector<amdgpu::ExtractSliceOp>> subtilesB;
308+
SmallVector<SmallVector<triton::amdgpu::ExtractSliceOp>> subtilesB;
308309
tileIdx = 0;
309310
for (int32_t k = 0; k < numRepK; ++k) {
310-
SmallVector<amdgpu::ExtractSliceOp> subtilesK;
311+
SmallVector<triton::amdgpu::ExtractSliceOp> subtilesK;
311312
for (int32_t j = 0; j < numRepN; ++j) {
312313
LDBG("local_load_b[" << k << "][" << j << "] extact_slice");
313314
int32_t shiftN = j * elementsPerSliceN;
314315
int32_t shiftK = k * elementsPerSliceK;
315-
auto extract = rewriter.create<amdgpu::ExtractSliceOp>(
316+
auto extract = rewriter.create<triton::amdgpu::ExtractSliceOp>(
316317
loc, Type{extractSliceTypeB}, Value{b},
317318
DenseI64ArrayAttr::get(ctx, {shiftK, shiftN}));
318319
// Add dot-tile info to local_load's slice;
@@ -404,9 +405,8 @@ struct DotOpMFMAConverter {
404405
}
405406
}
406407

407-
auto concatDims = DenseI64ArrayAttr::get(ctx, {numRepM, numRepN});
408408
auto joinedDotsResult = rewriter.create<triton::amdgpu::ConcatOp>(
409-
loc, dTensorTy, refinedDotValues, concatDims);
409+
loc, dTensorTy, refinedDotValues);
410410

411411
d.replaceAllUsesWith(joinedDotsResult);
412412

@@ -545,17 +545,8 @@ struct LocalLoadOpPattern
545545
}
546546
}
547547

548-
// concat dims is correct shape 8x1 vs 1x8, else gives wrong output shape.
549-
std::vector<int64_t> loweringOrder(numReps2D.size());
550-
int64_t counter = 0;
551-
auto increment = [&counter](int64_t &val) { val = counter++; };
552-
if (opIdx == 0)
553-
std::for_each(loweringOrder.rbegin(), loweringOrder.rend(), increment);
554-
else
555-
std::for_each(loweringOrder.begin(), loweringOrder.end(), increment);
556-
557-
auto joinedResult = rewriter.create<triton::amdgpu::ConcatOp>(
558-
loc, resultType, subtiles, numReps2D, loweringOrder);
548+
auto joinedResult =
549+
rewriter.create<triton::amdgpu::ConcatOp>(loc, resultType, subtiles);
559550
LDBG("ConcatOp: " << *joinedResult);
560551

561552
rewriter.replaceOp(op, joinedResult);
@@ -617,9 +608,8 @@ struct LoadOpPattern : public RefineRewritePattern<triton::LoadOp> {
617608
refinedTensors.push_back(refinedTensor);
618609
}
619610

620-
auto concatDims = DenseI64ArrayAttr::get(ctx, refinedBlock.numPerDims);
621611
auto joinedResult = rewriter.create<triton::amdgpu::ConcatOp>(
622-
loc, origResultType, refinedTensors, concatDims);
612+
loc, origResultType, refinedTensors);
623613

624614
origResult.replaceAllUsesWith(joinedResult);
625615
return success();
@@ -831,10 +821,8 @@ struct ReduceOpPattern : public RefineRewritePattern<triton::ReduceOp> {
831821

832822
// Concat reduce slices.
833823
auto reduceResultType = op.getResultTypes()[0];
834-
SmallVector<int64_t> concatDimShape = {numReps};
835-
auto concatDims = DenseI64ArrayAttr::get(ctx, concatDimShape);
836824
auto concatOp = rewriter.create<triton::amdgpu::ConcatOp>(
837-
loc, reduceResultType, refinedReduces, concatDims);
825+
loc, reduceResultType, refinedReduces);
838826
auto origOpResult = op.getResult();
839827
origOpResult.replaceAllUsesWith(concatOp);
840828
rewriter.eraseOp(op);
@@ -952,9 +940,8 @@ struct ElementWiseOpPattern : public RefineRewritePattern<OpTy> {
952940

953941
// Concat slices.
954942
auto resultType = op->getResultTypes()[0];
955-
auto concatDims = DenseI64ArrayAttr::get(op->getContext(), numReps);
956943
auto concatOp = rewriter.create<triton::amdgpu::ConcatOp>(
957-
op.getLoc(), resultType, refinedOps, concatDims);
944+
op.getLoc(), resultType, refinedOps);
958945

959946
auto origOpResult = op.getResult();
960947
origOpResult.replaceAllUsesWith(concatOp);
@@ -1076,11 +1063,8 @@ struct ExpandDimsOpPattern : public RefineRewritePattern<triton::ExpandDimsOp> {
10761063

10771064
// Concat refined ops.
10781065
auto reduceResultType = op->getResultTypes()[0];
1079-
// Expand dims of numReps also before concat.
1080-
numReps.insert(numReps.begin() + op.getAxis(), 1);
1081-
auto concatDims = DenseI64ArrayAttr::get(op->getContext(), numReps);
10821066
auto concatOp = rewriter.create<triton::amdgpu::ConcatOp>(
1083-
op.getLoc(), reduceResultType, refinedReduces, concatDims);
1067+
op.getLoc(), reduceResultType, refinedReduces);
10841068
auto origOpResult = op.getResult();
10851069

10861070
auto checkLL = triton::gpu::toLinearEncoding(
@@ -1195,9 +1179,8 @@ struct BroadcastOpPattern : public RefineRewritePattern<BroadcastOp> {
11951179

11961180
// Concat refined ops.
11971181
auto reduceResultType = op->getResultTypes()[0];
1198-
auto concatDims = DenseI64ArrayAttr::get(op->getContext(), numReps);
11991182
auto concatOp = rewriter.create<triton::amdgpu::ConcatOp>(
1200-
op.getLoc(), reduceResultType, refinedBroadcasts, concatDims);
1183+
op.getLoc(), reduceResultType, refinedBroadcasts);
12011184

12021185
auto origOpResult = op.getResult();
12031186
origOpResult.replaceAllUsesWith(concatOp);

third_party/amd/lib/TritonAMDGPUTransforms/RescheduleOps.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
#include "mlir/Conversion/ArithToLLVM/ArithToLLVM.h"
22
#include "mlir/Conversion/GPUToROCDL/GPUToROCDLPass.h"
33
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
4+
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
45
#include "mlir/Dialect/LLVMIR/ROCDLDialect.h"
56
#include "mlir/Pass/Pass.h"
67
#include "third_party/amd/include/Dialect/TritonAMDGPU/IR/Dialect.h"

0 commit comments

Comments
 (0)