@@ -406,10 +406,10 @@ void ggml_fp32_to_bf16_row(const float * x, ggml_bf16_t * y, int64_t n) {
406406 int i = 0;
407407#if defined(__AVX512BF16__)
408408 for (; i + 32 <= n; i += 32) {
409- _mm512_storeu_ps (
410- (__m512 *)(y + i),
411- (__m512) _mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
412- _mm512_loadu_ps(x + i)));
409+ _mm512_storeu_si512 (
410+ (__m512i *)(y + i),
411+ m512i( _mm512_cvtne2ps_pbh(_mm512_loadu_ps(x + i + 16),
412+ _mm512_loadu_ps(x + i) )));
413413 }
414414#endif
415415 for (; i < n; i++) {
@@ -1666,10 +1666,10 @@ static void ggml_vec_dot_bf16(int n, float * restrict s, size_t bs, ggml_bf16_t
16661666 __m512 c1 = _mm512_setzero_ps();
16671667 __m512 c2 = _mm512_setzero_ps();
16681668 for (; i + 64 <= n; i += 64) {
1669- c1 = _mm512_dpbf16_ps(c1, (__m512bh)_mm512_loadu_ps((const float *)( x + i)),
1670- (__m512bh)_mm512_loadu_ps((const float *)( y + i)));
1671- c2 = _mm512_dpbf16_ps(c2, (__m512bh)_mm512_loadu_ps((const float *)( x + i + 32)),
1672- (__m512bh)_mm512_loadu_ps((const float *)( y + i + 32)));
1669+ c1 = _mm512_dpbf16_ps(c1, m512bh(_mm512_loadu_si512(( x + i) )),
1670+ m512bh(_mm512_loadu_si512(( y + i) )));
1671+ c2 = _mm512_dpbf16_ps(c2, m512bh(_mm512_loadu_si512(( x + i + 32) )),
1672+ m512bh(_mm512_loadu_si512(( y + i + 32) )));
16731673 }
16741674 sumf += (ggml_float)_mm512_reduce_add_ps(c1);
16751675 sumf += (ggml_float)_mm512_reduce_add_ps(c2);
@@ -23137,6 +23137,14 @@ int ggml_cpu_has_avx512_vnni(void) {
2313723137#endif
2313823138}
2313923139
23140+ int ggml_cpu_has_avx512_bf16(void) {
23141+ #if defined(__AVX512BF16__)
23142+ return 1;
23143+ #else
23144+ return 0;
23145+ #endif
23146+ }
23147+
2314023148int ggml_cpu_has_fma(void) {
2314123149#if defined(__FMA__)
2314223150 return 1;
0 commit comments