@@ -256,6 +256,18 @@ static std::tuple<Type, Type, Type> getABCElementTypes(MLIRContext *context,
256256 case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ_F8E4M3FNUZ: {
257257 return {f8E5M2FNUZ, f8E4M3FNUZ, f32 };
258258 }
259+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ: {
260+ return {f8E4M3FNUZ, f8E4M3FNUZ, f32 };
261+ }
262+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ: {
263+ return {f8E5M2FNUZ, f8E5M2FNUZ, f32 };
264+ }
265+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ: {
266+ return {f8E4M3FNUZ, f8E5M2FNUZ, f32 };
267+ }
268+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ: {
269+ return {f8E5M2FNUZ, f8E4M3FNUZ, f32 };
270+ }
259271 case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
260272 return {i8 , i8 , i32 };
261273 }
@@ -608,6 +620,10 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
608620 return {/* outer=*/ {1 , 1 }, /* thread=*/ {4 , 16 }, /* tstrides=*/ {16 , 1 },
609621 /* element=*/ {4 , 1 }};
610622 }
623+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ:
624+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ:
625+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ_F8E5M2FNUZ:
626+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E5M2FNUZ_F8E4M3FNUZ:
611627 case MMAIntrinsic::MFMA_I32_32x32x16_I8:
612628 switch (fragment) {
613629 case MMAFragment::Lhs:
@@ -675,6 +691,8 @@ SmallVector<VirtualMMAIntrinsic> MMAAttr::getVirtualIntrinsics() const {
675691 return {VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16};
676692 case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
677693 return {VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ};
694+ case MMAIntrinsic::MFMA_F32_32x32x16_F8E4M3FNUZ:
695+ return {VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ};
678696 default :
679697 return {};
680698 }
@@ -1218,6 +1236,9 @@ static OpaqueMmaLayout getOpaqueVMMALayout(MLIRContext *context,
12181236 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: {
12191237 return OpaqueMmaLayout{16 , 16 , 32 , f8E4M3FNUZ, f8E4M3FNUZ, f32 };
12201238 }
1239+ case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: {
1240+ return OpaqueMmaLayout{32 , 32 , 16 , f8E4M3FNUZ, f8E4M3FNUZ, f32 };
1241+ }
12211242 // V(Virtual)MFMA instructions which have 2 mfma instructions interleaved
12221243 // along the k dimension.
12231244 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16: {
@@ -1252,6 +1273,7 @@ VirtualMMAAttr::getABCVectorTypes() const {
12521273 auto cType = VectorType::get ({4 }, C);
12531274 return {aType, bType, cType};
12541275 }
1276+ case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
12551277 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
12561278 auto aType = VectorType::get ({8 }, A);
12571279 auto bType = VectorType::get ({8 }, B);
@@ -1274,6 +1296,7 @@ int64_t VirtualMMAAttr::getSubgroupSize() const {
12741296 switch (getIntrinsic ().getValue ()) {
12751297 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
12761298 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
1299+ case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
12771300 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
12781301 return 64 ;
12791302 }
@@ -1328,7 +1351,8 @@ int64_t VirtualMMAAttr::getUnrollK() const {
13281351 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
13291352 return 2 ;
13301353 }
1331- case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ: {
1354+ case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
1355+ case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ: {
13321356 return 1 ;
13331357 }
13341358 }
@@ -1356,6 +1380,7 @@ FailureOr<Value> VirtualMMAAttr::buildMmaOperation(OpBuilder &builder,
13561380 switch (getIntrinsic ().getValue ()) {
13571381 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
13581382 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
1383+ case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
13591384 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
13601385 // Generate mfma's for K with unrolled kernels.
13611386 const int64_t unrollKFactor = getUnrollK ();
@@ -1394,6 +1419,7 @@ int64_t VirtualMMAAttr::getBlockSize() const {
13941419 switch (getIntrinsic ().getValue ()) {
13951420 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F8E4M3FNUZ:
13961421 case VirtualMMAIntrinsic::VMFMA_F32_16x16x32_F16:
1422+ case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
13971423 case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F16: {
13981424 return 1 ;
13991425 }
@@ -1442,6 +1468,18 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(VirtualMMAIntrinsic intrinsic,
14421468 return {/* outer=*/ {4 , 1 }, /* thread=*/ {2 , 32 }, /* tstrides=*/ {32 , 1 },
14431469 /* element=*/ {4 , 1 }};
14441470 }
1471+ case VirtualMMAIntrinsic::VMFMA_F32_32x32x16_F8E4M3FNUZ:
1472+ switch (fragment) {
1473+ case MMAFragment::Lhs:
1474+ return {/* outer=*/ {1 , 2 }, /* thread=*/ {32 , 2 }, /* tstrides=*/ {1 , 32 },
1475+ /* element=*/ {1 , 4 }};
1476+ case MMAFragment::Rhs:
1477+ return {/* outer=*/ {2 , 1 }, /* thread=*/ {2 , 32 }, /* tstrides=*/ {32 , 1 },
1478+ /* element=*/ {4 , 1 }};
1479+ case MMAFragment::Acc:
1480+ return {/* outer=*/ {4 , 1 }, /* thread=*/ {2 , 32 }, /* tstrides=*/ {32 , 1 },
1481+ /* element=*/ {4 , 1 }};
1482+ }
14451483 }
14461484 assert (false && " unhandled virtual mma layout type." );
14471485 return {};
0 commit comments