@@ -862,61 +862,61 @@ static __device__ __forceinline__ void vec_dot_q8_0_q8_1_mma(
862862 const float * y_df = (const float *) y;
863863 const half2 * y_ds = (const half2 *) y;
864864
865- tile_A A[ntx][WARP_SIZE/QI8_0] ;
866- float dA[ntx][tile_C::ne/2 ][WARP_SIZE/QI8_0] ;
865+ tile_A A[ntx];
866+ float dA[ntx][tile_C::ne/2 ];
867867
868868 const int i0 = (threadIdx .y /ntx)*rows_per_warp;
869869
870- #pragma unroll
871- for (int n = 0 ; n < ntx; ++n) {
872- #pragma unroll
873- for (int k01 = 0 ; k01 < WARP_SIZE; k01 += QI8_0) {
874- const int k0 = k00 + k01;
875-
876- load_ldmatrix (A[n][k01/QI8_0], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
877- }
878-
879- #pragma unroll
870+ #pragma unroll
871+ for (int k01 = 0 ; k01 < WARP_SIZE; k01 += QI8_0) {
872+ const int k0 = k00 + k01;
873+ tile_B B;
874+ float dB[tile_C::ne/2 ];
875+ load_generic (B, y_qs + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
876+ #pragma unroll
880877 for (int l = 0 ; l < tile_C::ne/2 ; ++l) {
881- const int i = i0 + n*tile_A::I + tile_C::get_i (2 *l);
882-
883- #pragma unroll
884- for (int k01 = 0 ; k01 < WARP_SIZE; k01 += QI8_0) {
885- const int k0 = k00 + k01;
886-
887- dA[n][l][k01/QI8_0] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
878+ const int j = tile_C::get_j (l);
879+ if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
880+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
881+ } else {
882+ dB[l] = __low2float (y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
888883 }
889884 }
890- }
891-
892- #pragma unroll
893- for (int j0 = 0 ; j0 < mmq_x; j0 += ntx*tile_C::J) {
894- #pragma unroll
895- for (int k01 = 0 ; k01 < WARP_SIZE; k01 += QI8_0) {
896- tile_B B;
897- float dB[tile_C::ne/2 ];
898-
885+ #pragma unroll
886+ for (int n = 0 ; n < ntx; ++n) {
887+ load_ldmatrix (A[n], x_qs + (i0 + n*tile_A::I)*MMQ_MMA_TILE_X_K_Q8_0 + k0, MMQ_MMA_TILE_X_K_Q8_0);
888+ #pragma unroll
889+ for (int l = 0 ; l < tile_C::ne/2 ; ++l) {
890+ const int i = i0 + n*tile_A::I + tile_C::get_i (2 *l);
891+ dA[n][l] = x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + k0/QI8_0];
892+ }
893+ tile_C C;
894+ mma (C, A[n], B);
895+ #pragma unroll
896+ for (int l = 0 ; l < tile_C::ne; ++l) {
897+ sum[(n)*tile_C::ne + l] += C.x [l]*dA[n][l/2 ]*dB[l%2 ];
898+ }
899+ }
900+ #pragma unroll
901+ for (int j0 = ntx*tile_C::J; j0 < mmq_x; j0 += ntx*tile_C::J) {
899902 load_generic (B, y_qs + j0*MMQ_TILE_Y_K + k01, MMQ_TILE_Y_K); // faster than load_ldmatrix
900-
901- #pragma unroll
903+ #pragma unroll
902904 for (int l = 0 ; l < tile_C::ne/2 ; ++l) {
903905 const int j = j0 + tile_C::get_j (l);
904-
905- if (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
906- dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
906+ if constexpr (ds_layout == MMQ_Q8_1_DS_LAYOUT_D4) {
907+ dB[l] = y_df[j*MMQ_TILE_Y_K + k01/QI8_1];
907908 } else {
908909 dB[l] = __low2float (y_ds[j*MMQ_TILE_Y_K + k01/QI8_1]);
909910 }
910911 }
912+ #pragma unroll
911913
912- #pragma unroll
913914 for (int n = 0 ; n < ntx; ++n) {
914915 tile_C C;
915- mma (C, A[n][k01/QI8_0], B);
916-
917- #pragma unroll
916+ mma (C, A[n], B);
917+ #pragma unroll
918918 for (int l = 0 ; l < tile_C::ne; ++l) {
919- sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x [l]*dA[n][l/2 ][k01/QI8_0] *dB[l%2 ];
919+ sum[(j0/tile_C::J + n)*tile_C::ne + l] += C.x [l]*dA[n][l/2 ]*dB[l%2 ];
920920 }
921921 }
922922 }
@@ -2784,6 +2784,64 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
27842784 }
27852785}
27862786
2787+ // template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks(
2788+ // const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
2789+ //
2790+ // #ifdef NEW_MMA_AVAILABLE
2791+ // int * x_qs = (int *) x_tile;
2792+ // float * x_df = (float *) (x_qs + WARP_SIZE*2);
2793+ // #else
2794+ // constexpr tile_x_sizes txs = mmq_get_dp4a_tile_x_sizes(GGML_TYPE_IQ4_XS, mmq_y);
2795+ // int * x_qs = (int *) x_tile;
2796+ // float * x_df = (float *) (x_qs + txs.qs);
2797+ // #endif // NEW_MMA_AVAILABLE
2798+ //
2799+ // const int kbx = 0; // threadIdx.x / QI4_XS
2800+ // const int kqsx = threadIdx.x; // threadIdx.x % QI4_XS
2801+ //
2802+ // #pragma unroll
2803+ // for (int i0 = 0; i0 < mmq_y; i0 += nwarps) {
2804+ // int i = i0 + threadIdx.y;
2805+ //
2806+ // if (need_check) {
2807+ // i = min(i, i_max);
2808+ // }
2809+ //
2810+ // const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof(float)) + kbx0 + kbx;
2811+ //
2812+ // auto values = iq4k_values + ((bxi->scales[kqsx/4] & 1) << 4);
2813+ // const int aux_q4 = get_int_b4(bxi->qs, kqsx);
2814+ // const int2 v = get_int_from_table_16(aux_q4, values);
2815+ // const int k0 = 8 * (threadIdx.x / 4) + threadIdx.x % 4;
2816+ // #ifdef NEW_MMA_AVAILABLE
2817+ // x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0] = v.x;
2818+ // x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4] = v.y;
2819+ // #else
2820+ // x_qs[i*(2*WARP_SIZE + 1) + k0 + 0] = v.x;
2821+ // x_qs[i*(2*WARP_SIZE + 1) + k0 + 4] = v.y;
2822+ // #endif // NEW_MMA_AVAILABLE
2823+ // }
2824+ //
2825+ // #pragma unroll
2826+ // for (int i0 = 0; i0 < mmq_y; i0 += nwarps * 4) {
2827+ // int i = i0 + threadIdx.y * 4 + threadIdx.x / (WARP_SIZE/4);
2828+ //
2829+ // if (need_check) {
2830+ // i = min(i, i_max);
2831+ // }
2832+ //
2833+ // const float * dptr = (const float *)(x + i*stride);
2834+ // const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1) + kbx0;
2835+ // const int ls = (bxi->scales[threadIdx.x % 8] & 254) - 127;
2836+ //
2837+ // #ifdef NEW_MMA_AVAILABLE
2838+ // x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx.x % 8] = dptr[0] * ls;
2839+ // #else
2840+ // x_df[i*(WARP_SIZE/4) + i/4 + threadIdx.x % 8] = dptr[0] * ls;
2841+ // #endif // NEW_MMA_AVAILABLE
2842+ // }
2843+ // }
2844+
27872845template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_ks (
27882846 const char * __restrict__ x, int * __restrict__ x_tile, const int & kbx0, const int & i_max, const int & stride) {
27892847
@@ -2796,50 +2854,40 @@ template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinlin
27962854 float * x_df = (float *) (x_qs + txs.qs );
27972855#endif // NEW_MMA_AVAILABLE
27982856
2799- const int kbx = 0 ; // threadIdx.x / QI4_XS
2800- const int kqsx = threadIdx .x ; // threadIdx.x % QI4_XS
2857+ const int kqsx = threadIdx .x / 4 ;
28012858
28022859#pragma unroll
2803- for (int i0 = 0 ; i0 < mmq_y; i0 += nwarps) {
2804- int i = i0 + threadIdx .y ;
2860+ for (int i0 = 0 ; i0 < mmq_y; i0 += 4 * nwarps) {
2861+ int i = i0 + 4 * threadIdx .y + threadIdx . x % 4 ;
28052862
28062863 if (need_check) {
28072864 i = min (i, i_max);
28082865 }
28092866
2810- const block_iq4_ks * bxi = (const block_iq4_ks *)(x + i*stride + sizeof (float )) + kbx0 + kbx;
2867+ const float * dptr = (const float *)(x + i*stride);
2868+ const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1 ) + kbx0;
2869+ const int ls = (bxi->scales [kqsx] & 254 ) - 127 ;
2870+ auto values = iq4k_values + ((bxi->scales [kqsx] & 1 ) << 4 );
28112871
2812- auto values = iq4k_values + ((bxi-> scales [kqsx/ 4 ] & 1 ) << 4 );
2813- const int aux_q4 = get_int_b4 (bxi-> qs , kqsx);
2814- const int2 v = get_int_from_table_16 (aux_q4, values );
2815- const int k0 = 8 * ( threadIdx . x / 4 ) + threadIdx . x % 4 ;
2872+ # pragma unroll
2873+ for ( int j = 0 ; j < 4 ; ++j) {
2874+ const int aux_q4 = get_int_b4 (bxi-> qs , 4 *kqsx+j );
2875+ const int2 v = get_int_from_table_16 (aux_q4, values) ;
28162876#ifdef NEW_MMA_AVAILABLE
2817- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 0 ] = v.x ;
2818- x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + k0 + 4 ] = v.y ;
2877+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8 *kqsx + j + 0 ] = v.x ;
2878+ x_qs[i*MMQ_MMA_TILE_X_K_Q8_0 + 8 *kqsx + j + 4 ] = v.y ;
28192879#else
2820- x_qs[i*(2 *WARP_SIZE + 1 ) + k0 + 0 ] = v.x ;
2821- x_qs[i*(2 *WARP_SIZE + 1 ) + k0 + 4 ] = v.y ;
2880+ x_qs[i*(2 *WARP_SIZE + 1 ) + 8 *kqsx + j + 0 ] = v.x ;
2881+ x_qs[i*(2 *WARP_SIZE + 1 ) + 8 *kqsx + j + 4 ] = v.y ;
28222882#endif // NEW_MMA_AVAILABLE
2823- }
2824-
2825- #pragma unroll
2826- for (int i0 = 0 ; i0 < mmq_y; i0 += nwarps * 4 ) {
2827- int i = i0 + threadIdx .y * 4 + threadIdx .x / (WARP_SIZE/4 );
2828-
2829- if (need_check) {
2830- i = min (i, i_max);
28312883 }
2832-
2833- const float * dptr = (const float *)(x + i*stride);
2834- const block_iq4_ks * bxi = (const block_iq4_ks *)(dptr + 1 ) + kbx0;
2835- const int ls = (bxi->scales [threadIdx .x % 8 ] & 254 ) - 127 ;
2836-
28372884#ifdef NEW_MMA_AVAILABLE
2838- x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + threadIdx . x % 8 ] = dptr[0 ] * ls;
2885+ x_df[i*MMQ_MMA_TILE_X_K_Q8_0 + kqsx ] = dptr[0 ] * ls;
28392886#else
2840- x_df[i*(WARP_SIZE/4 ) + i/4 + threadIdx . x % 8 ] = dptr[0 ] * ls;
2887+ x_df[i*(WARP_SIZE/4 ) + i/4 + kqsx ] = dptr[0 ] * ls;
28412888#endif // NEW_MMA_AVAILABLE
28422889 }
2890+
28432891}
28442892
28452893template <int mmq_y, int nwarps, bool need_check> static __device__ __forceinline__ void load_tiles_iq4_kt (
0 commit comments