Skip to content

Commit 8c17887

Browse files
poyencillsilin
andauthored
[CK_TILE] Fix incompatible vector type arguments for the intrinsic calls (#3672)
* Change call to the intrinsics * fix clang format * Undo changes under include/ck/utility * Use named variable as vector size --------- Co-authored-by: illsilin_amdeng <Illia.Silin@amd.com>
1 parent 70d71b1 commit 8c17887

File tree

1 file changed

+56
-12
lines changed

1 file changed

+56
-12
lines changed

include/ck_tile/ops/gemm/warp/warp_gemm_attribute_mfma_impl.hpp

Lines changed: 56 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -612,7 +612,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
612612
else
613613
{
614614
#if defined(__gfx90a__) || defined(__gfx94__)
615-
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
615+
c_vec = __builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
616+
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
617+
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
618+
c_vec,
619+
0,
620+
0,
621+
0);
616622
#elif defined(__gfx908__)
617623
static_for<0, 2, 1>{}([&](auto k) {
618624
c_vec = __builtin_amdgcn_mfma_f32_32x32x4bf16(
@@ -637,8 +643,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M32N32K8
637643
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
638644
{
639645
#if defined(__gfx90a__) || defined(__gfx94__)
640-
return bit_cast<CVecType>(
641-
__builtin_amdgcn_mfma_f32_32x32x8bf16_1k(a_vec, b_vec, fp32x16_t{0.f}, 0, 0, 0));
646+
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_32x32x8bf16_1k(
647+
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
648+
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
649+
fp32x16_t{0.f},
650+
0,
651+
0,
652+
0));
642653
#elif defined(__gfx908__)
643654
CVecType c_vec{0.f};
644655
static_for<0, 2, 1>{}([&](auto k) {
@@ -700,7 +711,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
700711
DISPATCH_MFMA_CTRL_("v_mfma_f32_16x16x16bf16_1k", Ctrl)
701712
{
702713
#if defined(__gfx90a__) || defined(__gfx94__)
703-
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
714+
c_vec = __builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
715+
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
716+
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
717+
c_vec,
718+
0,
719+
0,
720+
0);
704721
#elif defined(__gfx908__)
705722
static_for<0, 2, 1>{}([&](auto k) {
706723
c_vec = __builtin_amdgcn_mfma_f32_16x16x8bf16(
@@ -725,8 +742,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M16N16K16
725742
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
726743
{
727744
#if defined(__gfx90a__) || defined(__gfx94__)
728-
return bit_cast<CVecType>(
729-
__builtin_amdgcn_mfma_f32_16x16x16bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
745+
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_16x16x16bf16_1k(
746+
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
747+
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
748+
fp32x4_t{0.f},
749+
0,
750+
0,
751+
0));
730752
#elif defined(__gfx908__)
731753
CVecType c_vec{0.f};
732754
static_for<0, 2, 1>{}([&](auto k) {
@@ -790,7 +812,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
790812
else
791813
{
792814
#if defined(__gfx90a__) || defined(__gfx94__)
793-
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
815+
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
816+
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
817+
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
818+
c_vec,
819+
0,
820+
0,
821+
0);
794822
#elif defined(__gfx908__)
795823
static_for<0, 2, 1>{}([&](auto k) {
796824
c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
@@ -815,8 +843,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M4N64K4
815843
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
816844
{
817845
#if defined(__gfx90a__) || defined(__gfx94__)
818-
return bit_cast<CVecType>(
819-
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
846+
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
847+
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
848+
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
849+
fp32x4_t{0.f},
850+
0,
851+
0,
852+
0));
820853
#elif defined(__gfx908__)
821854
CVecType c_vec{0.f};
822855
static_for<0, 2, 1>{}([&](auto k) {
@@ -880,7 +913,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
880913
else
881914
{
882915
#if defined(__gfx90a__) || defined(__gfx94__)
883-
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, c_vec, 0, 0, 0);
916+
c_vec = __builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
917+
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
918+
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
919+
c_vec,
920+
0,
921+
0,
922+
0);
884923
#elif defined(__gfx908__)
885924
static_for<0, 2, 1>{}([&](auto k) {
886925
c_vec = __builtin_amdgcn_mfma_f32_4x4x2bf16(
@@ -905,8 +944,13 @@ struct WarpGemmAttributeMfmaImplBf16Bf16F32M64N4K4
905944
CK_TILE_DEVICE CVecType operator()(const AVecType& a_vec, const BVecType& b_vec) const
906945
{
907946
#if defined(__gfx90a__) || defined(__gfx94__)
908-
return bit_cast<CVecType>(
909-
__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(a_vec, b_vec, fp32x4_t{0.f}, 0, 0, 0));
947+
return bit_cast<CVecType>(__builtin_amdgcn_mfma_f32_4x4x4bf16_1k(
948+
bit_cast<ext_vector_t<short, kABKPerLane>>(a_vec),
949+
bit_cast<ext_vector_t<short, kABKPerLane>>(b_vec),
950+
fp32x4_t{0.f},
951+
0,
952+
0,
953+
0));
910954
#elif defined(__gfx908__)
911955
CVecType c_vec{0.f};
912956
static_for<0, 2, 1>{}([&](auto k) {

0 commit comments

Comments
 (0)