11#include < BinSearch.h>
2- #include < common.h>
32#include < cpu_ops.h>
43#include < thread>
54
@@ -529,19 +528,7 @@ template void dequantizeBlockwise4bitCpu<bf16_t, NF4>(
529528 unsigned char * A, const float * absmax, bf16_t * out, long long blocksize, long long m, long long n
530529);
531530
532- // template void gemv_4bit_inference<float, FP4>(
533- // int64_t M, int64_t N, int64_t K, const float* __restrict__ x, const unsigned char* __restrict__ w, const float*
534- // __restrict__ absmax, float* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride);
535- // template void gemv_4bit_inference<float, NF4>(
536- // int64_t M, int64_t N, int64_t K, const float* __restrict__ x, const unsigned char* __restrict__ w, const float*
537- // __restrict__ absmax, float* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride);
538- //
539- // template void gemv_4bit_inference<fp16_t, FP4>(
540- // int64_t M, int64_t N, int64_t K, const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, const float*
541- // __restrict__ absmax, fp16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride);
542- // template void gemv_4bit_inference<fp16_t, NF4>(
543- // int64_t M, int64_t N, int64_t K, const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, const float*
544- // __restrict__ absmax, fp16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride);
531+ #if defined(__AVX512F__) && defined(__AVX512BF16__)
545532template void gemv_4bit_inference<bf16_t , FP4>(
546533 int64_t M, int64_t N, int64_t K, const bf16_t * __restrict__ x, const unsigned char * __restrict__ w,
547534 const bf16_t * __restrict__ absmax, bf16_t * __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride
@@ -550,3 +537,4 @@ template void gemv_4bit_inference<bf16_t, NF4>(
550537 int64_t M, int64_t N, int64_t K, const bf16_t * __restrict__ x, const unsigned char * __restrict__ w,
551538 const bf16_t * __restrict__ absmax, bf16_t * __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride
552539);
540+ #endif
0 commit comments