@@ -214,6 +214,7 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
214214 Type f8E4M3FNUZ = Float8E4M3FNUZType::get (context);
215215 Type f8E5M2FNUZ = Float8E5M2FNUZType::get (context);
216216 Type f16 = Float16Type::get (context);
217+ Type bf16 = BFloat16Type::get (context);
217218 Type f32 = Float32Type::get (context);
218219
219220 Type i8 = IntegerType::get (context, 8 );
@@ -229,6 +230,12 @@ static OpaqueMmaLayout getOpaqueMFMALayout(MLIRContext *context,
229230 case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
230231 return OpaqueMmaLayout{32 , 32 , 8 , f16 , f16 , f32 };
231232 }
233+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
234+ return OpaqueMmaLayout{16 , 16 , 16 , bf16 , bf16 , f32 };
235+ }
236+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
237+ return OpaqueMmaLayout{32 , 32 , 8 , bf16 , bf16 , f32 };
238+ }
232239 case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ: {
233240 return OpaqueMmaLayout{16 , 16 , 32 , f8E4M3FNUZ, f8E4M3FNUZ, f32 };
234241 }
@@ -336,6 +343,45 @@ static ConcreteMmaLayout getConcreteMFMALayout(MLIRContext *context,
336343 return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
337344 bNLayout, cMLayout, cNLayout};
338345 }
346+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
347+ // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
348+ // #inner = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [4, 4]>
349+ // #layout_a = #iree_vector_ext.layout<#outer, #inner>
350+ // #layout_b = #iree_vector_ext.layout<#inner, #outer>
351+ // #layout_c = #iree_vector_ext.layout<#inner, #outer>
352+
353+ auto outer = PerDimLayoutAttr::get (context, {laneX}, {16 });
354+ auto inner = PerDimLayoutAttr::get (context, {laneY, vectorX}, {4 , 4 });
355+ auto aMLayout = outer;
356+ auto aKLayout = inner;
357+ auto bKLayout = inner;
358+ auto bNLayout = outer;
359+ auto cMLayout = inner;
360+ auto cNLayout = outer;
361+ return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
362+ bNLayout, cMLayout, cNLayout};
363+ }
364+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
365+ // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [32]>
366+ // #inner1 = #iree_vector_ext.per_dim_layout<[LANEY, VECTORX], [2, 4]>
367+ // #inner2 = #iree_vector_ext.per_dim_layout<[VECTORY, LANEY, VECTORX],
368+ // [4, 2, 4]>
369+ // #layout_a = #iree_vector_ext.layout<#outer, #inner1>
370+ // #layout_b = #iree_vector_ext.layout<#inner1, #outer>
371+ // #layout_c = #iree_vector_ext.layout<#inner2, #outer>
372+
373+ auto outer = PerDimLayoutAttr::get (context, {laneX}, {32 });
374+ auto inner = PerDimLayoutAttr::get (context, {laneY, vectorX}, {2 , 4 });
375+ auto aMLayout = outer;
376+ auto aKLayout = inner;
377+ auto bKLayout = inner;
378+ auto bNLayout = outer;
379+ auto cMLayout =
380+ PerDimLayoutAttr::get (context, {vectorY, laneY, vectorX}, {4 , 2 , 4 });
381+ auto cNLayout = outer;
382+ return ConcreteMmaLayout{opaqueLayout, aMLayout, aKLayout, bKLayout,
383+ bNLayout, cMLayout, cNLayout};
384+ }
339385 case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
340386 case MMAIntrinsic::MFMA_I32_16x16x32_I8: {
341387 // #outer = #iree_vector_ext.per_dim_layout<[LANEX], [16]>
@@ -462,14 +508,16 @@ MMAAttr::getABCVectorTypes() const {
462508 return std::make_tuple (aType, bType, cType);
463509 }
464510 case MMAIntrinsic::MFMA_I32_16x16x16_I8:
465- case MMAIntrinsic::MFMA_F32_16x16x16_F16: {
511+ case MMAIntrinsic::MFMA_F32_16x16x16_F16:
512+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16: {
466513 auto aType = VectorType::get ({4 }, getAType ());
467514 auto bType = VectorType::get ({4 }, getBType ());
468515 auto cType = VectorType::get ({4 }, getCType ());
469516 return std::make_tuple (aType, bType, cType);
470517 }
471518 case MMAIntrinsic::MFMA_I32_32x32x8_I8:
472- case MMAIntrinsic::MFMA_F32_32x32x8_F16: {
519+ case MMAIntrinsic::MFMA_F32_32x32x8_F16:
520+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16: {
473521 auto aType = VectorType::get ({4 }, getAType ());
474522 auto bType = VectorType::get ({4 }, getBType ());
475523 auto cType = VectorType::get ({16 }, getCType ());
@@ -519,8 +567,10 @@ int64_t MMAAttr::getBlockSize() const {
519567 switch (getIntrinsic ().getValue ()) {
520568 case MMAIntrinsic::MFMA_F32_16x16x4_F32:
521569 case MMAIntrinsic::MFMA_F32_16x16x16_F16:
570+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
522571 case MMAIntrinsic::MFMA_I32_16x16x16_I8:
523572 case MMAIntrinsic::MFMA_F32_32x32x8_F16:
573+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
524574 case MMAIntrinsic::MFMA_I32_32x32x8_I8:
525575 case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
526576 case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
@@ -540,8 +590,10 @@ static int64_t getIntrinsicSubgroupSize(MMAIntrinsic intrinsic) {
540590 switch (intrinsic) {
541591 case MMAIntrinsic::MFMA_F32_16x16x4_F32:
542592 case MMAIntrinsic::MFMA_F32_16x16x16_F16:
593+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
543594 case MMAIntrinsic::MFMA_I32_16x16x16_I8:
544595 case MMAIntrinsic::MFMA_F32_32x32x8_F16:
596+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
545597 case MMAIntrinsic::MFMA_I32_32x32x8_I8:
546598 case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
547599 case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
@@ -584,6 +636,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
584636 }
585637 case MMAIntrinsic::MFMA_I32_16x16x16_I8:
586638 case MMAIntrinsic::MFMA_F32_16x16x16_F16:
639+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
587640 switch (fragment) {
588641 case MMAFragment::Lhs:
589642 return {/* outer=*/ {1 , 1 }, /* thread=*/ {16 , 4 }, /* tstrides=*/ {1 , 16 },
@@ -597,6 +650,7 @@ MMASingleSubgroupLayout getSingleSubgroupLayout(MMAIntrinsic intrinsic,
597650 }
598651 case MMAIntrinsic::MFMA_I32_32x32x8_I8:
599652 case MMAIntrinsic::MFMA_F32_32x32x8_F16:
653+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
600654 switch (fragment) {
601655 case MMAFragment::Lhs:
602656 return {/* outer=*/ {1 , 1 }, /* thread=*/ {32 , 2 }, /* tstrides=*/ {1 , 32 },
@@ -704,8 +758,10 @@ FailureOr<Value> MMAAttr::buildMmaOperation(OpBuilder &builder, Location loc,
704758 }
705759 case MMAIntrinsic::MFMA_I32_16x16x16_I8:
706760 case MMAIntrinsic::MFMA_F32_16x16x16_F16:
761+ case MMAIntrinsic::MFMA_F32_16x16x16_BF16:
707762 case MMAIntrinsic::MFMA_I32_32x32x8_I8:
708763 case MMAIntrinsic::MFMA_F32_32x32x8_F16:
764+ case MMAIntrinsic::MFMA_F32_32x32x8_BF16:
709765 case MMAIntrinsic::MFMA_F32_16x16x32_F8E4M3FNUZ:
710766 case MMAIntrinsic::MFMA_F32_16x16x32_F8E5M2FNUZ:
711767 case MMAIntrinsic::MFMA_I32_16x16x32_I8:
0 commit comments