Skip to content

Commit 9975e07

Browse files
committed
Fix failures due to incorrect mfma selection
1 parent 4099bfa commit 9975e07

File tree

5 files changed

+53
-26
lines changed

5 files changed

+53
-26
lines changed

external/llvm-project/mlir/lib/Conversion/GPUToROCDL/LowerGpuOpsToROCDLOps.cpp

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
#include "mlir/Conversion/LLVMCommon/Pattern.h"
2727
#include "mlir/Conversion/LLVMCommon/TypeConverter.h"
2828
#include "mlir/Conversion/MathToLLVM/MathToLLVM.h"
29+
#include "mlir/Conversion/VectorToLLVM/ConvertVectorToLLVM.h"
2930
#include "mlir/Conversion/MathToROCDL/MathToROCDL.h"
3031
#include "mlir/Dialect/AMDGPU/IR/AMDGPUDialect.h"
3132
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
@@ -42,6 +43,7 @@
4243
#include "mlir/Transforms/DialectConversion.h"
4344
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
4445
#include "llvm/Support/FormatVariadic.h"
46+
#include "mlir/Dialect/Vector/Transforms/VectorRewritePatterns.h"
4547

4648
#include "../GPUCommon/GPUOpsLowering.h"
4749
#include "../GPUCommon/IndexIntrinsicsOpLowering.h"
@@ -375,6 +377,11 @@ struct LowerGpuOpsToROCDLOpsPass final
375377

376378
populateAMDGPUToROCDLConversionPatterns(converter, llvmPatterns,
377379
*maybeChipset);
380+
// TODO (rocmlir): remove hardcoded passes
381+
// related PR: https://github.com/llvm/llvm-project/pull/124439
382+
mlir::vector::populateVectorInsertExtractStridedSliceTransforms(
383+
llvmPatterns);
384+
// TODO: ends here
378385
populateGpuToROCDLConversionPatterns(converter, llvmPatterns, runtime,
379386
*maybeChipset);
380387
configureGpuToROCDLConversionLegality(target);
@@ -411,6 +418,9 @@ void mlir::configureGpuToROCDLConversionLegality(ConversionTarget &target) {
411418
target.addLegalDialect<::mlir::LLVM::LLVMDialect>();
412419
target.addLegalDialect<ROCDL::ROCDLDialect>();
413420
target.addIllegalDialect<gpu::GPUDialect>();
421+
// TODO (rocmlir): remove vector::VectorDialect
422+
// related PR: https://github.com/llvm/llvm-project/pull/124439
423+
target.addIllegalDialect<gpu::GPUDialect, vector::VectorDialect>();
414424
target.addIllegalOp<LLVM::CosOp, LLVM::ExpOp, LLVM::Exp2Op, LLVM::FCeilOp,
415425
LLVM::FFloorOp, LLVM::FRemOp, LLVM::LogOp, LLVM::Log10Op,
416426
LLVM::Log2Op, LLVM::PowOp, LLVM::SinOp>();

mlir/include/mlir/Dialect/Rock/IR/MfmaInsnGroup.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -136,7 +136,7 @@ class MfmaInsnGroup {
136136

137137
public:
138138
static FailureOr<MfmaInsnGroup> select(Type elementTypeA, Type elementTypeB,
139-
StringRef arch, int64_t mnPerXdl);
139+
StringRef arch, int64_t mnPerXdl, int64_t kPack, int64_t kPackPerBlock);
140140
MfmaInsnGroup(Type elementTypeA, Type elementTypeB, const MfmaInsn &insn,
141141
const MfmaInsnGroupAttr &groupAttr);
142142
int64_t getMRepeats(int64_t mPerWave);

mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp

Lines changed: 30 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
#include "mlir/Dialect/Rock/utility/math.h"
77
#include "mlir/IR/Builders.h"
88
#include "mlir/IR/BuiltinTypes.h"
9+
#include "mlir/Support/LLVM.h"
910

1011
#include "llvm/Support/Debug.h"
1112
#include "llvm/Support/ErrorHandling.h"
@@ -423,22 +424,22 @@ static auto getMfmaInsnGroupAttrMapGfx942 = []() {
423424
static auto getMfmaInsnGroupAttrMapGfx950 = []() {
424425
static MfmaInsnGroupMap groupAttrMap{
425426
// fp16 double rate
426-
{{MfmaTypeId::Fp16TyId, 16, 16},
427-
{ROCDL::mfma_f32_16x16x32_f16::getOperationName()}},
428427
{{MfmaTypeId::Fp16TyId, 32, 32},
429428
{ROCDL::mfma_f32_32x32x16_f16::getOperationName()}},
429+
{{MfmaTypeId::Fp16TyId, 16, 16},
430+
{ROCDL::mfma_f32_16x16x32_f16::getOperationName()}},
431+
430432
// bfp16 double rate
431-
{{MfmaTypeId::Bf16TyId, 16, 16},
432-
{ROCDL::mfma_f32_16x16x32_bf16::getOperationName()}},
433433
{{MfmaTypeId::Bf16TyId, 32, 32},
434434
{ROCDL::mfma_f32_32x32x16_bf16::getOperationName()}},
435+
{{MfmaTypeId::Bf16TyId, 16, 16},
436+
{ROCDL::mfma_f32_16x16x32_bf16::getOperationName()}},
437+
435438
// i8 double rate
436-
{{MfmaTypeId::I8TyId, 16, 16},
437-
{ROCDL::mfma_i32_16x16x64_i8::getOperationName()}},
438439
{{MfmaTypeId::I8TyId, 32, 32},
439-
{ROCDL::mfma_i32_32x32x32_i8::getOperationName()}}
440-
441-
};
440+
{ROCDL::mfma_i32_32x32x32_i8::getOperationName()}},
441+
{{MfmaTypeId::I8TyId, 16, 16},
442+
{ROCDL::mfma_i32_16x16x64_i8::getOperationName()}}};
442443
return groupAttrMap;
443444
};
444445

