@@ -79,21 +79,21 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
7979 __m256 accum256_1 = _mm256_setzero_ps ();
8080 int tail_index_32 = n & (~31 );
8181 for (int j = 0 ; j < tail_index_32 ; j += 32 ) {
82- accum256 = _mm256_dpbf16_ps (accum256 , (__m256bh ) _mm256_loadu_si256 (& x [j + 0 ]), (__m256bh ) _mm256_loadu_si256 (& y [j + 0 ]));
83- accum256_1 = _mm256_dpbf16_ps (accum256_1 , (__m256bh ) _mm256_loadu_si256 (& x [j + 16 ]), (__m256bh ) _mm256_loadu_si256 (& y [j + 16 ]));
82+ accum256 = _mm256_dpbf16_ps (accum256 , (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & x [j + 0 ]), (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & y [j + 0 ]));
83+ accum256_1 = _mm256_dpbf16_ps (accum256_1 , (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & x [j + 16 ]), (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & y [j + 16 ]));
8484 }
8585 accum256 = _mm256_add_ps (accum256 , accum256_1 );
8686
8787 /* Processing the remaining <32 chunk with 16-elements processing */
8888 if ((n & 16 ) != 0 ) {
89- accum256 = _mm256_dpbf16_ps (accum256 , (__m256bh ) _mm256_loadu_si256 (& x [tail_index_32 ]), (__m256bh ) _mm256_loadu_si256 (& y [tail_index_32 ]));
89+ accum256 = _mm256_dpbf16_ps (accum256 , (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & x [tail_index_32 ]), (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & y [tail_index_32 ]));
9090 }
9191 accum128 = _mm_add_ps (_mm256_castps256_ps128 (accum256 ), _mm256_extractf128_ps (accum256 , 1 ));
9292
9393 /* Processing the remaining <16 chunk with 8-elements processing */
9494 if ((n & 8 ) != 0 ) {
9595 int tail_index_16 = n & (~15 );
96- accum128 = _mm_dpbf16_ps (accum128 , (__m128bh ) _mm_loadu_si128 (& x [tail_index_16 ]), (__m128bh ) _mm_loadu_si128 (& y [tail_index_16 ]));
96+ accum128 = _mm_dpbf16_ps (accum128 , (__m128bh ) _mm_loadu_si128 (( __m128i * ) & x [tail_index_16 ]), (__m128bh ) _mm_loadu_si128 (( __m128i * ) & y [tail_index_16 ]));
9797 }
9898
9999 /* Processing the remaining <8 chunk with masked 8-elements processing */
@@ -108,13 +108,13 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
108108 } else if (n > 15 ) { /* n range from 16 to 31 */
109109 /* Processing <32 chunk with 16-elements processing */
110110 __m256 accum256 = _mm256_setzero_ps ();
111- accum256 = _mm256_dpbf16_ps (accum256 , (__m256bh ) _mm256_loadu_si256 (& x [0 ]), (__m256bh ) _mm256_loadu_si256 (& y [0 ]));
111+ accum256 = _mm256_dpbf16_ps (accum256 , (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & x [0 ]), (__m256bh ) _mm256_loadu_si256 (( __m256i * ) & y [0 ]));
112112 accum128 += _mm_add_ps (_mm256_castps256_ps128 (accum256 ), _mm256_extractf128_ps (accum256 , 1 ));
113113
114114 /* Processing the remaining <16 chunk with 8-elements processing */
115115 if ((n & 8 ) != 0 ) {
116116 int tail_index_16 = n & (~15 );
117- accum128 = _mm_dpbf16_ps (accum128 , (__m128bh ) _mm_loadu_si128 (& x [tail_index_16 ]), (__m128bh ) _mm_loadu_si128 (& y [tail_index_16 ]));
117+ accum128 = _mm_dpbf16_ps (accum128 , (__m128bh ) _mm_loadu_si128 (( __m128i * ) & x [tail_index_16 ]), (__m128bh ) _mm_loadu_si128 (( __m128i * ) & y [tail_index_16 ]));
118118 }
119119
120120 /* Processing the remaining <8 chunk with masked 8-elements processing */
@@ -128,7 +128,7 @@ static float sbdot_accl_kernel(BLASLONG n, bfloat16 *x, bfloat16 *y)
128128 }
129129 } else if (n > 7 ) { /* n range from 8 to 15 */
130130 /* Processing <16 chunk with 8-elements processing */
131- accum128 = _mm_dpbf16_ps (accum128 , (__m128bh ) _mm_loadu_si128 (& x [0 ]), (__m128bh ) _mm_loadu_si128 (& y [0 ]));
131+ accum128 = _mm_dpbf16_ps (accum128 , (__m128bh ) _mm_loadu_si128 (( __m128i * ) & x [0 ]), (__m128bh ) _mm_loadu_si128 (( __m128i * ) & y [0 ]));
132132
133133 /* Processing the remaining <8 chunk with masked 8-elements processing */
134134 if ((n & 7 ) != 0 ) {
0 commit comments