Skip to content

Commit ef241f9

Browse files
authored
[LLVMGPU] Cleanup VirtualMMA functions to match refactoring on base MMAAttr (iree-org#19144)
Apply similar cleanups to what is done in iree-org#19098. For most part we do: 1. Templateize getVectorType and getOpaqueMmaLayout to work on any intrinsic 2. Use common getOpaqueMmaLayout for VirtualMMA 3. Update getABCElementTypes to be similar to MMAAttr 4. Rename get.*MFMA fn to get.*MMA since MFMA is CDNA specific but in reality it does not have to be MFMA instructions Signed-off-by: Stanley Winata <[email protected]>
1 parent a70ea83 commit ef241f9

File tree

1 file changed

+33
-51
lines changed

1 file changed

+33
-51
lines changed

compiler/src/iree/compiler/Codegen/Dialect/GPU/IR/IREEGPUAttrs.cpp

Lines changed: 33 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -306,8 +306,9 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
306306
return {};
307307
}
308308

309-
static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
310-
MMAIntrinsic intrinsic) {
309+
template <typename MMAIntrinsicType>
310+
static OpaqueMmaLayout getOpaqueMMALayout(MLIRContext *context,
311+
MMAIntrinsicType intrinsic) {
311312
OpaqueMmaLayout o;
312313
std::tie(o.aType, o.bType, o.cType) = getABCElementTypes(context, intrinsic);
313314
auto lhs = getSingleSubgroupLayout(intrinsic, MMAFragment::Lhs);
@@ -369,9 +370,9 @@ getPerDimLayoutAttrs(MLIRContext *context, TileSwizzle swizzle) {
369370
PerDimLayoutAttr::get(context, labels[1], shape[1])};
370371
};
371372

372-
static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
373-
MMAIntrinsic intrinsic) {
374-
auto opaque = getOpaqueMFMALayout(context, intrinsic);
373+
static ConcreteMmaLayout getConcreteMMALayout(MLIRContext *context,
374+
MMAIntrinsic intrinsic) {
375+
auto opaque = getOpaqueMMALayout(context, intrinsic);
375376
ConcreteMmaLayout concreteLayout;
376377
concreteLayout.base = opaque;
377378
auto lhsSwizzle = getIntrinsicSwizzle(intrinsic, MMAFragment::Lhs);
@@ -452,7 +453,7 @@ void MMAAttr::print(AsmPrinter &p) const {
452453
}
453454

454455
MMAAttr MMAAttr::get(MLIRContext *context, MMAIntrinsic type) {
455-
auto layout = getOpaqueMFMALayout(context, type);
456+
auto layout = getOpaqueMMALayout(context, type);
456457
return Base::get(context, MMAIntrinsicAttr::get(context, type), layout.mSize,
457458
layout.nSize, layout.kSize, layout.aType, layout.bType,
458459
layout.cType);
@@ -466,9 +467,11 @@ std::tuple<int64_t, int64_t, int64_t> MMAAttr::getMNKShape() const {
466467
return {getMSize(), getNSize(), getKSize()};
467468
}
468469

469-
static VectorType getVectorType(MLIRContext *context, MMAIntrinsic intrinsic,
470+
template <typename MMAIntrinsicType>
471+
static VectorType getVectorType(MLIRContext *context,
472+
MMAIntrinsicType intrinsic,
470473
MMAFragment fragment) {
471-
auto o = getOpaqueMFMALayout(context, intrinsic);
474+
auto o = getOpaqueMMALayout(context, intrinsic);
472475
auto s = getSingleSubgroupLayout(intrinsic, fragment);
473476
Type elemType = (fragment == MMAFragment::Lhs) ? o.aType
474477
: (fragment == MMAFragment::Rhs) ? o.bType
@@ -491,7 +494,7 @@ FailureOr<std::tuple<VectorLayoutInterface, VectorLayoutInterface,
491494
VectorLayoutInterface>>
492495
MMAAttr::getContractionLayout(vector::ContractionOp contract) const {
493496
ConcreteMmaLayout layout =
494-
getConcreteMFMALayout(contract->getContext(), getIntrinsic().getValue());
497+
getConcreteMMALayout(contract->getContext(), getIntrinsic().getValue());
495498
return IREE::GPU::getContractionLayout(contract, layout);
496499
}
497500

@@ -932,13 +935,13 @@ sliceSwizzledShape(const TileSwizzle &swizzle,
932935

933936
std::tuple<Type, Type, Type> DataTiledMMAAttr::getABCElementTypes() const {
934937
MLIRContext *ctx = getContext();
935-
auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue());
938+
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
936939
return {opaqueLayout.aType, opaqueLayout.bType, opaqueLayout.cType};
937940
}
938941

939942
std::tuple<int64_t, int64_t, int64_t> DataTiledMMAAttr::getMNKShape() const {
940943
MLIRContext *ctx = getContext();
941-
auto opaqueLayout = getOpaqueMFMALayout(ctx, getIntrinsic().getValue());
944+
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
942945
return {opaqueLayout.mSize * getUnrollM() * getSubgroupsM(),
943946
opaqueLayout.nSize * getUnrollN() * getSubgroupsN(),
944947
opaqueLayout.kSize * getUnrollK()};
@@ -1228,68 +1231,47 @@ VirtualMMAAttr VirtualMMAAttr::get(MLIRContext *context,
12281231
return VirtualMMAAttr::get(context, intrinsicAttr);
12291232
}
12301233

1231-
static OpaqueMmaLayout getOpaqueVMMALayout(MLIRContext *context,
1232-
VirtualMMAIntrinsic type) {
1234+
static std::tuple<Type, Type, Type>
1235+
getABCElementTypes(MLIRContext *context, VirtualMMAIntrinsic type) {
12331236
Type f8E4M3FNUZ = Float8E4M3FNUZType::get(context);
12341237
Type f16 = Float16Type::get(context);
12351238
Type f32 = Float32Type::get(context);
12361239

12371240
switch (type) {
1238-
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: {
1239-
return OpaqueMmaLayout{16, 16, 32, f8E4M3FNUZ, f8E4M3FNUZ, f32};
1240-
}
1241-
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: {
1242-
return OpaqueMmaLayout{32, 32, 16, f8E4M3FNUZ, f8E4M3FNUZ, f32};
1243-
}
1241+
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
1242+
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
1243+
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
1244+
return {f8E4M3FNUZ, f8E4M3FNUZ, f32};
12441245
// V(Virtual)MFMA instructions which have 2 mfma instructions interleaved
12451246
// along the k dimension.
1246-
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: {
1247-
return OpaqueMmaLayout{16, 16, 32, f16, f16, f32};
1248-
}
1249-
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
1250-
return OpaqueMmaLayout{32, 32, 16, f16, f16, f32};
1251-
}
1247+
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
1248+
return {f16, f16, f32};
1249+
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
1250+
return {f16, f16, f32};
12521251
}
12531252
assert(false && "unhandled virtual mma layout type.");
1254-
return OpaqueMmaLayout{};
1253+
return {};
12551254
}
12561255

12571256
std::tuple<Type, Type, Type> VirtualMMAAttr::getABCElementTypes() const {
12581257
MLIRContext *ctx = getContext();
1259-
auto opaqueLayout = getOpaqueVMMALayout(ctx, getIntrinsic().getValue());
1258+
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
12601259
return {opaqueLayout.aType, opaqueLayout.bType, opaqueLayout.cType};
12611260
}
12621261

12631262
std::tuple<VectorType, VectorType, VectorType>
12641263
VirtualMMAAttr::getABCVectorTypes() const {
1265-
// Check https://github.com/ROCm/amd_matrix_instruction_calculator for
1266-
// instruction details. Note here we are returning the number elements, while
1267-
// amd_matrix_instruction_calculator tells us about the number of 32-bit
1268-
// registers. So need to adjust accordingly. All vectors should be 1-D.
1269-
auto [A, B, C] = getABCElementTypes();
1270-
switch (getIntrinsic().getValue()) {
1271-
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
1272-
case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: {
1273-
auto aType = VectorType::get({8}, A);
1274-
auto bType = VectorType::get({8}, B);
1275-
auto cType = VectorType::get({4}, C);
1276-
return {aType, bType, cType};
1277-
}
1278-
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
1279-
case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
1280-
auto aType = VectorType::get({8}, A);
1281-
auto bType = VectorType::get({8}, B);
1282-
auto cType = VectorType::get({16}, C);
1283-
return {aType, bType, cType};
1284-
}
1285-
}
1286-
assert(false && "unhandled virtual mma layout type.");
1287-
return {VectorType{}, VectorType{}, VectorType{}};
1264+
MLIRContext *context = getContext();
1265+
VirtualMMAIntrinsic intrinsic = getIntrinsic().getValue();
1266+
VectorType aVecType = getVectorType(context, intrinsic, MMAFragment::Lhs);
1267+
VectorType bVecType = getVectorType(context, intrinsic, MMAFragment::Rhs);
1268+
VectorType cVecType = getVectorType(context, intrinsic, MMAFragment::Acc);
1269+
return {aVecType, bVecType, cVecType};
12881270
}
12891271

12901272
std::tuple<int64_t, int64_t, int64_t> VirtualMMAAttr::getMNKShape() const {
12911273
MLIRContext *ctx = getContext();
1292-
auto opaqueLayout = getOpaqueVMMALayout(ctx, getIntrinsic().getValue());
1274+
auto opaqueLayout = getOpaqueMMALayout(ctx, getIntrinsic().getValue());
12931275
return {opaqueLayout.mSize, opaqueLayout.nSize, opaqueLayout.kSize};
12941276
}
12951277

0 commit comments

Comments
 (0)