@@ -546,10 +547,9 @@ static MfmaTypeId convertTypesToId(Type dataTypeA, Type dataTypeB) {
546547
llvm_unreachable("Unsupported input argument type.");
547548
}
548549

549-
FailureOr<MfmaInsnGroup> MfmaInsnGroup::select(Type elementTypeA,
550-
Type elementTypeB,
551-
StringRef arch,
552-
int64_t mnPerXdl) {
550+
FailureOr<MfmaInsnGroup>
551+
MfmaInsnGroup::select(Type elementTypeA, Type elementTypeB, StringRef arch,
552+
int64_t mnPerXdl, int64_t kPack, int64_t kPackPerBlock) {
553553
LLVM_DEBUG(llvm::dbgs() << "Invoke Mfma group selection:\n"
554554
<< "elementType A: " << elementTypeA << "\n"
555555
<< "elementType B: " << elementTypeB << "\n"
@@ -594,6 +594,15 @@ FailureOr<MfmaInsnGroup> MfmaInsnGroup::select(Type elementTypeA,
594594
} else {
595595
// gfx950 has double rate instructions. Select from those first.
596596
selectFrom(getMfmaInsnGroupAttrMapGfx950());
597+
if (succeeded(result)) {
598+
if (result->isCoherentWithK(kPack, kPackPerBlock)) {
599+
LLVM_DEBUG(llvm::dbgs()
600+
<< "Selected gfx950 double rate instruction\n");
601+
return result;
602+
}
603+
// else select again
604+
result = failure();
605+
}
597606
selectFrom(getMfmaInsnGroupAttrMapGfx90aPlusBf16());
598607
}
599608
}
@@ -605,6 +614,14 @@ FailureOr<MfmaInsnGroup> MfmaInsnGroup::select(Type elementTypeA,
605614
} else if (isGfx95x) {
606615
// select from new double rate instructions first
607616
selectFrom(getMfmaInsnGroupAttrMapGfx950());
617+
if (succeeded(result)) {
618+
if (result->isCoherentWithK(kPack, kPackPerBlock)) {
619+
LLVM_DEBUG(llvm::dbgs() << "Selected gfx950 double rate instruction\n");
620+
return result;
621+
}
622+
// else select again
623+
result = failure();
624+
}
608625
// all previous instructions are still valid for gfx950
609626
selectFrom(getMfmaInsnGroupAttrMapGfx942());
610627
}

mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,8 @@ LogicalResult PopulateParamsXDL::isValidBlockwiseGemm(
525525
mnPerXdl = derivedParam.getMnPerXdl();
526526
}
527527
auto maybeMfmaInsnGroup =
528-
MfmaInsnGroup::select(dataTypeA, dataTypeB, arch, mnPerXdl);
528+
MfmaInsnGroup::select(dataTypeA, dataTypeB, arch, mnPerXdl,
529+
param.getKpack(), param.getKpackPerBlock());
529530
if (failed(maybeMfmaInsnGroup)) {
530531
LLVM_DEBUG(llvm::dbgs() << "Failed to select xdlops instruction group.\n");
531532
return failure();
@@ -580,7 +581,8 @@ PopulateParamsXDL::getTuningParameters(KernelType opType, Type dataTypeA,
580581
[&](const InitParamsAccel &param) {
581582
int64_t mnPerXdl = param.gemmNPerWaveOrMnPerXdl;
582583
auto maybeMfmaInsnGroup =
583-
MfmaInsnGroup::select(dataTypeA, dataTypeB, arch, mnPerXdl);
584+
MfmaInsnGroup::select(dataTypeA, dataTypeB, arch, mnPerXdl,
585+
param.gemmKPack, param.gemmKPerBlock);
584586
if (failed(maybeMfmaInsnGroup)) {
585587
return false;
586588
}
@@ -653,9 +655,9 @@ LogicalResult PopulateParamsWmma::isValidBlockwiseGemm(
653655
if (minDPerWave <= 16) {
654656
validKPerWaveFactor = 4;
655657
}
656-
if (!((param.getMPerBlock() % minDPerWave == 0) &&
657-
(param.getNPerBlock() % minDPerWave == 0) &&
658-
(param.getKpackPerBlock() % validKPerWaveFactor == 0))) {
658+
if ((param.getMPerBlock() % minDPerWave != 0) ||
659+
(param.getNPerBlock() % minDPerWave != 0) ||
660+
(param.getKpackPerBlock() % validKPerWaveFactor != 0)) {
659661
return failure();
660662
}
661663

@@ -689,7 +691,7 @@ LogicalResult PopulateParamsWmma::isValidBlockwiseGemm(
689691

690692
// Sledgehammer hotfix because not unrolling sometimes makes the register
691693
// allocator break. This should be refined quickly.
692-
if (param.getForceUnroll() == false) {
694+
if (!param.getForceUnroll()) {
693695
return failure();
694696
}
695697

@@ -756,10 +758,7 @@ PopulateParamsWmma::getTuningParameters(KernelType opType, Type dataTypeA,
756758
return false;
757759
}
758760
WmmaInsn wmmaInsn = *maybeWmmaInsn;
759-
if (!wmmaInsn.isCoherentWithK(param.gemmKPack, param.gemmKPerBlock)) {
760-
return false;
761-
}
762-
return true;
761+
return wmmaInsn.isCoherentWithK(param.gemmKPack, param.gemmKPerBlock);
763762
});
764763
return res;
765764
}

mlir/lib/Dialect/Rock/utility/AccelEmitter.cpp

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1196,8 +1196,9 @@ AccelEmitter::select(GemmFeatures features, Type dataTypeA, Type dataTypeB,
11961196
if (isMfma) {
11971197
XdlopsGemmDerivedParamsAttr mfmaParams =
11981198
cast<XdlopsGemmDerivedParamsAttr>(tuningParams);
1199-
auto maybeMfmaInsnGroup = MfmaInsnGroup::select(dataTypeA, dataTypeB, arch,
1200-
mfmaParams.getMnPerXdl());
1199+
auto maybeMfmaInsnGroup = MfmaInsnGroup::select(
1200+
dataTypeA, dataTypeB, arch, mfmaParams.getMnPerXdl(),
1201+
mfmaParams.getKpack(), mfmaParams.getKpackPerBlock());
12011202
if (failed(maybeMfmaInsnGroup)) {
12021203
return nullptr;
12031204
}

0 commit comments

Comments
 (0)