Skip to content

Commit c2ec6f0

Browse files
committed
add double rate insns
1 parent b3732c6 commit c2ec6f0

File tree

3 files changed

+246
-24
lines changed

3 files changed

+246
-24
lines changed

mlir/lib/Dialect/Rock/IR/MfmaInsnGroup.cpp

Lines changed: 72 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -17,8 +17,8 @@ using namespace mlir;
1717
using namespace mlir::rock;
1818

1919
// The static initialization will follow the defined ordering
20-
// of the below lambdas
21-
auto getMfmaInsnInfoMap = []() -> const llvm::StringMap<MfmaInsnInfo> & {
20+
// of the below lambda
21+
static auto getMfmaInsnInfoMap = []() -> const llvm::StringMap<MfmaInsnInfo> & {
2222
static llvm::StringMap<MfmaInsnInfo> insnInfo{
2323
// fp32
2424
{ROCDL::mfma_f32_32x32x1f32::getOperationName(),
@@ -37,8 +37,12 @@ auto getMfmaInsnInfoMap = []() -> const llvm::StringMap<MfmaInsnInfo> & {
3737
{MfmaTypeId::Fp16TyId, 32, 4, 2}},
3838
{ROCDL::mfma_f32_32x32x8f16::getOperationName(),
3939
{MfmaTypeId::Fp16TyId, 32, 8, 1}},
40+
{ROCDL::mfma_f32_32x32x16_f16::getOperationName(),
41+
{MfmaTypeId::Fp16TyId, 32, 16, 1}},
4042
{ROCDL::mfma_f32_16x16x4f16::getOperationName(),
4143
{MfmaTypeId::Fp16TyId, 16, 4, 4}},
44+
{ROCDL::mfma_f32_16x16x32_f16::getOperationName(),
45+
{MfmaTypeId::Fp16TyId, 16, 32, 1}},
4246
{ROCDL::mfma_f32_16x16x16f16::getOperationName(),
4347
{MfmaTypeId::Fp16TyId, 16, 16, 1}},
4448
{ROCDL::mfma_f32_4x4x4f16::getOperationName(),
@@ -47,10 +51,14 @@ auto getMfmaInsnInfoMap = []() -> const llvm::StringMap<MfmaInsnInfo> & {
4751
// bf16
4852
{ROCDL::mfma_f32_32x32x2bf16::getOperationName(),
4953
{MfmaTypeId::Bf16TyId, 32, 2, 2}},
54+
{ROCDL::mfma_f32_32x32x16_bf16::getOperationName(),
55+
{MfmaTypeId::Bf16TyId, 32, 16, 1}},
5056
{ROCDL::mfma_f32_32x32x4bf16::getOperationName(),
5157
{MfmaTypeId::Bf16TyId, 32, 4, 1}},
5258
{ROCDL::mfma_f32_16x16x2bf16::getOperationName(),
5359
{MfmaTypeId::Bf16TyId, 16, 2, 4}},
60+
{ROCDL::mfma_f32_16x16x32_bf16::getOperationName(),
61+
{MfmaTypeId::Bf16TyId, 16, 32, 1}},
5462
{ROCDL::mfma_f32_16x16x8bf16::getOperationName(),
5563
{MfmaTypeId::Bf16TyId, 16, 8, 1}},
5664
{ROCDL::mfma_f32_4x4x2bf16::getOperationName(),
@@ -77,8 +85,12 @@ auto getMfmaInsnInfoMap = []() -> const llvm::StringMap<MfmaInsnInfo> & {
7785
// i8 (new)
7886
{ROCDL::mfma_i32_32x32x16_i8::getOperationName(),
7987
{MfmaTypeId::I8TyId, 32, 16, 1}},
88+
{ROCDL::mfma_i32_32x32x32_i8::getOperationName(),
89+
{MfmaTypeId::I8TyId, 32, 32, 1}},
8090
{ROCDL::mfma_i32_16x16x32_i8::getOperationName(),
8191
{MfmaTypeId::I8TyId, 16, 32, 1}},
92+
{ROCDL::mfma_i32_16x16x64_i8::getOperationName(),
93+
{MfmaTypeId::I8TyId, 16, 64, 1}},
8294

8395
// fp8
8496
{ROCDL::mfma_f32_32x32x16_fp8_fp8::getOperationName(),
@@ -178,7 +190,7 @@ static MfmaInsnAttr deriveAttr(MfmaInsnInfo info) {
178190
isKReduction};
179191
}
180192

181-
auto getMfmaInsnAttrMap = []() -> const llvm::StringMap<MfmaInsnAttr> & {
193+
static auto getMfmaInsnAttrMap = []() -> const llvm::StringMap<MfmaInsnAttr> & {
182194
static llvm::StringMap<MfmaInsnAttr> insnDb;
183195
static std::once_flag once;
184196
std::call_once(once, [&]() {
@@ -194,7 +206,7 @@ auto getMfmaInsnAttrMap = []() -> const llvm::StringMap<MfmaInsnAttr> & {
194206
using MfmaInsnGroupMap =
195207
llvm::DenseMap<MfmaInsnGroupSelectKey, MfmaInsnGroupAttr,
196208
MfmaInsnGroupSelectKeyInfo>;
197-
auto getMfmaInsnGroupAttrMapAllArch = []() -> const MfmaInsnGroupMap & {
209+
static auto getMfmaInsnGroupAttrMapAllArch = []() -> const MfmaInsnGroupMap & {
198210
using amdgpu::MFMAPermB;
199211
static MfmaInsnGroupMap
200212
// f32
@@ -242,7 +254,8 @@ auto getMfmaInsnGroupAttrMapAllArch = []() -> const MfmaInsnGroupMap & {
242254
return groupAttrMap;
243255
};
244256

245-
auto getMfmaInsnGroupAttrMapGfx908Bf16 = []() -> const MfmaInsnGroupMap & {
257+
static auto getMfmaInsnGroupAttrMapGfx908Bf16 =
258+
[]() -> const MfmaInsnGroupMap & {
246259
using amdgpu::MFMAPermB;
247260
static MfmaInsnGroupMap
248261
// bf16
@@ -269,7 +282,7 @@ auto getMfmaInsnGroupAttrMapGfx908Bf16 = []() -> const MfmaInsnGroupMap & {
269282
return groupAttrMap;
270283
};
271284

272-
auto getMfmaInsnGroupAttrMapGfx90aPlusBf16 = []() {
285+
static auto getMfmaInsnGroupAttrMapGfx90aPlusBf16 = []() {
273286
using amdgpu::MFMAPermB;
274287
static llvm::DenseMap<MfmaInsnGroupSelectKey, MfmaInsnGroupAttr,
275288
MfmaInsnGroupSelectKeyInfo>
@@ -297,7 +310,7 @@ auto getMfmaInsnGroupAttrMapGfx90aPlusBf16 = []() {
297310
return groupAttrMap;
298311
};
299312

300-
auto getMfmaInsnGroupAttrMapPreGfx942Int8 = []() {
313+
static auto getMfmaInsnGroupAttrMapPreGfx942Int8 = []() {
301314
using amdgpu::MFMAPermB;
302315
static llvm::DenseMap<MfmaInsnGroupSelectKey, MfmaInsnGroupAttr,
303316
MfmaInsnGroupSelectKeyInfo>
@@ -321,7 +334,7 @@ auto getMfmaInsnGroupAttrMapPreGfx942Int8 = []() {
321334
};
322335

323336
// New I8 and all Float8
324-
auto getMfmaInsnGroupAttrMapGfx942Plus = []() {
337+
static auto getMfmaInsnGroupAttrMapGfx942 = []() {
325338
using amdgpu::MFMAPermB;
326339
static MfmaInsnGroupMap
327340
// Int8
@@ -407,6 +420,28 @@ auto getMfmaInsnGroupAttrMapGfx942Plus = []() {
407420
return groupAttrMap;
408421
};
409422

423+
static auto getMfmaInsnGroupAttrMapGfx950 = []() {
424+
static MfmaInsnGroupMap groupAttrMap{
425+
// fp16 double rate
426+
{{MfmaTypeId::Fp16TyId, 16, 16},
427+
{ROCDL::mfma_f32_16x16x32_f16::getOperationName()}},
428+
{{MfmaTypeId::Fp16TyId, 32, 32},
429+
{ROCDL::mfma_f32_32x32x16_f16::getOperationName()}},
430+
// bfp16 double rate
431+
{{MfmaTypeId::Bf16TyId, 16, 16},
432+
{ROCDL::mfma_f32_16x16x32_bf16::getOperationName()}},
433+
{{MfmaTypeId::Bf16TyId, 32, 32},
434+
{ROCDL::mfma_f32_32x32x16_bf16::getOperationName()}},
435+
// i8 double rate
436+
{{MfmaTypeId::I8TyId, 16, 16},
437+
{ROCDL::mfma_i32_16x16x64_i8::getOperationName()}},
438+
{{MfmaTypeId::I8TyId, 32, 32},
439+
{ROCDL::mfma_i32_32x32x32_i8::getOperationName()}}
440+
441+
};
442+
return groupAttrMap;
443+
};
444+
410445
FailureOr<MfmaInsn> MfmaInsn::select(StringRef mfmaInsn) {
411446
auto mfmaInsnAttrMap = getMfmaInsnAttrMap();
412447
auto it = mfmaInsnAttrMap.find(mfmaInsn);
@@ -546,13 +581,35 @@ FailureOr<MfmaInsnGroup> MfmaInsnGroup::select(Type elementTypeA,
546581
result = MfmaInsnGroup(elementTypeA, elementTypeB, *maybeInsn, groupAttr);
547582
}
548583
};
549-
bool hasOldBf16 = arch.contains("gfx908");
550-
bool isPreGfx942 = arch.contains("gfx908") || arch.contains("gfx90a");
551-
if (elementTypeA.isBF16())
552-
selectFrom(hasOldBf16 ? getMfmaInsnGroupAttrMapGfx908Bf16()
553-
: getMfmaInsnGroupAttrMapGfx90aPlusBf16());
554-
selectFrom(isPreGfx942 ? getMfmaInsnGroupAttrMapPreGfx942Int8()
555-
: getMfmaInsnGroupAttrMapGfx942Plus());
584+
bool isGfx908 = arch.contains("gfx908");
585+
bool isGfx90a = arch.contains("gfx908") || arch.contains("gfx90a");
586+
bool isGfx94x = arch.contains("gfx942") || arch.contains("gfx940");
587+
bool isGfx95x = arch.contains("gfx950");
588+
// TODO: refactor this later to not keep multiple maps for different arches
589+
if (elementTypeA.isBF16()) {
590+
if (isGfx908) {
591+
selectFrom(getMfmaInsnGroupAttrMapGfx908Bf16());
592+
} else if (isGfx94x || isGfx90a) {
593+
selectFrom(getMfmaInsnGroupAttrMapGfx90aPlusBf16());
594+
} else {
595+
// gfx950 has double rate instructions. Select from those first.
596+
selectFrom(getMfmaInsnGroupAttrMapGfx950());
597+
selectFrom(getMfmaInsnGroupAttrMapGfx90aPlusBf16());
598+
}
599+
}
600+
601+
if (isGfx908 || isGfx90a) {
602+
selectFrom(getMfmaInsnGroupAttrMapPreGfx942Int8());
603+
} else if (isGfx94x) {
604+
selectFrom(getMfmaInsnGroupAttrMapGfx942());
605+
} else if (isGfx95x) {
606+
// select from new double rate instructions first
607+
selectFrom(getMfmaInsnGroupAttrMapGfx950());
608+
// all previous instructions are still valid for gfx950
609+
selectFrom(getMfmaInsnGroupAttrMapGfx942());
610+
}
611+
// select from all available instructions on all architectures if it is not
612+
// been selected yet
556613
selectFrom(getMfmaInsnGroupAttrMapAllArch());
557614
if (failed(result)) {
558615
LLVM_DEBUG(llvm::dbgs() << "No match found in MFMA database\n");

mlir/lib/Dialect/Rock/Tuning/GridwiseGemmParams.cpp

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -463,10 +463,10 @@ LogicalResult PopulateParamsXDL::isValidBlockwiseGemm(
463463
if (minDPerWave <= 16) {
464464
validKPerWaveFactor = 4;
465465
}
466-
if (!((param.getMPerBlock() % minDPerWave == 0) &&
467-
(param.getNPerBlock() % minDPerWave == 0) &&
468-
((param.getKpackPerBlock() * param.getKpack()) % validKPerWaveFactor ==
469-
0))) {
466+
if ((param.getMPerBlock() % minDPerWave != 0) ||
467+
(param.getNPerBlock() % minDPerWave != 0) ||
468+
((param.getKpackPerBlock() * param.getKpack()) % validKPerWaveFactor !=
469+
0)) {
470470
return failure();
471471
}
472472

@@ -515,7 +515,7 @@ LogicalResult PopulateParamsXDL::isValidBlockwiseGemm(
515515

516516
// Sledgehammer hotfix because not unrolling sometimes makes the register
517517
// allocator break. This should be refined quickly.
518-
if (cast<RockTuningParamAttrInterface>(param).getForceUnroll() == false) {
518+
if (!cast<RockTuningParamAttrInterface>(param).getForceUnroll()) {
519519
return failure();
520520
}
521521

@@ -585,10 +585,7 @@ PopulateParamsXDL::getTuningParameters(KernelType opType, Type dataTypeA,
585585
return false;
586586
}
587587
MfmaInsnGroup mfmaGroup = *maybeMfmaInsnGroup;
588-
if (!mfmaGroup.isCoherentWithK(param.gemmKPack, param.gemmKPerBlock)) {
589-
return false;
590-
}
591-
return true;
588+
return mfmaGroup.isCoherentWithK(param.gemmKPack, param.gemmKPerBlock);
592589
});
593590
return res;
594591
}

mlir/test/Dialect/Rock/lowering_xdlops_gemm.mlir

Lines changed: 168 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -270,3 +270,171 @@ func.func @accel_gemm_fp8_bf8_ocp(%matrixA : memref<1x4xvector<8xf8E4M3FN>, #gpu
270270
} : memref<2x2xvector<16xf32>, #gpu.address_space<private>> += memref<1x4xvector<8xf8E4M3FN>, #gpu.address_space<private>> * memref<1x4xvector<8xf8E5M2>, #gpu.address_space<private>>
271271
return
272272
}
273+
274+
func.func @accel_gemm_gfx950_f16_16x16x32(%matrixA : memref<1x2xvector<8xf16>, 5>,
275+
%matrixB : memref<1x2xvector<8xf16>, 5>,
276+
%matrixC : memref<1x1xvector<4xf32>, 5>) {
277+
// CHECK-LABEL: func.func @accel_gemm_gfx950_f16_16x16x32
278+
// CHECK: rock.transforming_for
279+
// CHECK-SAME: bounds [1, 1, 1]
280+
// CHECK: amdgpu.mfma
281+
// CHECK-SAME: blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32
282+
// CHECK-NOT: amdgpu.mfma
283+
%c0 = arith.constant 0 : index
284+
rock.threadwise_accel_gemm %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma {
285+
arch = "amdgcn-amd-amdhsa:gfx950",
286+
params = #rock.xdlops_gemm_derived_params<
287+
kpackPerBlock = 8,
288+
kpack = 8,
289+
mPerWave = 16,
290+
nPerWave = 16,
291+
mPerBlock = 16,
292+
nPerBlock = 16,
293+
mnPerXdl = 16,
294+
splitKFactor = 1,
295+
scheduleVersion = 1,
296+
outputSwizzle = 2,
297+
forceUnroll = true>
298+
} : memref<1x1xvector<4xf32>, 5> += memref<1x2xvector<8xf16>, 5> * memref<1x2xvector<8xf16>, 5>
299+
return
300+
}
301+
302+
func.func @accel_gemm_gfx950_bf16_16x16x32(%matrixA : memref<1x2xvector<8xbf16>, 5>,
303+
%matrixB : memref<1x2xvector<8xbf16>, 5>,
304+
%matrixC : memref<1x1xvector<4xf32>, 5>) {
305+
// CHECK-LABEL: func.func @accel_gemm_gfx950_bf16_16x16x32
306+
// CHECK: rock.transforming_for
307+
// CHECK-SAME: bounds [1, 1, 1]
308+
// CHECK: amdgpu.mfma
309+
// CHECK-SAME: blocks = 1 : i32, k = 32 : i32, m = 16 : i32, n = 16 : i32
310+
// CHECK-NOT: amdgpu.mfma
311+
%c0 = arith.constant 0 : index
312+
rock.threadwise_accel_gemm %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma {
313+
arch = "amdgcn-amd-amdhsa:gfx950",
314+
params = #rock.xdlops_gemm_derived_params<
315+
kpackPerBlock = 8,
316+
kpack = 8,
317+
mPerWave = 16,
318+
nPerWave = 16,
319+
mPerBlock = 16,
320+
nPerBlock = 16,
321+
mnPerXdl = 16,
322+
splitKFactor = 1,
323+
scheduleVersion = 1,
324+
outputSwizzle = 2,
325+
forceUnroll = true>
326+
} : memref<1x1xvector<4xf32>, 5> += memref<1x2xvector<8xbf16>, 5> * memref<1x2xvector<8xbf16>, 5>
327+
return
328+
}
329+
330+
func.func @accel_gemm_gfx950_f16_32x32x16(%matrixA : memref<1x2xvector<8xf16>, 5>,
331+
%matrixB : memref<1x2xvector<8xf16>, 5>,
332+
%matrixC : memref<1x1xvector<16xf32>, 5>) {
333+
// CHECK-LABEL: func.func @accel_gemm_gfx950_f16_32x32x16
334+
// CHECK: rock.transforming_for
335+
// CHECK-SAME: bounds [1, 1, 1]
336+
// CHECK: amdgpu.mfma
337+
// CHECK-SAME: blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32
338+
// CHECK-NOT: amdgpu.mfma
339+
%c0 = arith.constant 0 : index
340+
rock.threadwise_accel_gemm %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma {
341+
arch = "amdgcn-amd-amdhsa:gfx950",
342+
params = #rock.xdlops_gemm_derived_params<
343+
kpackPerBlock = 4,
344+
kpack = 8,
345+
mPerWave = 32,
346+
nPerWave = 32,
347+
mPerBlock = 32,
348+
nPerBlock = 32,
349+
mnPerXdl = 32,
350+
splitKFactor = 1,
351+
scheduleVersion = 1,
352+
outputSwizzle = 2,
353+
forceUnroll = true>
354+
} : memref<1x1xvector<16xf32>, 5> += memref<1x2xvector<8xf16>, 5> * memref<1x2xvector<8xf16>, 5>
355+
return
356+
}
357+
358+
func.func @accel_gemm_gfx950_bf16_32x32x16(%matrixA : memref<1x2xvector<8xbf16>, 5>,
359+
%matrixB : memref<1x2xvector<8xbf16>, 5>,
360+
%matrixC : memref<1x1xvector<16xf32>, 5>) {
361+
// CHECK-LABEL: func.func @accel_gemm_gfx950_bf16_32x32x16
362+
// CHECK: rock.transforming_for
363+
// CHECK-SAME: bounds [1, 1, 1]
364+
// CHECK: amdgpu.mfma
365+
// CHECK-SAME: blocks = 1 : i32, k = 16 : i32, m = 32 : i32, n = 32 : i32
366+
// CHECK-NOT: amdgpu.mfma
367+
%c0 = arith.constant 0 : index
368+
rock.threadwise_accel_gemm %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma {
369+
arch = "amdgcn-amd-amdhsa:gfx950",
370+
params = #rock.xdlops_gemm_derived_params<
371+
kpackPerBlock = 4,
372+
kpack = 8,
373+
mPerWave = 32,
374+
nPerWave = 32,
375+
mPerBlock = 32,
376+
nPerBlock = 32,
377+
mnPerXdl = 32,
378+
splitKFactor = 1,
379+
scheduleVersion = 1,
380+
outputSwizzle = 2,
381+
forceUnroll = true>
382+
} : memref<1x1xvector<16xf32>, 5> += memref<1x2xvector<8xbf16>, 5> * memref<1x2xvector<8xbf16>, 5>
383+
return
384+
}
385+
386+
func.func @accel_gemm_gfx950_i8_32x32x32(%matrixA : memref<1x4xvector<16xi8>, 5>,
387+
%matrixB : memref<1x4xvector<16xi8>, 5>,
388+
%matrixC : memref<1x1xvector<16xi32>, 5>) {
389+
// CHECK-LABEL: func.func @accel_gemm_gfx950_i8_32x32x32
390+
// CHECK: rock.transforming_for
391+
// CHECK-SAME: bounds [1, 1, 1]
392+
// CHECK: amdgpu.mfma
393+
// CHECK-SAME: blocks = 1 : i32, k = 32 : i32, m = 32 : i32, n = 32 : i32
394+
// CHECK-NOT: amdgpu.mfma
395+
%c0 = arith.constant 0 : index
396+
rock.threadwise_accel_gemm %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma {
397+
arch = "amdgcn-amd-amdhsa:gfx950",
398+
params = #rock.xdlops_gemm_derived_params<
399+
kpackPerBlock = 8,
400+
kpack = 16,
401+
mPerWave = 32,
402+
nPerWave = 32,
403+
mPerBlock = 32,
404+
nPerBlock = 32,
405+
mnPerXdl = 32,
406+
splitKFactor = 1,
407+
scheduleVersion = 1,
408+
outputSwizzle = 2,
409+
forceUnroll = true>
410+
} : memref<1x1xvector<16xi32>, 5> += memref<1x4xvector<16xi8>, 5> * memref<1x4xvector<16xi8>, 5>
411+
return
412+
}
413+
414+
func.func @accel_gemm_gfx950_i8_16x16x64(%matrixA : memref<1x2xvector<16xi8>, 5>,
415+
%matrixB : memref<1x2xvector<16xi8>, 5>,
416+
%matrixC : memref<1x1xvector<4xi32>, 5>) {
417+
// CHECK-LABEL: func.func @accel_gemm_gfx950_i8_16x16x64
418+
// CHECK: rock.transforming_for
419+
// CHECK-SAME: bounds [1, 1, 1]
420+
// CHECK: amdgpu.mfma
421+
// CHECK-SAME: blocks = 1 : i32, k = 64 : i32, m = 16 : i32, n = 16 : i32
422+
// CHECK-NOT: amdgpu.mfma
423+
%c0 = arith.constant 0 : index
424+
rock.threadwise_accel_gemm %matrixC += %matrixA * %matrixB at [%c0, %c0, %c0] features = mfma {
425+
arch = "amdgcn-amd-amdhsa:gfx950",
426+
params = #rock.xdlops_gemm_derived_params<
427+
kpackPerBlock = 8,
428+
kpack = 16,
429+
mPerWave = 16,
430+
nPerWave = 16,
431+
mPerBlock = 32,
432+
nPerBlock = 32,
433+
mnPerXdl = 16,
434+
splitKFactor = 1,
435+
scheduleVersion = 1,
436+
outputSwizzle = 2,
437+
forceUnroll = true>
438+
} : memref<1x1xvector<4xi32>, 5> += memref<1x2xvector<16xi8>, 5> * memref<1x2xvector<16xi8>, 5>
439+
return
440+
}

0 commit comments

Comments
 (0)