Skip to content

Commit 346f48e

Browse files
authored
[Headers][X86] Convert bf16 to f32 conversions to constexpr implementations (#169841)
Fixes #154911
1 parent 63163b4 commit 346f48e

File tree

4 files changed

+45
-10
lines changed

4 files changed

+45
-10
lines changed

clang/lib/Headers/avx512bf16intrin.h

Lines changed: 15 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,14 @@ typedef __bf16 __bfloat16 __attribute__((deprecated("use __bf16 instead")));
2525
#define __DEFAULT_FN_ATTRS \
2626
__attribute__((__always_inline__, __nodebug__, __target__("avx512bf16")))
2727

28+
#if defined(__cplusplus) && (__cplusplus >= 201103L)
29+
#define __DEFAULT_FN_ATTRS512_CONSTEXPR __DEFAULT_FN_ATTRS512 constexpr
30+
#define __DEFAULT_FN_ATTRS_CONSTEXPR __DEFAULT_FN_ATTRS constexpr
31+
#else
32+
#define __DEFAULT_FN_ATTRS512_CONSTEXPR __DEFAULT_FN_ATTRS512
33+
#define __DEFAULT_FN_ATTRS_CONSTEXPR __DEFAULT_FN_ATTRS
34+
#endif
35+
2836
/// Convert One BF16 Data to One Single Float Data.
2937
///
3038
/// \headerfile <x86intrin.h>
@@ -35,7 +43,7 @@ typedef __bf16 __bfloat16 __attribute__((deprecated("use __bf16 instead")));
3543
/// A bfloat data.
3644
/// \returns A float data whose sign field and exponent field keep unchanged,
3745
/// and fraction field is extended to 23 bits.
38-
static __inline__ float __DEFAULT_FN_ATTRS _mm_cvtsbh_ss(__bf16 __A) {
46+
static __inline__ float __DEFAULT_FN_ATTRS_CONSTEXPR _mm_cvtsbh_ss(__bf16 __A) {
3947
return (float)(__A);
4048
}
4149

@@ -235,7 +243,8 @@ _mm512_maskz_dpbf16_ps(__mmask16 __U, __m512 __D, __m512bh __A, __m512bh __B) {
235243
/// \param __A
236244
/// A 256-bit vector of [16 x bfloat].
237245
/// \returns A 512-bit vector of [16 x float] come from conversion of __A
238-
static __inline__ __m512 __DEFAULT_FN_ATTRS512 _mm512_cvtpbh_ps(__m256bh __A) {
246+
static __inline__ __m512 __DEFAULT_FN_ATTRS512_CONSTEXPR
247+
_mm512_cvtpbh_ps(__m256bh __A) {
239248
return (__m512) __builtin_convertvector(__A, __v16sf);
240249
}
241250

@@ -249,7 +258,7 @@ static __inline__ __m512 __DEFAULT_FN_ATTRS512 _mm512_cvtpbh_ps(__m256bh __A) {
249258
/// \param __A
250259
/// A 256-bit vector of [16 x bfloat].
251260
/// \returns A 512-bit vector of [16 x float] come from conversion of __A
252-
static __inline__ __m512 __DEFAULT_FN_ATTRS512
261+
static __inline__ __m512 __DEFAULT_FN_ATTRS512_CONSTEXPR
253262
_mm512_maskz_cvtpbh_ps(__mmask16 __U, __m256bh __A) {
254263
return (__m512)__builtin_ia32_selectps_512((__mmask16)__U,
255264
(__v16sf)_mm512_cvtpbh_ps(__A),
@@ -268,14 +277,16 @@ _mm512_maskz_cvtpbh_ps(__mmask16 __U, __m256bh __A) {
268277
/// \param __A
269278
/// A 256-bit vector of [16 x bfloat].
270279
/// \returns A 512-bit vector of [16 x float] come from conversion of __A
271-
static __inline__ __m512 __DEFAULT_FN_ATTRS512
280+
static __inline__ __m512 __DEFAULT_FN_ATTRS512_CONSTEXPR
272281
_mm512_mask_cvtpbh_ps(__m512 __S, __mmask16 __U, __m256bh __A) {
273282
return (__m512)__builtin_ia32_selectps_512(
274283
(__mmask16)__U, (__v16sf)_mm512_cvtpbh_ps(__A), (__v16sf)__S);
275284
}
276285

277286
#undef __DEFAULT_FN_ATTRS
287+
#undef __DEFAULT_FN_ATTRS_CONSTEXPR
278288
#undef __DEFAULT_FN_ATTRS512
289+
#undef __DEFAULT_FN_ATTRS512_CONSTEXPR
279290

280291
#endif
281292
#endif

clang/lib/Headers/avx512vlbf16intrin.h

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@
2424
__target__("avx512vl,avx512bf16"), \
2525
__min_vector_width__(256)))
2626

27+
#if defined(__cplusplus) && (__cplusplus >= 201103L)
28+
#define __DEFAULT_FN_ATTRS128_CONSTEXPR __DEFAULT_FN_ATTRS128 constexpr
29+
#define __DEFAULT_FN_ATTRS256_CONSTEXPR __DEFAULT_FN_ATTRS256 constexpr
30+
#else
31+
#define __DEFAULT_FN_ATTRS128_CONSTEXPR __DEFAULT_FN_ATTRS128
32+
#define __DEFAULT_FN_ATTRS256_CONSTEXPR __DEFAULT_FN_ATTRS256
33+
#endif
34+
2735
/// Convert Two Packed Single Data to One Packed BF16 Data.
2836
///
2937
/// \headerfile <x86intrin.h>
@@ -421,7 +429,8 @@ static __inline__ __bf16 __DEFAULT_FN_ATTRS128 _mm_cvtness_sbh(float __A) {
421429
/// \param __A
422430
/// A 128-bit vector of [4 x bfloat].
423431
/// \returns A 128-bit vector of [4 x float] come from conversion of __A
424-
static __inline__ __m128 __DEFAULT_FN_ATTRS128 _mm_cvtpbh_ps(__m128bh __A) {
432+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
433+
_mm_cvtpbh_ps(__m128bh __A) {
425434
return (__m128)_mm256_castps256_ps128(
426435
(__m256) __builtin_convertvector(__A, __v8sf));
427436
}
@@ -433,7 +442,8 @@ static __inline__ __m128 __DEFAULT_FN_ATTRS128 _mm_cvtpbh_ps(__m128bh __A) {
433442
/// \param __A
434443
/// A 128-bit vector of [8 x bfloat].
435444
/// \returns A 256-bit vector of [8 x float] come from conversion of __A
436-
static __inline__ __m256 __DEFAULT_FN_ATTRS256 _mm256_cvtpbh_ps(__m128bh __A) {
445+
static __inline__ __m256 __DEFAULT_FN_ATTRS256_CONSTEXPR
446+
_mm256_cvtpbh_ps(__m128bh __A) {
437447
return (__m256) __builtin_convertvector(__A, __v8sf);
438448
}
439449

@@ -447,7 +457,7 @@ static __inline__ __m256 __DEFAULT_FN_ATTRS256 _mm256_cvtpbh_ps(__m128bh __A) {
447457
/// \param __A
448458
/// A 128-bit vector of [4 x bfloat].
449459
/// \returns A 128-bit vector of [4 x float] come from conversion of __A
450-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
460+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
451461
_mm_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
452462
return (__m128)__builtin_ia32_selectps_128(
453463
(__mmask8)__U, (__v4sf)_mm_cvtpbh_ps(__A), (__v4sf)_mm_setzero_ps());
@@ -463,7 +473,7 @@ _mm_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
463473
/// \param __A
464474
/// A 128-bit vector of [8 x bfloat].
465475
/// \returns A 256-bit vector of [8 x float] come from conversion of __A
466-
static __inline__ __m256 __DEFAULT_FN_ATTRS256
476+
static __inline__ __m256 __DEFAULT_FN_ATTRS256_CONSTEXPR
467477
_mm256_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
468478
return (__m256)__builtin_ia32_selectps_256((__mmask8)__U,
469479
(__v8sf)_mm256_cvtpbh_ps(__A),
@@ -483,7 +493,7 @@ _mm256_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
483493
/// \param __A
484494
/// A 128-bit vector of [4 x bfloat].
485495
/// \returns A 128-bit vector of [4 x float] come from conversion of __A
486-
static __inline__ __m128 __DEFAULT_FN_ATTRS128
496+
static __inline__ __m128 __DEFAULT_FN_ATTRS128_CONSTEXPR
487497
_mm_mask_cvtpbh_ps(__m128 __S, __mmask8 __U, __m128bh __A) {
488498
return (__m128)__builtin_ia32_selectps_128(
489499
(__mmask8)__U, (__v4sf)_mm_cvtpbh_ps(__A), (__v4sf)__S);
@@ -502,14 +512,16 @@ _mm_mask_cvtpbh_ps(__m128 __S, __mmask8 __U, __m128bh __A) {
502512
/// \param __A
503513
/// A 128-bit vector of [8 x bfloat].
504514
/// \returns A 256-bit vector of [8 x float] come from conversion of __A
505-
static __inline__ __m256 __DEFAULT_FN_ATTRS256
515+
static __inline__ __m256 __DEFAULT_FN_ATTRS256_CONSTEXPR
506516
_mm256_mask_cvtpbh_ps(__m256 __S, __mmask8 __U, __m128bh __A) {
507517
return (__m256)__builtin_ia32_selectps_256(
508518
(__mmask8)__U, (__v8sf)_mm256_cvtpbh_ps(__A), (__v8sf)__S);
509519
}
510520

511521
#undef __DEFAULT_FN_ATTRS128
512522
#undef __DEFAULT_FN_ATTRS256
523+
#undef __DEFAULT_FN_ATTRS128_CONSTEXPR
524+
#undef __DEFAULT_FN_ATTRS256_CONSTEXPR
513525

514526
#endif
515527
#endif

clang/test/CodeGen/X86/avx512bf16-builtins.c

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,13 +9,15 @@
99
// RUN: %clang_cc1 -x c++ -flax-vector-conversions=none -ffreestanding %s -triple=i386-apple-darwin -target-feature +avx512bf16 -emit-llvm -o - -Wall -Werror -fexperimental-new-constant-interpreter | FileCheck %s
1010

1111
#include <immintrin.h>
12+
#include "builtin_test_helpers.h"
1213

1314
float test_mm_cvtsbh_ss(__bf16 A) {
1415
// CHECK-LABEL: test_mm_cvtsbh_ss
1516
// CHECK: fpext bfloat %{{.*}} to float
1617
// CHECK: ret float %{{.*}}
1718
return _mm_cvtsbh_ss(A);
1819
}
20+
TEST_CONSTEXPR(_mm_cvtsbh_ss(-1.0f) == -1.0f);
1921

2022
__m512bh test_mm512_cvtne2ps_pbh(__m512 A, __m512 B) {
2123
// CHECK-LABEL: test_mm512_cvtne2ps_pbh
@@ -82,17 +84,20 @@ __m512 test_mm512_cvtpbh_ps(__m256bh A) {
8284
// CHECK: fpext <16 x bfloat> %{{.*}} to <16 x float>
8385
return _mm512_cvtpbh_ps(A);
8486
}
87+
TEST_CONSTEXPR(match_m512(_mm512_cvtpbh_ps((__m256bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f, -128.0f, -0.5f, 0.25f, -0.125f, -4.0f, 2.0f, -1.0f, 0.0f}), -0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f, -128.0f, -0.5f, 0.25f, -0.125f, -4.0f, 2.0f, -1.0f, 0.0f));
8588

8689
__m512 test_mm512_maskz_cvtpbh_ps(__mmask16 M, __m256bh A) {
8790
// CHECK-LABEL: test_mm512_maskz_cvtpbh_ps
8891
// CHECK: fpext <16 x bfloat> %{{.*}} to <16 x float>
8992
// CHECK: select <16 x i1> %{{.*}}, <16 x float> %{{.*}}, <16 x float> %{{.*}}
9093
return _mm512_maskz_cvtpbh_ps(M, A);
9194
}
95+
TEST_CONSTEXPR(match_m512(_mm512_maskz_cvtpbh_ps(0xA753, (__m256bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f, -128.0f, -0.5f, 0.25f, -0.125f, -4.0f, 2.0f, -1.0f, 0.0f}), -0.0f, 1.0f, 0.0f, 0.0f, -8.0f, 0.0f, -32.0f, 0.0f, -128.0f, -0.5f, 0.25f, 0.0f, 0.0f, 2.0f, 0.0f, 0.0f));
9296

9397
__m512 test_mm512_mask_cvtpbh_ps(__m512 S, __mmask16 M, __m256bh A) {
9498
// CHECK-LABEL: test_mm512_mask_cvtpbh_ps
9599
// CHECK: fpext <16 x bfloat> %{{.*}} to <16 x float>
96100
// CHECK: select <16 x i1> %{{.*}}, <16 x float> %{{.*}}, <16 x float> %{{.*}}
97101
return _mm512_mask_cvtpbh_ps(S, M, A);
98102
}
103+
TEST_CONSTEXPR(match_m512(_mm512_mask_cvtpbh_ps((__m512){ 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f }, 0xA753, (__m256bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f, -128.0f, -0.5f, 0.25f, -0.125f, -4.0f, 2.0f, -1.0f, 0.0f}), -0.0f, 1.0f, 99.0f, 99.0f, -8.0f, 99.0f, -32.0f, 99.0f, -128.0f, -0.5f, 0.25f, 99.0f, 99.0f, 2.0f, 99.0f, 0.0f));

clang/test/CodeGen/X86/avx512vlbf16-builtins.c

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
// RUN: %clang_cc1 -x c++ -flax-vector-conversions=none -ffreestanding %s -triple=i386-apple-darwin -target-feature +avx512bf16 -target-feature +avx512vl -emit-llvm -o - -Wall -Werror -fexperimental-new-constant-interpreter | FileCheck %s
1010

1111
#include <immintrin.h>
12+
#include "builtin_test_helpers.h"
1213

1314
__m128bh test_mm_cvtne2ps2bf16(__m128 A, __m128 B) {
1415
// CHECK-LABEL: test_mm_cvtne2ps2bf16
@@ -160,12 +161,14 @@ __m128 test_mm_cvtpbh_ps(__m128bh A) {
160161
// CHECK: shufflevector <8 x float> %{{.*}}, <8 x float> %{{.*}}, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
161162
return _mm_cvtpbh_ps(A);
162163
}
164+
TEST_CONSTEXPR(match_m128(_mm_cvtpbh_ps((__m128bh){-8.0f, 16.0f, -32.0f, 64.0f, -0.0f, 1.0f, -2.0f, 4.0f}), -8.0f, 16.0f, -32.0f, 64.0f));
163165

164166
__m256 test_mm256_cvtpbh_ps(__m128bh A) {
165167
// CHECK-LABEL: test_mm256_cvtpbh_ps
166168
// CHECK: fpext <8 x bfloat> %{{.*}} to <8 x float>
167169
return _mm256_cvtpbh_ps(A);
168170
}
171+
TEST_CONSTEXPR(match_m256(_mm256_cvtpbh_ps((__m128bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f}), -0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f));
169172

170173
__m128 test_mm_maskz_cvtpbh_ps(__mmask8 M, __m128bh A) {
171174
// CHECK-LABEL: test_mm_maskz_cvtpbh_ps
@@ -174,13 +177,15 @@ __m128 test_mm_maskz_cvtpbh_ps(__mmask8 M, __m128bh A) {
174177
// CHECK: select <4 x i1> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}}
175178
return _mm_maskz_cvtpbh_ps(M, A);
176179
}
180+
TEST_CONSTEXPR(match_m128(_mm_maskz_cvtpbh_ps(0x01, (__m128bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f}), -0.0f, 0.0f, 0.0f, 0.0f));
177181

178182
__m256 test_mm256_maskz_cvtpbh_ps(__mmask8 M, __m128bh A) {
179183
// CHECK-LABEL: test_mm256_maskz_cvtpbh_ps
180184
// CHECK: fpext <8 x bfloat> %{{.*}} to <8 x float>
181185
// CHECK: select <8 x i1> %{{.*}}, <8 x float> %{{.*}}, <8 x float> %{{.*}}
182186
return _mm256_maskz_cvtpbh_ps(M, A);
183187
}
188+
TEST_CONSTEXPR(match_m256(_mm256_maskz_cvtpbh_ps(0x73, (__m128bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f}), -0.0f, 1.0f, 0.0f, 0.0f, -8.0f, 16.0f, -32.0f, 0.0f));
184189

185190
__m128 test_mm_mask_cvtpbh_ps(__m128 S, __mmask8 M, __m128bh A) {
186191
// CHECK-LABEL: test_mm_mask_cvtpbh_ps
@@ -189,10 +194,12 @@ __m128 test_mm_mask_cvtpbh_ps(__m128 S, __mmask8 M, __m128bh A) {
189194
// CHECK: select <4 x i1> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}}
190195
return _mm_mask_cvtpbh_ps(S, M, A);
191196
}
197+
TEST_CONSTEXPR(match_m128(_mm_mask_cvtpbh_ps((__m128){ 99.0f, 99.0f, 99.0f, 99.0f }, 0x03, (__m128bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f}), -0.0f, 1.0f, 99.0f, 99.0f));
192198

193199
__m256 test_mm256_mask_cvtpbh_ps(__m256 S, __mmask8 M, __m128bh A) {
194200
// CHECK-LABEL: test_mm256_mask_cvtpbh_ps
195201
// CHECK: fpext <8 x bfloat> %{{.*}} to <8 x float>
196202
// CHECK: select <8 x i1> %{{.*}}, <8 x float> %{{.*}}, <8 x float> %{{.*}}
197203
return _mm256_mask_cvtpbh_ps(S, M, A);
198204
}
205+
TEST_CONSTEXPR(match_m256(_mm256_mask_cvtpbh_ps((__m256){ 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f, 99.0f }, 0x37, (__m128bh){-0.0f, 1.0f, -2.0f, 4.0f, -8.0f, 16.0f, -32.0f, 64.0f}), -0.0f, 1.0f, -2.0f, 99.0f, -8.0f, 16.0f, 99.0f, 99.0f));

0 commit comments

Comments
 (0)