Skip to content

Commit 0542f5a

Browse files
committed
[X86] Replace BF16 to F32 conversions with generic conversions
Let standard casting / builtin_convertvector handle the conversions from BF16 to F32 My only query is how to best implement _mm_cvtpbh_ps - I went for the v8bf16 -> v8f32 conversion followed by subvector extraction in the end, but could just as easily extract a v4bf16 first - makes no difference to final codegen. First part of #154911
1 parent c28c99f commit 0542f5a

File tree

6 files changed

+37
-52
lines changed

6 files changed

+37
-52
lines changed

clang/include/clang/Basic/BuiltinsX86.td

Lines changed: 0 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -3358,10 +3358,6 @@ let Features = "avx512bf16", Attributes = [NoThrow, Const, RequiredVectorWidth<5
33583358
def dpbf16ps_512 : X86Builtin<"_Vector<16, float>(_Vector<16, float>, _Vector<32, __bf16>, _Vector<32, __bf16>)">;
33593359
}
33603360

3361-
let Features = "avx512bf16", Attributes = [NoThrow, Const] in {
3362-
def cvtsbf162ss_32 : X86Builtin<"float(__bf16)">;
3363-
}
3364-
33653361
let Features = "avx512vp2intersect", Attributes = [NoThrow, RequiredVectorWidth<512>] in {
33663362
def vp2intersect_q_512 : X86Builtin<"void(_Vector<8, long long int>, _Vector<8, long long int>, unsigned char *, unsigned char *)">;
33673363
}

clang/lib/CodeGen/TargetBuiltins/X86.cpp

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2796,8 +2796,6 @@ Value *CodeGenFunction::EmitX86BuiltinExpr(unsigned BuiltinID,
27962796
Intrinsic::ID IID = Intrinsic::x86_avx512bf16_mask_cvtneps2bf16_128;
27972797
return Builder.CreateCall(CGM.getIntrinsic(IID), Ops);
27982798
}
2799-
case X86::BI__builtin_ia32_cvtsbf162ss_32:
2800-
return Builder.CreateFPExt(Ops[0], Builder.getFloatTy());
28012799

