@@ -306,8 +306,9 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
306306 return {};
307307}
308308
309- static OpaqueMmaLayout getOpaqueMFMALayout (MLIRContext *context,
310- MMAIntrinsic intrinsic) {
309+ template <typename MMAIntrinsicType>
310+ static OpaqueMmaLayout getOpaqueMMALayout (MLIRContext *context,
311+ MMAIntrinsicType intrinsic) {
311312 OpaqueMmaLayout o;
312313 std::tie (o.aType , o.bType , o.cType ) = getABCElementTypes (context, intrinsic);
313314 auto lhs = getSingleSubgroupLayout (intrinsic, MMAFragment::Lhs);
@@ -369,9 +370,9 @@ getPerDimLayoutAttrs(MLIRContext *context, TileSwizzle swizzle) {
369370 PerDimLayoutAttr::get (context, labels[1 ], shape[1 ])};
370371};
371372
372- static ConcreteMmaLayout getConcreteMFMALayout (MLIRContext *context,
373- MMAIntrinsic intrinsic) {
374- auto opaque = getOpaqueMFMALayout (context, intrinsic);
373+ static ConcreteMmaLayout getConcreteMMALayout (MLIRContext *context,
374+ MMAIntrinsic intrinsic) {
375+ auto opaque = getOpaqueMMALayout (context, intrinsic);
375376 ConcreteMmaLayout concreteLayout;
376377 concreteLayout.base = opaque;
377378 auto lhsSwizzle = getIntrinsicSwizzle (intrinsic, MMAFragment::Lhs);
@@ -452,7 +453,7 @@ void MMAAttr::print(AsmPrinter &p) const {
452453}
453454
454455MMAAttr MMAAttr::get (MLIRContext *context, MMAIntrinsic type) {
455- auto layout = getOpaqueMFMALayout (context, type);
456+ auto layout = getOpaqueMMALayout (context, type);
456457 return Base::get (context, MMAIntrinsicAttr::get (context, type), layout.mSize ,
457458 layout.nSize , layout.kSize , layout.aType , layout.bType ,
458459 layout.cType );
@@ -466,9 +467,11 @@ std::tuple<int64_t, int64_t, int64_t> MMAAttr::getMNKShape() const {
466467 return {getMSize (), getNSize (), getKSize ()};
467468}
468469
469- static VectorType getVectorType (MLIRContext *context, MMAIntrinsic intrinsic,
470+ template <typename MMAIntrinsicType>
471+ static VectorType getVectorType (MLIRContext *context,
472+ MMAIntrinsicType intrinsic,
470473 MMAFragment fragment) {
471- auto o = getOpaqueMFMALayout (context, intrinsic);
474+ auto o = getOpaqueMMALayout (context, intrinsic);
472475 auto s = getSingleSubgroupLayout (intrinsic, fragment);
473476 Type elemType = (fragment == MMAFragment::Lhs) ? o.aType
474477 : (fragment == MMAFragment::Rhs) ? o.bType
@@ -491,7 +494,7 @@ FailureOr<std::tuple<VectorLayoutInterface, VectorLayoutInterface,
491494 VectorLayoutInterface>>
492495MMAAttr::getContractionLayout (vector::ContractionOp contract) const {
493496 ConcreteMmaLayout layout =
494- getConcreteMFMALayout (contract->getContext (), getIntrinsic ().getValue ());
497+ getConcreteMMALayout (contract->getContext (), getIntrinsic ().getValue ());
495498 return IREE::GPU::getContractionLayout (contract, layout);
496499}
497500
@@ -932,13 +935,13 @@ sliceSwizzledShape(const TileSwizzle &swizzle,
932935
933936std::tuple<Type, Type, Type> DataTiledMMAAttr::getABCElementTypes () const {
934937 MLIRContext *ctx = getContext ();
935- auto opaqueLayout = getOpaqueMFMALayout (ctx, getIntrinsic ().getValue ());
938+ auto opaqueLayout = getOpaqueMMALayout (ctx, getIntrinsic ().getValue ());
936939 return {opaqueLayout.aType , opaqueLayout.bType , opaqueLayout.cType };
937940}
938941
939942std::tuple<int64_t , int64_t , int64_t > DataTiledMMAAttr::getMNKShape () const {
940943 MLIRContext *ctx = getContext ();
941- auto opaqueLayout = getOpaqueMFMALayout (ctx, getIntrinsic ().getValue ());
944+ auto opaqueLayout = getOpaqueMMALayout (ctx, getIntrinsic ().getValue ());
942945 return {opaqueLayout.mSize * getUnrollM () * getSubgroupsM (),
943946 opaqueLayout.nSize * getUnrollN () * getSubgroupsN (),
944947 opaqueLayout.kSize * getUnrollK ()};
@@ -1228,68 +1231,47 @@ VirtualMMAAttr VirtualMMAAttr::get(MLIRContext *context,
12281231 return VirtualMMAAttr::get (context, intrinsicAttr);
12291232}
12301233
1231- static OpaqueMmaLayout getOpaqueVMMALayout (MLIRContext *context,
1232- VirtualMMAIntrinsic type) {
1234+ static std::tuple<Type, Type, Type>
1235+ getABCElementTypes (MLIRContext *context, VirtualMMAIntrinsic type) {
12331236 Type f8E4M3FNUZ = Float8E4M3FNUZType::get (context);
12341237 Type f16 = Float16Type::get (context);
12351238 Type f32 = Float32Type::get (context);
12361239
12371240 switch (type) {
1238- case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: {
1239- return OpaqueMmaLayout{16 , 16 , 32 , f8E4M3FNUZ, f8E4M3FNUZ, f32 };
1240- }
1241- case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: {
1242- return OpaqueMmaLayout{32 , 32 , 16 , f8E4M3FNUZ, f8E4M3FNUZ, f32 };
1243- }
1241+ case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
1242+ return {f8E4M3FNUZ, f8E4M3FNUZ, f32 };
1243+ case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
1244+ return {f8E4M3FNUZ, f8E4M3FNUZ, f32 };
12441245 // V(Virtual)MFMA instructions which have 2 mfma instructions interleaved
12451246 // along the k dimension.
1246- case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: {
1247- return OpaqueMmaLayout{16 , 16 , 32 , f16 , f16 , f32 };
1248- }
1249- case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
1250- return OpaqueMmaLayout{32 , 32 , 16 , f16 , f16 , f32 };
1251- }
1247+ case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
1248+ return {f16 , f16 , f32 };
1249+ case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16:
1250+ return {f16 , f16 , f32 };
12521251 }
12531252 assert (false && " unhandled virtual mma layout type." );
1254- return OpaqueMmaLayout {};
1253+ return {};
12551254}
12561255
12571256std::tuple<Type, Type, Type> VirtualMMAAttr::getABCElementTypes () const {
12581257 MLIRContext *ctx = getContext ();
1259- auto opaqueLayout = getOpaqueVMMALayout (ctx, getIntrinsic ().getValue ());
1258+ auto opaqueLayout = getOpaqueMMALayout (ctx, getIntrinsic ().getValue ());
12601259 return {opaqueLayout.aType , opaqueLayout.bType , opaqueLayout.cType };
12611260}
12621261
12631262std::tuple<VectorType, VectorType, VectorType>
12641263VirtualMMAAttr::getABCVectorTypes () const {
1265- // Check https://github.com/ROCm/amd_matrix_instruction_calculator for
1266- // instruction details. Note here we are returning the number elements, while
1267- // amd_matrix_instruction_calculator tells us about the number of 32-bit
1268- // registers. So need to adjust accordingly. All vectors should be 1-D.
1269- auto [A, B, C] = getABCElementTypes ();
1270- switch (getIntrinsic ().getValue ()) {
1271- case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
1272- case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: {
1273- auto aType = VectorType::get ({8 }, A);
1274- auto bType = VectorType::get ({8 }, B);
1275- auto cType = VectorType::get ({4 }, C);
1276- return {aType, bType, cType};
1277- }
1278- case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
1279- case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
1280- auto aType = VectorType::get ({8 }, A);
1281- auto bType = VectorType::get ({8 }, B);
1282- auto cType = VectorType::get ({16 }, C);
1283- return {aType, bType, cType};
1284- }
1285- }
1286- assert (false && " unhandled virtual mma layout type." );
1287- return {VectorType{}, VectorType{}, VectorType{}};
1264+ MLIRContext *context = getContext ();
1265+ VirtualMMAIntrinsic intrinsic = getIntrinsic ().getValue ();
1266+ VectorType aVecType = getVectorType (context, intrinsic, MMAFragment::Lhs);
1267+ VectorType bVecType = getVectorType (context, intrinsic, MMAFragment::Rhs);
1268+ VectorType cVecType = getVectorType (context, intrinsic, MMAFragment::Acc);
1269+ return {aVecType, bVecType, cVecType};
12881270}
12891271
12901272std::tuple<int64_t , int64_t , int64_t > VirtualMMAAttr::getMNKShape () const {
12911273 MLIRContext *ctx = getContext ();
1292- auto opaqueLayout = getOpaqueVMMALayout (ctx, getIntrinsic ().getValue ());
1274+ auto opaqueLayout = getOpaqueMMALayout (ctx, getIntrinsic ().getValue ());
12931275 return {opaqueLayout.mSize , opaqueLayout.nSize , opaqueLayout.kSize };
12941276}
12951277
0 commit comments