@@ -271,6 +271,30 @@ target_compile_features(bitsandbytes PUBLIC cxx_std_17)
271271target_include_directories (bitsandbytes PUBLIC csrc include )
272272
273273if (BUILD_CPU)
274+ include (CheckCXXSourceRuns)
275+ set (AVX512F_TEST_CODE "
276+ #include <immintrin.h>
277+ int main() {
278+ __m512 a = _mm512_setzero_ps();
279+ __m512 b = _mm512_add_ps(a, a);
280+ return 0;
281+ }
282+ " )
283+ set (AVX512BF16_TEST_CODE "
284+ #include <immintrin.h>
285+ int main() {
286+ __m512 a = _mm512_setzero_ps();
287+ __m256bh b = _mm512_cvtneps_pbh(a);
288+ return 0;
289+ }
290+ " )
291+ set (CMAKE_REQUIRED_FLAGS "-mavx512f" )
292+ check_cxx_source_runs("${AVX512F_TEST_CODE} " HOST_HAS_AVX512F)
293+ unset (CMAKE_REQUIRED_FLAGS)
294+ set (CMAKE_REQUIRED_FLAGS "-mavx512bf16" )
295+ check_cxx_source_runs("${AVX512BF16_TEST_CODE} " HOST_HAS_AVX512BF16)
296+ unset (CMAKE_REQUIRED_FLAGS)
297+
274298 if (OpenMP_CXX_FOUND)
275299 target_link_libraries (bitsandbytes PRIVATE OpenMP::OpenMP_CXX)
276300 add_definitions (-DHAS_OPENMP)
@@ -280,13 +304,13 @@ if (BUILD_CPU)
280304 include (CheckCXXCompilerFlag)
281305 check_cxx_compiler_flag(-mavx512f HAS_AVX512F_FLAG)
282306 check_cxx_compiler_flag(-mavx512bf16 HAS_AVX512BF16_FLAG)
283- if (HAS_AVX512F_FLAG)
307+ if (HAS_AVX512F_FLAG AND HOST_HAS_AVX512F )
284308 target_compile_options (bitsandbytes PRIVATE -mavx512f)
285309 target_compile_options (bitsandbytes PRIVATE -mavx512dq)
286310 target_compile_options (bitsandbytes PRIVATE -mavx512bw)
287311 target_compile_options (bitsandbytes PRIVATE -mavx512vl)
288312 endif ()
289- if (HAS_AVX512BF16_FLAG)
313+ if (HAS_AVX512BF16_FLAG AND HOST_HAS_AVX512BF16 )
290314 target_compile_options (bitsandbytes PRIVATE -mavx512bf16)
291315 endif ()
292316 target_compile_options (
0 commit comments