Skip to content

Commit d8cbc68

Browse files
committed
fix cmake check
Signed-off-by: jiqing-feng <[email protected]>
1 parent d7e981d commit d8cbc68

File tree

2 files changed

+50
-12
lines changed

2 files changed

+50
-12
lines changed

CMakeLists.txt

Lines changed: 40 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -85,6 +85,7 @@ endif()
8585
if (BUILD_CPU)
8686
set(CMAKE_CXX_STANDARD 17)
8787
set(CMAKE_CXX_STANDARD_REQUIRED ON)
88+
string(TOLOWER "${CMAKE_SYSTEM_PROCESSOR}" HOST_ARCH)
8889
find_package(OpenMP)
8990
endif()
9091

@@ -270,18 +271,48 @@ target_compile_features(bitsandbytes PUBLIC cxx_std_17)
270271
target_include_directories(bitsandbytes PUBLIC csrc include)
271272

272273
if (BUILD_CPU)
273-
target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX)
274-
include(CheckCXXCompilerFlag)
275-
276-
check_cxx_compiler_flag(-mavx512f HAS_AVX512F)
277-
check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16)
274+
if (OpenMP_CXX_FOUND)
275+
target_link_libraries(bitsandbytes PRIVATE OpenMP::OpenMP_CXX)
276+
add_definitions(-DHAS_OPENMP)
277+
else()
278+
add_definitions(-DNO_OPENMP)
279+
endif()
278280

279-
if(HAS_AVX512F)
280-
target_compile_options(bitsandbytes PRIVATE -mavx512f)
281+
if (HOST_ARCH MATCHES "x86_64|amd64")
282+
include(CheckCXXCompilerFlag)
283+
check_cxx_compiler_flag(-mavx512f HAS_AVX512F_FLAG)
284+
check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16_FLAG)
285+
if (HAS_AVX512F_FLAG)
286+
target_compile_options(bitsandbytes PRIVATE -mavx512f)
287+
add_definitions(-DHAS_AVX512F)
288+
endif()
289+
if (HAS_AVX512BF16_FLAG)
290+
target_compile_options(bitsandbytes PRIVATE -mavx512bf16)
291+
add_definitions(-DHAS_AVX512BF16)
292+
else()
293+
add_definitions(-DNO_AVX512BF16)
294+
endif()
281295
endif()
296+
endif()
282297

283-
if(HAS_AVX512BF16)
284-
target_compile_options(bitsandbytes PRIVATE -mavx512bf16)
298+
# --- Windows MSVC specific AVX512BF16 probe (after add_library) ---
299+
if (MSVC AND BUILD_CPU)
300+
include(CheckCXXSourceCompiles)
301+
set(_AVX512BF16_TEST "
302+
#include <immintrin.h>
303+
int main(){
304+
__m512bh a{}, b{};
305+
auto c = _mm512_dpbf16_ps(_mm512_setzero_ps(), a, b);
306+
(void)c;
307+
return 0;
308+
}")
309+
check_cxx_source_compiles("${_AVX512BF16_TEST}" MSVC_HAS_AVX512BF16)
310+
if (MSVC_HAS_AVX512BF16)
311+
# /arch:AVX512;
312+
target_compile_options(bitsandbytes PRIVATE /arch:AVX512)
313+
add_definitions(-DHAS_AVX512BF16)
314+
else()
315+
add_definitions(-DNO_AVX512BF16)
285316
endif()
286317
endif()
287318

csrc/cpu_ops.cpp

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,13 @@
33
#include <cpu_ops.h>
44
#include <thread>
55

6+
#ifdef HAS_OPENMP
7+
#include <omp.h>
8+
#define BNB_OMP_PARALLEL_FOR _Pragma("omp parallel for")
9+
#else
10+
#define BNB_OMP_PARALLEL_FOR
11+
#endif
12+
613
using namespace BinSearch;
714

815

@@ -99,7 +106,7 @@ void dequantizeBlockwise4bitCpu(unsigned char* A,
99106
if (dim_1 % VEC_LEN == 0 && blocksize >= VEC_LEN) {
100107
__m512 lut = DATA_TYPE == 1 ? set_fp4_lut() : set_nf4_lut();
101108
constexpr auto k_step = VEC_LEN / 2; // 8
102-
#pragma omp parallel for
109+
BNB_OMP_PARALLEL_FOR
103110
for (int block_idx = 0; block_idx < dim_0; ++block_idx) {
104111
for (int k = 0; k < input_dim_1; k += k_step) {
105112
// Load 64 bits of nf4 data and a single scale data
@@ -141,7 +148,7 @@ void dequantizeBlockwise4bitCpu(unsigned char* A,
141148
#endif
142149
// Scalar fallback branch
143150
long long total = m * n;
144-
#pragma omp parallel for
151+
BNB_OMP_PARALLEL_FOR
145152
for (long long block_idx = 0; block_idx < total; block_idx += blocksize) {
146153
long long valid_items = (total - block_idx >= blocksize ? blocksize : total - block_idx);
147154
float scale = absmax[block_idx / blocksize];
@@ -187,7 +194,7 @@ void dequantizeBlockwise8bitCpu(float* code,
187194
long long n) {
188195
if (blocksize <= 0 || n <= 0) return;
189196
// 8-bit path
190-
#pragma omp parallel for
197+
BNB_OMP_PARALLEL_FOR
191198
for (long long block_idx = 0; block_idx < n; block_idx += blocksize) {
192199
long long valid_items = (n - block_idx >= blocksize ? blocksize : n - block_idx);
193200
long long block_end = block_idx + valid_items;

0 commit comments

Comments
 (0)