@@ -454,8 +454,8 @@ __global__ void kQuantizeBlockwise(float * code, T * __restrict__ const A, float
454454 __shared__ float smem_code[256 ];
455455 __shared__ float smem_absmax_value[1 ];
456456
457- if ( threadIdx .x < 256 )
458- smem_code[threadIdx . x ] = code[threadIdx . x ];
457+ for ( int i = threadIdx .x ; i < 256 ; i+= blockDim . x )
458+ smem_code[i ] = code[i ];
459459
460460 for (unsigned int i = base_idx; i < n_full; i += gridDim .x *BLOCK_SIZE)
461461 {
@@ -2799,6 +2799,12 @@ template __global__ void kQuantizeBlockwise<half, 1024, 4, 0>(float * code, half
27992799template __global__ void kQuantizeBlockwise <float , 1024 , 4 , 0 >(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
28002800template __global__ void kQuantizeBlockwise <half, 512 , 2 , 0 >(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
28012801template __global__ void kQuantizeBlockwise <float , 512 , 2 , 0 >(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2802+ template __global__ void kQuantizeBlockwise <half, 256 , 2 , 0 >(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2803+ template __global__ void kQuantizeBlockwise <float , 256 , 2 , 0 >(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2804+ template __global__ void kQuantizeBlockwise <half, 128 , 2 , 0 >(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2805+ template __global__ void kQuantizeBlockwise <float , 128 , 2 , 0 >(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2806+ template __global__ void kQuantizeBlockwise <half, 64 , 1 , 0 >(float * code, half * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
2807+ template __global__ void kQuantizeBlockwise <float , 64 , 1 , 0 >(float * code, float * __restrict__ const A, float *absmax, unsigned char *out, float * __restrict__ const rand, const int rand_offset, const int n);
28022808
28032809template __global__ void kDequantizeBlockwise <half, 4096 , 1024 , 4 >(float *code, unsigned char * A, float * absmax, half *out, const int n);
28042810template __global__ void kDequantizeBlockwise <float , 4096 , 1024 , 4 >(float *code, unsigned char * A, float * absmax, float *out, const int n);
@@ -2808,6 +2814,12 @@ template __global__ void kDequantizeBlockwise<half, 1024, 256, 4>(float *code, u
28082814template __global__ void kDequantizeBlockwise <float , 1024 , 256 , 4 >(float *code, unsigned char * A, float * absmax, float *out, const int n);
28092815template __global__ void kDequantizeBlockwise <half, 512 , 256 , 2 >(float *code, unsigned char * A, float * absmax, half *out, const int n);
28102816template __global__ void kDequantizeBlockwise <float , 512 , 256 , 2 >(float *code, unsigned char * A, float * absmax, float *out, const int n);
2817+ template __global__ void kDequantizeBlockwise <half, 256 , 128 , 2 >(float *code, unsigned char * A, float * absmax, half *out, const int n);
2818+ template __global__ void kDequantizeBlockwise <float , 256 , 128 , 2 >(float *code, unsigned char * A, float * absmax, float *out, const int n);
2819+ template __global__ void kDequantizeBlockwise <half, 128 , 64 , 2 >(float *code, unsigned char * A, float * absmax, half *out, const int n);
2820+ template __global__ void kDequantizeBlockwise <float , 128 , 64 , 2 >(float *code, unsigned char * A, float * absmax, float *out, const int n);
2821+ template __global__ void kDequantizeBlockwise <half, 64 , 64 , 1 >(float *code, unsigned char * A, float * absmax, half *out, const int n);
2822+ template __global__ void kDequantizeBlockwise <float , 64 , 64 , 1 >(float *code, unsigned char * A, float * absmax, float *out, const int n);
28112823
28122824
28132825
0 commit comments