@@ -612,7 +612,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
612612 else
613613 {
614614#if defined(__gfx90a__) || defined(__gfx94__)
615- c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k (a_vec, b_vec, c_vec, 0 , 0 , 0 );
615+ c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k (
616+ bit_cast<ext_vector_t <short , kABKPerLane >>(a_vec),
617+ bit_cast<ext_vector_t <short , kABKPerLane >>(b_vec),
618+ c_vec,
619+ 0 ,
620+ 0 ,
621+ 0 );
616622#elif defined(__gfx908__)
617623 static_for<0 , 2 , 1 >{}([&](auto k) {
618624 c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16 (
@@ -637,8 +643,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
637643 CK_TILE_DEVICE CVecType operator ()(const AVecType& a_vec, const BVecType& b_vec) const
638644 {
639645#if defined(__gfx90a__) || defined(__gfx94__)
640- return bit_cast<CVecType>(
641- __builtin_amdgcn_mfma_f32_32x32x8bf16_1k (a_vec, b_vec, fp32x16_t {0 .f }, 0 , 0 , 0 ));
646+ return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x8bf16_1k (
647+ bit_cast<ext_vector_t <short , kABKPerLane >>(a_vec),
648+ bit_cast<ext_vector_t <short , kABKPerLane >>(b_vec),
649+ fp32x16_t {0 .f },
650+ 0 ,
651+ 0 ,
652+ 0 ));
642653#elif defined(__gfx908__)
643654 CVecType c_vec{0 .f };
644655 static_for<0 , 2 , 1 >{}([&](auto k) {
@@ -700,7 +711,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
700711 DISPATCH_MFMA_CTRL_ (" v_mfma_f32_16x16x16bf16_1k" , Ctrl)
701712 {
702713#if defined(__gfx90a__) || defined(__gfx94__)
703- c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k (a_vec, b_vec, c_vec, 0 , 0 , 0 );
714+ c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k (
715+ bit_cast<ext_vector_t <short , kABKPerLane >>(a_vec),
716+ bit_cast<ext_vector_t <short , kABKPerLane >>(b_vec),
717+ c_vec,
718+ 0 ,
719+ 0 ,
720+ 0 );
704721#elif defined(__gfx908__)
705722 static_for<0 , 2 , 1 >{}([&](auto k) {
706723 c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16 (
@@ -725,8 +742,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
725742 CK_TILE_DEVICE CVecType operator ()(const AVecType& a_vec, const BVecType& b_vec) const
726743 {
727744#if defined(__gfx90a__) || defined(__gfx94__)
728- return bit_cast<CVecType>(
729- __builtin_amdgcn_mfma_f32_16x16x16bf16_1k (a_vec, b_vec, fp32x4_t {0 .f }, 0 , 0 , 0 ));
745+ return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x16bf16_1k (
746+ bit_cast<ext_vector_t <short , kABKPerLane >>(a_vec),
747+ bit_cast<ext_vector_t <short , kABKPerLane >>(b_vec),
748+ fp32x4_t {0 .f },
749+ 0 ,
750+ 0 ,
751+ 0 ));
730752#elif defined(__gfx908__)
731753 CVecType c_vec{0 .f };
732754 static_for<0 , 2 , 1 >{}([&](auto k) {
@@ -790,7 +812,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
790812 else
791813 {
792814#if defined(__gfx90a__) || defined(__gfx94__)
793- c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (a_vec, b_vec, c_vec, 0 , 0 , 0 );
815+ c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (
816+ bit_cast<ext_vector_t <short , kABKPerLane >>(a_vec),
817+ bit_cast<ext_vector_t <short , kABKPerLane >>(b_vec),
818+ c_vec,
819+ 0 ,
820+ 0 ,
821+ 0 );
794822#elif defined(__gfx908__)
795823 static_for<0 , 2 , 1 >{}([&](auto k) {
796824 c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16 (
@@ -815,8 +843,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
815843 CK_TILE_DEVICE CVecType operator ()(const AVecType& a_vec, const BVecType& b_vec) const
816844 {
817845#if defined(__gfx90a__) || defined(__gfx94__)
818- return bit_cast<CVecType>(
819- __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (a_vec, b_vec, fp32x4_t {0 .f }, 0 , 0 , 0 ));
846+ return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_4x4x4bf16_1k (
847+ bit_cast<ext_vector_t <short , kABKPerLane >>(a_vec),
848+ bit_cast<ext_vector_t <short , kABKPerLane >>(b_vec),
849+ fp32x4_t {0 .f },
850+ 0 ,
851+ 0 ,
852+ 0 ));
820853#elif defined(__gfx908__)
821854 CVecType c_vec{0 .f };
822855 static_for<0 , 2 , 1 >{}([&](auto k) {
@@ -880,7 +913,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
880913 else
881914 {
882915#if defined(__gfx90a__) || defined(__gfx94__)
883- c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (a_vec, b_vec, c_vec, 0 , 0 , 0 );
916+ c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (
917+ bit_cast<ext_vector_t <short , kABKPerLane >>(a_vec),
918+ bit_cast<ext_vector_t <short , kABKPerLane >>(b_vec),
919+ c_vec,
920+ 0 ,
921+ 0 ,
922+ 0 );
884923#elif defined(__gfx908__)
885924 static_for<0 , 2 , 1 >{}([&](auto k) {
886925 c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16 (
@@ -905,8 +944,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
905944 CK_TILE_DEVICE CVecType operator ()(const AVecType& a_vec, const BVecType& b_vec) const
906945 {
907946#if defined(__gfx90a__) || defined(__gfx94__)
908- return bit_cast<CVecType>(
909- __builtin_amdgcn_mfma_f32_4x4x4bf16_1k (a_vec, b_vec, fp32x4_t {0 .f }, 0 , 0 , 0 ));
947+ return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_4x4x4bf16_1k (
948+ bit_cast<ext_vector_t <short , kABKPerLane >>(a_vec),
949+ bit_cast<ext_vector_t <short , kABKPerLane >>(b_vec),
950+ fp32x4_t {0 .f },
951+ 0 ,
952+ 0 ,
953+ 0 ));
910954#elif defined(__gfx908__)
911955 CVecType c_vec{0 .f };
912956 static_for<0 , 2 , 1 >{}([&](auto k) {
0 commit comments