28022800
case X86::BI__builtin_ia32_cvtneps2bf16_256_mask:
28032801
case X86::BI__builtin_ia32_cvtneps2bf16_512_mask: {

clang/lib/Headers/avx512bf16intrin.h

Lines changed: 7 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,7 @@ typedef __bf16 __bfloat16 __attribute__((deprecated("use __bf16 instead")));
3636
/// \returns A float data whose sign field and exponent field keep unchanged,
3737
/// and fraction field is extended to 23 bits.
3838
static __inline__ float __DEFAULT_FN_ATTRS _mm_cvtsbh_ss(__bf16 __A) {
39-
return __builtin_ia32_cvtsbf162ss_32(__A);
39+
return (float)(__A);
4040
}
4141

4242
/// Convert Two Packed Single Data to One Packed BF16 Data.
@@ -236,8 +236,7 @@ _mm512_maskz_dpbf16_ps(__mmask16 __U, __m512 __D, __m512bh __A, __m512bh __B) {
236236
/// A 256-bit vector of [16 x bfloat].
237237
/// \returns A 512-bit vector of [16 x float] come from conversion of __A
238238
static __inline__ __m512 __DEFAULT_FN_ATTRS512 _mm512_cvtpbh_ps(__m256bh __A) {
239-
return _mm512_castsi512_ps((__m512i)_mm512_slli_epi32(
240-
(__m512i)_mm512_cvtepi16_epi32((__m256i)__A), 16));
239+
return (__m512) __builtin_convertvector(__A, __v16sf);
241240
}
242241

243242
/// Convert Packed BF16 Data to Packed float Data using zeroing mask.
@@ -252,8 +251,9 @@ static __inline__ __m512 __DEFAULT_FN_ATTRS512 _mm512_cvtpbh_ps(__m256bh __A) {
252251
/// \returns A 512-bit vector of [16 x float] come from conversion of __A
253252
static __inline__ __m512 __DEFAULT_FN_ATTRS512
254253
_mm512_maskz_cvtpbh_ps(__mmask16 __U, __m256bh __A) {
255-
return _mm512_castsi512_ps((__m512i)_mm512_slli_epi32(
256-
(__m512i)_mm512_maskz_cvtepi16_epi32((__mmask16)__U, (__m256i)__A), 16));
254+
return (__m512)__builtin_ia32_selectps_512((__mmask16)__U,
255+
(__v16sf)_mm512_cvtpbh_ps(__A),
256+
(__v16sf)_mm512_setzero_ps());
257257
}
258258

259259
/// Convert Packed BF16 Data to Packed float Data using merging mask.
@@ -270,9 +270,8 @@ _mm512_maskz_cvtpbh_ps(__mmask16 __U, __m256bh __A) {
270270
/// \returns A 512-bit vector of [16 x float] come from conversion of __A
271271
static __inline__ __m512 __DEFAULT_FN_ATTRS512
272272
_mm512_mask_cvtpbh_ps(__m512 __S, __mmask16 __U, __m256bh __A) {
273-
return _mm512_castsi512_ps((__m512i)_mm512_mask_slli_epi32(
274-
(__m512i)__S, (__mmask16)__U,
275-
(__m512i)_mm512_cvtepi16_epi32((__m256i)__A), 16));
273+
return (__m512)__builtin_ia32_selectps_512(
274+
(__mmask16)__U, (__v16sf)_mm512_cvtpbh_ps(__A), (__v16sf)__S);
276275
}
277276

278277
#undef __DEFAULT_FN_ATTRS

clang/lib/Headers/avx512vlbf16intrin.h

Lines changed: 12 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -422,8 +422,8 @@ static __inline__ __bf16 __DEFAULT_FN_ATTRS128 _mm_cvtness_sbh(float __A) {
422422
/// A 128-bit vector of [4 x bfloat].
423423
/// \returns A 128-bit vector of [4 x float] come from conversion of __A
424424
static __inline__ __m128 __DEFAULT_FN_ATTRS128 _mm_cvtpbh_ps(__m128bh __A) {
425-
return _mm_castsi128_ps(
426-
(__m128i)_mm_slli_epi32((__m128i)_mm_cvtepi16_epi32((__m128i)__A), 16));
425+
return (__m128)_mm256_castps256_ps128(
426+
(__m256) __builtin_convertvector(__A, __v8sf));
427427
}
428428

429429
/// Convert Packed BF16 Data to Packed float Data.
@@ -434,8 +434,7 @@ static __inline__ __m128 __DEFAULT_FN_ATTRS128 _mm_cvtpbh_ps(__m128bh __A) {
434434
/// A 128-bit vector of [8 x bfloat].
435435
/// \returns A 256-bit vector of [8 x float] come from conversion of __A
436436
static __inline__ __m256 __DEFAULT_FN_ATTRS256 _mm256_cvtpbh_ps(__m128bh __A) {
437-
return _mm256_castsi256_ps((__m256i)_mm256_slli_epi32(
438-
(__m256i)_mm256_cvtepi16_epi32((__m128i)__A), 16));
437+
return (__m256) __builtin_convertvector(__A, __v8sf);
439438
}
440439

441440
/// Convert Packed BF16 Data to Packed float Data using zeroing mask.
@@ -450,8 +449,8 @@ static __inline__ __m256 __DEFAULT_FN_ATTRS256 _mm256_cvtpbh_ps(__m128bh __A) {
450449
/// \returns A 128-bit vector of [4 x float] come from conversion of __A
451450
static __inline__ __m128 __DEFAULT_FN_ATTRS128
452451
_mm_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
453-
return _mm_castsi128_ps((__m128i)_mm_slli_epi32(
454-
(__m128i)_mm_maskz_cvtepi16_epi32((__mmask8)__U, (__m128i)__A), 16));
452+
return (__m128)__builtin_ia32_selectps_128(
453+
(__mmask8)__U, (__v4sf)_mm_cvtpbh_ps(__A), (__v4sf)_mm_setzero_ps());
455454
}
456455

457456
/// Convert Packed BF16 Data to Packed float Data using zeroing mask.
@@ -466,8 +465,9 @@ _mm_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
466465
/// \returns A 256-bit vector of [8 x float] come from conversion of __A
467466
static __inline__ __m256 __DEFAULT_FN_ATTRS256
468467
_mm256_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
469-
return _mm256_castsi256_ps((__m256i)_mm256_slli_epi32(
470-
(__m256i)_mm256_maskz_cvtepi16_epi32((__mmask8)__U, (__m128i)__A), 16));
468+
return (__m256)__builtin_ia32_selectps_256((__mmask8)__U,
469+
(__v8sf)_mm256_cvtpbh_ps(__A),
470+
(__v8sf)_mm256_setzero_ps());
471471
}
472472

473473
/// Convert Packed BF16 Data to Packed float Data using merging mask.
@@ -485,9 +485,8 @@ _mm256_maskz_cvtpbh_ps(__mmask8 __U, __m128bh __A) {
485485
/// \returns A 128-bit vector of [4 x float] come from conversion of __A
486486
static __inline__ __m128 __DEFAULT_FN_ATTRS128
487487
_mm_mask_cvtpbh_ps(__m128 __S, __mmask8 __U, __m128bh __A) {
488-
return _mm_castsi128_ps((__m128i)_mm_mask_slli_epi32(
489-
(__m128i)__S, (__mmask8)__U, (__m128i)_mm_cvtepi16_epi32((__m128i)__A),
490-
16));
488+
return (__m128)__builtin_ia32_selectps_128(
489+
(__mmask8)__U, (__v4sf)_mm_cvtpbh_ps(__A), (__v4sf)__S);
491490
}
492491

493492
/// Convert Packed BF16 Data to Packed float Data using merging mask.
@@ -505,9 +504,8 @@ _mm_mask_cvtpbh_ps(__m128 __S, __mmask8 __U, __m128bh __A) {
505504
/// \returns A 256-bit vector of [8 x float] come from conversion of __A
506505
static __inline__ __m256 __DEFAULT_FN_ATTRS256
507506
_mm256_mask_cvtpbh_ps(__m256 __S, __mmask8 __U, __m128bh __A) {
508-
return _mm256_castsi256_ps((__m256i)_mm256_mask_slli_epi32(
509-
(__m256i)__S, (__mmask8)__U, (__m256i)_mm256_cvtepi16_epi32((__m128i)__A),
510-
16));
507+
return (__m256)__builtin_ia32_selectps_256(
508+
(__mmask8)__U, (__v8sf)_mm256_cvtpbh_ps(__A), (__v8sf)__S);
511509
}
512510

513511
#undef __DEFAULT_FN_ATTRS128

clang/test/CodeGen/X86/avx512bf16-builtins.c

Lines changed: 5 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -79,23 +79,20 @@ __m512 test_mm512_mask_dpbf16_ps(__m512 D, __m512bh A, __m512bh B, __mmask16 U)
7979

8080
__m512 test_mm512_cvtpbh_ps(__m256bh A) {
8181
// CHECK-LABEL: test_mm512_cvtpbh_ps
82-
// CHECK: sext <16 x i16> %{{.*}} to <16 x i32>
83-
// CHECK: call <16 x i32> @llvm.x86.avx512.pslli.d.512(<16 x i32> %{{.*}}, i32 %{{.*}})
82+
// CHECK: fpext <16 x bfloat> %{{.*}} to <16 x float>
8483
return _mm512_cvtpbh_ps(A);
8584
}
8685

8786
__m512 test_mm512_maskz_cvtpbh_ps(__mmask16 M, __m256bh A) {
8887
// CHECK-LABEL: test_mm512_maskz_cvtpbh_ps
89-
// CHECK: sext <16 x i16> %{{.*}} to <16 x i32>
90-
// CHECK: select <16 x i1> %{{.*}}, <16 x i32> %{{.*}}, <16 x i32> %{{.*}}
91-
// CHECK: call <16 x i32> @llvm.x86.avx512.pslli.d.512(<16 x i32> %{{.*}}, i32 %{{.*}})
88+
// CHECK: fpext <16 x bfloat> %{{.*}} to <16 x float>
89+
// CHECK: select <16 x i1> %{{.*}}, <16 x float> %{{.*}}, <16 x float> %{{.*}}
9290
return _mm512_maskz_cvtpbh_ps(M, A);
9391
}
9492

9593
__m512 test_mm512_mask_cvtpbh_ps(__m512 S, __mmask16 M, __m256bh A) {
9694
// CHECK-LABEL: test_mm512_mask_cvtpbh_ps
97-
// CHECK: sext <16 x i16> %{{.*}} to <16 x i32>
98-
// CHECK: call <16 x i32> @llvm.x86.avx512.pslli.d.512(<16 x i32> %{{.*}}, i32 %{{.*}})
99-
// CHECK: select <16 x i1> %{{.*}}, <16 x i32> %{{.*}}, <16 x i32> %{{.*}}
95+
// CHECK: fpext <16 x bfloat> %{{.*}} to <16 x float>
96+
// CHECK: select <16 x i1> %{{.*}}, <16 x float> %{{.*}}, <16 x float> %{{.*}}
10097
return _mm512_mask_cvtpbh_ps(S, M, A);
10198
}

clang/test/CodeGen/X86/avx512vlbf16-builtins.c

Lines changed: 13 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -156,46 +156,43 @@ __bf16 test_mm_cvtness_sbh(float A) {
156156

157157
__m128 test_mm_cvtpbh_ps(__m128bh A) {
158158
// CHECK-LABEL: test_mm_cvtpbh_ps
159-
// CHECK: sext <4 x i16> %{{.*}} to <4 x i32>
160-
// CHECK: call <4 x i32> @llvm.x86.sse2.pslli.d(<4 x i32> %{{.*}}, i32 %{{.*}})
159+
// CHECK: fpext <8 x bfloat> %{{.*}} to <8 x float>
160+
// CHECK: shufflevector <8 x float> %{{.*}}, <8 x float> %{{.*}}, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
161161
return _mm_cvtpbh_ps(A);
162162
}
163163

164164
__m256 test_mm256_cvtpbh_ps(__m128bh A) {
165165
// CHECK-LABEL: test_mm256_cvtpbh_ps
166-
// CHECK: sext <8 x i16> %{{.*}} to <8 x i32>
167-
// CHECK: call <8 x i32> @llvm.x86.avx2.pslli.d(<8 x i32> %{{.*}}, i32 %{{.*}})
166+
// CHECK: fpext <8 x bfloat> %{{.*}} to <8 x float>
168167
return _mm256_cvtpbh_ps(A);
169168
}
170169

171170
__m128 test_mm_maskz_cvtpbh_ps(__mmask8 M, __m128bh A) {
172171
// CHECK-LABEL: test_mm_maskz_cvtpbh_ps
173-
// CHECK: sext <4 x i16> %{{.*}} to <4 x i32>
174-
// CHECK: select <4 x i1> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}
175-
// CHECK: call <4 x i32> @llvm.x86.sse2.pslli.d(<4 x i32> %{{.*}}, i32 %{{.*}})
172+
// CHECK: fpext <8 x bfloat> %{{.*}} to <8 x float>
173+
// CHECK: shufflevector <8 x float> %{{.*}}, <8 x float> %{{.*}}, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
174+
// CHECK: select <4 x i1> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}}
176175
return _mm_maskz_cvtpbh_ps(M, A);
177176
}
178177

179178
__m256 test_mm256_maskz_cvtpbh_ps(__mmask8 M, __m128bh A) {
180179
// CHECK-LABEL: test_mm256_maskz_cvtpbh_ps
181-
// CHECK: sext <8 x i16> %{{.*}} to <8 x i32>
182-
// CHECK: select <8 x i1> %{{.*}}, <8 x i32> %{{.*}}, <8 x i32> %{{.*}}
183-
// CHECK: call <8 x i32> @llvm.x86.avx2.pslli.d(<8 x i32> %{{.*}}, i32 %{{.*}})
180+
// CHECK: fpext <8 x bfloat> %{{.*}} to <8 x float>
181+
// CHECK: select <8 x i1> %{{.*}}, <8 x float> %{{.*}}, <8 x float> %{{.*}}
184182
return _mm256_maskz_cvtpbh_ps(M, A);
185183
}
186184

187185
__m128 test_mm_mask_cvtpbh_ps(__m128 S, __mmask8 M, __m128bh A) {
188186
// CHECK-LABEL: test_mm_mask_cvtpbh_ps
189-
// CHECK: sext <4 x i16> %{{.*}} to <4 x i32>
190-
// CHECK: call <4 x i32> @llvm.x86.sse2.pslli.d(<4 x i32> %{{.*}}, i32 %{{.*}})
191-
// CHECK: select <4 x i1> %{{.*}}, <4 x i32> %{{.*}}, <4 x i32> %{{.*}}
187+
// CHECK: fpext <8 x bfloat> %{{.*}} to <8 x float>
188+
// CHECK: shufflevector <8 x float> %{{.*}}, <8 x float> %{{.*}}, <4 x i32> <i32 0, i32 1, i32 2, i32 3>
189+
// CHECK: select <4 x i1> %{{.*}}, <4 x float> %{{.*}}, <4 x float> %{{.*}}
192190
return _mm_mask_cvtpbh_ps(S, M, A);
193191
}
194192

195193
__m256 test_mm256_mask_cvtpbh_ps(__m256 S, __mmask8 M, __m128bh A) {
196194
// CHECK-LABEL: test_mm256_mask_cvtpbh_ps
197-
// CHECK: sext <8 x i16> %{{.*}} to <8 x i32>
198-
// CHECK: call <8 x i32> @llvm.x86.avx2.pslli.d(<8 x i32> %{{.*}}, i32 %{{.*}})
199-
// CHECK: select <8 x i1> %{{.*}}, <8 x i32> %{{.*}}, <8 x i32> %{{.*}}
195+
// CHECK: fpext <8 x bfloat> %{{.*}} to <8 x float>
196+
// CHECK: select <8 x i1> %{{.*}}, <8 x float> %{{.*}}, <8 x float> %{{.*}}
200197
return _mm256_mask_cvtpbh_ps(S, M, A);
201198
}

0 commit comments

Comments
 (0)