Skip to content

Commit df1d669

Browse files
committed
fix position
Signed-off-by: jiqing-feng <[email protected]>
1 parent f2029c6 commit df1d669

File tree

2 files changed

+10
-22
lines changed

2 files changed

+10
-22
lines changed

csrc/cpu_ops.cpp

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,4 @@
11
#include <BinSearch.h>
2-
#include <common.h>
32
#include <cpu_ops.h>
43
#include <thread>
54

@@ -529,19 +528,7 @@ template void dequantizeBlockwise4bitCpu<bf16_t, NF4>(
529528
unsigned char* A, const float* absmax, bf16_t* out, long long blocksize, long long m, long long n
530529
);
531530

532-
// template void gemv_4bit_inference<float, FP4>(
533-
// int64_t M, int64_t N, int64_t K, const float* __restrict__ x, const unsigned char* __restrict__ w, const float*
534-
// __restrict__ absmax, float* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride);
535-
// template void gemv_4bit_inference<float, NF4>(
536-
// int64_t M, int64_t N, int64_t K, const float* __restrict__ x, const unsigned char* __restrict__ w, const float*
537-
// __restrict__ absmax, float* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride);
538-
//
539-
// template void gemv_4bit_inference<fp16_t, FP4>(
540-
// int64_t M, int64_t N, int64_t K, const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, const float*
541-
// __restrict__ absmax, fp16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride);
542-
// template void gemv_4bit_inference<fp16_t, NF4>(
543-
// int64_t M, int64_t N, int64_t K, const fp16_t* __restrict__ x, const unsigned char* __restrict__ w, const float*
544-
// __restrict__ absmax, fp16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride);
531+
#if defined(__AVX512F__) && defined(__AVX512BF16__)
545532
template void gemv_4bit_inference<bf16_t, FP4>(
546533
int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w,
547534
const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride
@@ -550,3 +537,4 @@ template void gemv_4bit_inference<bf16_t, NF4>(
550537
int64_t M, int64_t N, int64_t K, const bf16_t* __restrict__ x, const unsigned char* __restrict__ w,
551538
const bf16_t* __restrict__ absmax, bf16_t* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride
552539
);
540+
#endif

csrc/cpu_ops.h

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -283,14 +283,6 @@ void dequantizeBlockwise4bitCpu(
283283
unsigned char* A, const float* absmax, T* out, long long blocksize, long long m, long long n
284284
);
285285

286-
#if defined(__AVX512F__) && defined(__AVX512BF16__)
287-
template <typename T, int DATA_TYPE>
288-
void gemv_4bit_inference(
289-
int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char* __restrict__ w,
290-
const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride
291-
);
292-
#endif
293-
294286
#if defined(__AVX512F__)
295287
#include <immintrin.h>
296288

@@ -327,4 +319,12 @@ static inline bool has_avx512bf16() {
327319
#endif
328320
#endif
329321

322+
#if defined(__AVX512F__) && defined(__AVX512BF16__)
323+
template <typename T, int DATA_TYPE>
324+
void gemv_4bit_inference(
325+
int64_t M, int64_t N, int64_t K, const T* __restrict__ x, const unsigned char* __restrict__ w,
326+
const T* __restrict__ absmax, T* __restrict__ out, int64_t blocksize, int64_t x_stride, int64_t out_stride
327+
);
328+
#endif
329+
330330
#endif

0 commit comments

Comments
 (0)