11#include " quantize.cuh"
22#include < cstdint>
33
4+ #ifdef GGML_HIP_GFX906_OPTIMIZED
5+ #include " gfx906-config.cuh"
6+ #endif
7+
48static __global__ void quantize_q8_1 (
59 const float * __restrict__ x, void * __restrict__ vy,
610 const int64_t ne00, const int64_t s01, const int64_t s02, const int64_t s03,
@@ -35,7 +39,7 @@ static __global__ void quantize_q8_1(
3539 sum = warp_reduce_sum (sum);
3640
3741 const float d = amax / 127 ;
38- const int8_t q = amax == 0 .0f ? 0 : roundf (xi / d);
42+ const int8_t q = amax == 0 .0f ? 0 : __float2int_rn (xi / d);
3943
4044 y[ib].qs [iqs] = q;
4145
@@ -87,6 +91,7 @@ static __global__ void quantize_mmq_q8_1(
8791 amax = fmaxf (amax, fabsf (xi.w ));
8892
8993 // Exchange max. abs. value between vals_per_scale/4 threads.
94+ // Fallback: standard reduction loop
9095#pragma unroll
9196 for (int offset = vals_per_scale/8 ; offset > 0 ; offset >>= 1 ) {
9297 amax = fmaxf (amax, __shfl_xor_sync (0xFFFFFFFF , amax, offset, WARP_SIZE));
@@ -97,20 +102,25 @@ static __global__ void quantize_mmq_q8_1(
97102 sum = xi.x + xi.y + xi.z + xi.w ;
98103
99104 // Calculate sums across vals_per_sum/4 threads.
105+ // Standard reduction loop
100106#pragma unroll
101107 for (int offset = vals_per_sum/8 ; offset > 0 ; offset >>= 1 ) {
102108 sum += __shfl_xor_sync (0xFFFFFFFF , sum, offset, WARP_SIZE);
103109 }
104110 }
105111
106112 const float d_inv = 127 .0f / amax;
107- char4 q;
108- q.x = roundf (xi.x *d_inv);
109- q.y = roundf (xi.y *d_inv);
110- q.z = roundf (xi.z *d_inv);
111- q.w = roundf (xi.w *d_inv);
112113
113- // Write back 4 int8 values as a single 32 bit value for better memroy bandwidth:
114+ // GFX906-optimized vectorized quantization using intrinsics (FASTEST)
115+ char4 q;
116+ // __float2int_rn is fastest on GFX906 for round-to-nearest float-to-int conversion
117+ q.x = __float2int_rn (xi.x *d_inv);
118+ q.y = __float2int_rn (xi.y *d_inv);
119+ q.z = __float2int_rn (xi.z *d_inv);
120+ q.w = __float2int_rn (xi.w *d_inv);
121+
122+ // Write back 4 int8 values as a single 32-bit value for better memory bandwidth:
123+ // Standard vectorized store
114124 char4 * yqs4 = (char4 *) y[ib].qs ;
115125 yqs4[iqs/4 ] = q;
116126
0 commit comments