@@ -54,7 +54,6 @@ __device__ half_2 packFp32s( float a, float b ) { return __builtin_amdgcn_cvt_pk
5454
5555extern "C" __global__ void wmma_matmul ( __fp16 * a , __fp16 * b , __fp16 * c )
5656{
57- const int gIdx = blockIdx .x * blockDim .x + threadIdx .x ;
5857 const int lIdx = threadIdx .x ;
5958
6059 // a and b fragments are stored in 8 VGPRs each, in packed format, so 16 elements each for a and b
@@ -65,14 +64,14 @@ extern "C" __global__ void wmma_matmul( __fp16* a, __fp16* b, __fp16* c )
6564 // initialize c fragment to 0
6665 frag_type_c c_frag = {};
6766
68- const int lane = lIdx % 16 ;
67+ const int laneWrapped = lIdx % 16 ;
6968 const int laneGroup = lIdx / 16 ;
7069#if defined( __gfx12__ )
7170#if 1
7271 for ( int ele = 0 ; ele < WMMA_DATA_WIDTH ; ++ ele )
7372 {
74- b_frag [ele ] = b [16 * ( ele + laneGroup * WMMA_DATA_WIDTH ) + lane ];
75- a_frag [ele ] = a [16 * lane + ( ele + laneGroup * WMMA_DATA_WIDTH )];
73+ b_frag [ele ] = b [16 * ( ele + laneGroup * WMMA_DATA_WIDTH ) + laneWrapped ];
74+ a_frag [ele ] = a [16 * laneWrapped + ( ele + laneGroup * WMMA_DATA_WIDTH )];
7675 }
7776#else
7877 {//with __builtin_amdgcn_cvt_pkrtz
@@ -82,17 +81,17 @@ extern "C" __global__ void wmma_matmul( __fp16* a, __fp16* b, __fp16* c )
8281 {
8382 const int e0 = ele * 2 + 0 ;
8483 const int e1 = ele * 2 + 1 ;
85- b_ptr [ele ] = packFp32s ( b [16 * ( e0 + laneGroup * WMMA_DATA_WIDTH ) + lane ], b [16 * ( e1 + laneGroup * WMMA_DATA_WIDTH ) + lane ] );
86- a_ptr [ele ] = packFp32s ( a [16 * lane + ( e0 + laneGroup * WMMA_DATA_WIDTH )], a [16 * lane + ( e1 + laneGroup * WMMA_DATA_WIDTH )] );
84+ b_ptr [ele ] = packFp32s ( b [16 * ( e0 + laneGroup * WMMA_DATA_WIDTH ) + laneWrapped ], b [16 * ( e1 + laneGroup * WMMA_DATA_WIDTH ) + laneWrapped ] );
85+ a_ptr [ele ] = packFp32s ( a [16 * laneWrapped + ( e0 + laneGroup * WMMA_DATA_WIDTH )], a [16 * laneWrapped + ( e1 + laneGroup * WMMA_DATA_WIDTH )] );
8786 }
8887 }
8988#endif
9089#else
9190 // lane is (0-31) mod 16 instead of 0-31 due to matrix replication in RDNA3
9291 for ( int ele = 0 ; ele < WMMA_DATA_WIDTH ; ++ ele )
9392 {
94- b_frag [ele ] = b [16 * ele + lane ];
95- a_frag [ele ] = a [16 * lane + ele ];
93+ b_frag [ele ] = b [16 * ele + laneWrapped ];
94+ a_frag [ele ] = a [16 * laneWrapped + ele ];
9695 }
9796#endif
9897 // call the WMMA compiler intrinsic
@@ -107,16 +106,16 @@ extern "C" __global__ void wmma_matmul( __fp16* a, __fp16* b, __fp16* c )
107106#if defined( __gfx12__ )
108107 for ( int ele = 0 ; ele < WMMA_DATA_WIDTH ; ++ ele )
109108 {
110- c [16 * ( ele + laneGroup * WMMA_DATA_WIDTH ) + lane ] = c_frag [ele ];
109+ c [16 * ( ele + laneGroup * WMMA_DATA_WIDTH ) + laneWrapped ] = c_frag [ele ];
111110 }
112111#else
113112 for ( int ele = 0 ; ele < 8 ; ++ ele )
114113 {
115114 const int r = ele * 2 + ( lIdx / 16 );
116115 // store results from unpacked c_frag output
117- c [16 * r + lane ] = c_frag [ele * 2 ];
116+ c [16 * r + laneWrapped ] = c_frag [ele * 2 ];
118117 // if OPSEL was set to "true", the line above would instead be
119- // c[16 * r + lane ] = c_frag[ele*2 + 1];
118+ // c[16 * r + laneWrapped ] = c_frag[ele*2 + 1];
120119 }
121120#endif
122121}
0 commit comments