Skip to content

Commit 1cbefc2

Browse files
committed
[msan] Support x86_avx512bf16_dpbf16ps
Use the generalized handleVectorPmaddIntrinsic(), but multiplication by an initialized zero does not guarantee that the result is zero (counter-example: multiply zero by NaN).
1 parent 3ad5765 commit 1cbefc2

File tree

3 files changed

+109
-153
lines changed

3 files changed

+109
-153
lines changed

llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp

Lines changed: 13 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5925,15 +5925,19 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
59255925
/*ZeroPurifies=*/true, /*EltSizeInBits=*/16);
59265926
break;
59275927

5928-
// TODO: Dot Product of BF16 Pairs Accumulated Into Packed Single
5929-
// Precision
5930-
// <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128
5931-
// (<4 x float>, <8 x bfloat>, <8 x bfloat>)
5932-
// <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256
5933-
// (<8 x float>, <16 x bfloat>, <16 x bfloat>)
5934-
// <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512
5935-
// (<16 x float>, <32 x bfloat>, <32 x bfloat>)
5936-
// handleVectorPmaddIntrinsic() currently only handles integer types.
5928+
// Dot Product of BF16 Pairs Accumulated Into Packed Single
5929+
// Precision
5930+
// <4 x float> @llvm.x86.avx512bf16.dpbf16ps.128
5931+
// (<4 x float>, <8 x bfloat>, <8 x bfloat>)
5932+
// <8 x float> @llvm.x86.avx512bf16.dpbf16ps.256
5933+
// (<8 x float>, <16 x bfloat>, <16 x bfloat>)
5934+
// <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512
5935+
// (<16 x float>, <32 x bfloat>, <32 x bfloat>)
5936+
case Intrinsic::x86_avx512bf16_dpbf16ps_128:
5937+
case Intrinsic::x86_avx512bf16_dpbf16ps_256:
5938+
case Intrinsic::x86_avx512bf16_dpbf16ps_512:
5939+
handleVectorPmaddIntrinsic(I, /*ReductionFactor=*/2, /*ZeroPurifies=*/false);
5940+
break;
59375941

59385942
case Intrinsic::x86_sse_cmp_ss:
59395943
case Intrinsic::x86_sse2_cmp_sd:

llvm/test/Instrumentation/MemorySanitizer/X86/avx512bf16-intrinsics.ll

Lines changed: 32 additions & 48 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
; Strictly handled:
77
; - llvm.x86.avx512bf16.cvtne2ps2bf16.512(<16 x float> %A, <16 x float> %B)
88
; - llvm.x86.avx512bf16.cvtneps2bf16.512(<16 x float> %A)
9-
; - llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> %E, <32 x bfloat> %A, <32 x bfloat> %B)
109
;
1110
; Heuristically handled: (none)
1211

@@ -241,25 +240,20 @@ define <16 x float> @test_mm512_dpbf16ps_512(<16 x float> %E, <32 x bfloat> %A,
241240
; CHECK-LABEL: define <16 x float> @test_mm512_dpbf16ps_512(
242241
; CHECK-SAME: <16 x float> [[E:%.*]], <32 x bfloat> [[A:%.*]], <32 x bfloat> [[B:%.*]]) local_unnamed_addr #[[ATTR1]] {
243242
; CHECK-NEXT: [[ENTRY:.*:]]
244-
; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
245243
; CHECK-NEXT: [[TMP1:%.*]] = load <32 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 64), align 8
246244
; CHECK-NEXT: [[TMP2:%.*]] = load <32 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 128), align 8
245+
; CHECK-NEXT: [[TMP11:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
247246
; CHECK-NEXT: call void @llvm.donothing()
248-
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <16 x i32> [[TMP0]] to i512
249-
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP3]], 0
250-
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <32 x i16> [[TMP1]] to i512
251-
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP4]], 0
252-
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
253-
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <32 x i16> [[TMP2]] to i512
254-
; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i512 [[TMP5]], 0
255-
; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]]
256-
; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB6:.*]], label %[[BB7:.*]], !prof [[PROF1]]
257-
; CHECK: [[BB6]]:
258-
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR4]]
259-
; CHECK-NEXT: unreachable
260-
; CHECK: [[BB7]]:
247+
; CHECK-NEXT: [[TMP3:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer
248+
; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <32 x i16> [[TMP2]], zeroinitializer
249+
; CHECK-NEXT: [[TMP5:%.*]] = or <32 x i1> [[TMP3]], [[TMP4]]
250+
; CHECK-NEXT: [[TMP6:%.*]] = sext <32 x i1> [[TMP5]] to <32 x i16>
251+
; CHECK-NEXT: [[TMP7:%.*]] = bitcast <32 x i16> [[TMP6]] to <16 x i32>
252+
; CHECK-NEXT: [[TMP12:%.*]] = icmp ne <16 x i32> [[TMP7]], zeroinitializer
253+
; CHECK-NEXT: [[TMP9:%.*]] = sext <16 x i1> [[TMP12]] to <16 x i32>
254+
; CHECK-NEXT: [[TMP10:%.*]] = or <16 x i32> [[TMP9]], [[TMP11]]
261255
; CHECK-NEXT: [[TMP8:%.*]] = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> [[E]], <32 x bfloat> [[A]], <32 x bfloat> [[B]])
262-
; CHECK-NEXT: store <16 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
256+
; CHECK-NEXT: store <16 x i32> [[TMP10]], ptr @__msan_retval_tls, align 8
263257
; CHECK-NEXT: ret <16 x float> [[TMP8]]
264258
;
265259
entry:
@@ -271,31 +265,26 @@ define <16 x float> @test_mm512_maskz_dpbf16ps_512(<16 x float> %E, <32 x bfloat
271265
; CHECK-LABEL: define <16 x float> @test_mm512_maskz_dpbf16ps_512(
272266
; CHECK-SAME: <16 x float> [[E:%.*]], <32 x bfloat> [[A:%.*]], <32 x bfloat> [[B:%.*]], i16 zeroext [[U:%.*]]) local_unnamed_addr #[[ATTR1]] {
273267
; CHECK-NEXT: [[ENTRY:.*:]]
274-
; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
275268
; CHECK-NEXT: [[TMP1:%.*]] = load <32 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 64), align 8
276269
; CHECK-NEXT: [[TMP2:%.*]] = load <32 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 128), align 8
270+
; CHECK-NEXT: [[TMP18:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
277271
; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr getelementptr (i8, ptr @__msan_param_tls, i64 192), align 8
278272
; CHECK-NEXT: call void @llvm.donothing()
279-
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <16 x i32> [[TMP0]] to i512
280-
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP4]], 0
281-
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <32 x i16> [[TMP1]] to i512
282-
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP5]], 0
283-
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
284-
; CHECK-NEXT: [[TMP6:%.*]] = bitcast <32 x i16> [[TMP2]] to i512
285-
; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i512 [[TMP6]], 0
286-
; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]]
287-
; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB7:.*]], label %[[BB8:.*]], !prof [[PROF1]]
288-
; CHECK: [[BB7]]:
289-
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR4]]
290-
; CHECK-NEXT: unreachable
291-
; CHECK: [[BB8]]:
273+
; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer
274+
; CHECK-NEXT: [[TMP5:%.*]] = icmp ne <32 x i16> [[TMP2]], zeroinitializer
275+
; CHECK-NEXT: [[TMP6:%.*]] = or <32 x i1> [[TMP4]], [[TMP5]]
276+
; CHECK-NEXT: [[TMP7:%.*]] = sext <32 x i1> [[TMP6]] to <32 x i16>
277+
; CHECK-NEXT: [[TMP8:%.*]] = bitcast <32 x i16> [[TMP7]] to <16 x i32>
278+
; CHECK-NEXT: [[TMP19:%.*]] = icmp ne <16 x i32> [[TMP8]], zeroinitializer
279+
; CHECK-NEXT: [[TMP20:%.*]] = sext <16 x i1> [[TMP19]] to <16 x i32>
280+
; CHECK-NEXT: [[TMP21:%.*]] = or <16 x i32> [[TMP20]], [[TMP18]]
292281
; CHECK-NEXT: [[TMP9:%.*]] = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> [[E]], <32 x bfloat> [[A]], <32 x bfloat> [[B]])
293282
; CHECK-NEXT: [[TMP10:%.*]] = bitcast i16 [[TMP3]] to <16 x i1>
294283
; CHECK-NEXT: [[TMP11:%.*]] = bitcast i16 [[U]] to <16 x i1>
295-
; CHECK-NEXT: [[TMP12:%.*]] = select <16 x i1> [[TMP11]], <16 x i32> zeroinitializer, <16 x i32> zeroinitializer
284+
; CHECK-NEXT: [[TMP12:%.*]] = select <16 x i1> [[TMP11]], <16 x i32> [[TMP21]], <16 x i32> zeroinitializer
296285
; CHECK-NEXT: [[TMP13:%.*]] = bitcast <16 x float> [[TMP9]] to <16 x i32>
297286
; CHECK-NEXT: [[TMP14:%.*]] = xor <16 x i32> [[TMP13]], zeroinitializer
298-
; CHECK-NEXT: [[TMP15:%.*]] = or <16 x i32> [[TMP14]], zeroinitializer
287+
; CHECK-NEXT: [[TMP15:%.*]] = or <16 x i32> [[TMP14]], [[TMP21]]
299288
; CHECK-NEXT: [[TMP16:%.*]] = or <16 x i32> [[TMP15]], zeroinitializer
300289
; CHECK-NEXT: [[_MSPROP_SELECT:%.*]] = select <16 x i1> [[TMP10]], <16 x i32> [[TMP16]], <16 x i32> [[TMP12]]
301290
; CHECK-NEXT: [[TMP17:%.*]] = select <16 x i1> [[TMP11]], <16 x float> [[TMP9]], <16 x float> zeroinitializer
@@ -312,32 +301,27 @@ define <16 x float> @test_mm512_mask_dpbf16ps_512(i16 zeroext %U, <16 x float> %
312301
; CHECK-LABEL: define <16 x float> @test_mm512_mask_dpbf16ps_512(
313302
; CHECK-SAME: i16 zeroext [[U:%.*]], <16 x float> [[E:%.*]], <32 x bfloat> [[A:%.*]], <32 x bfloat> [[B:%.*]]) local_unnamed_addr #[[ATTR1]] {
314303
; CHECK-NEXT: [[ENTRY:.*:]]
315-
; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i32>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 8), align 8
316304
; CHECK-NEXT: [[TMP1:%.*]] = load <32 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 72), align 8
317305
; CHECK-NEXT: [[TMP2:%.*]] = load <32 x i16>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 136), align 8
306+
; CHECK-NEXT: [[TMP0:%.*]] = load <16 x i32>, ptr getelementptr (i8, ptr @__msan_param_tls, i64 8), align 8
318307
; CHECK-NEXT: [[TMP3:%.*]] = load i16, ptr @__msan_param_tls, align 8
319308
; CHECK-NEXT: call void @llvm.donothing()
320-
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <16 x i32> [[TMP0]] to i512
321-
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP4]], 0
322-
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <32 x i16> [[TMP1]] to i512
323-
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP5]], 0
324-
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
325-
; CHECK-NEXT: [[TMP6:%.*]] = bitcast <32 x i16> [[TMP2]] to i512
326-
; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i512 [[TMP6]], 0
327-
; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]]
328-
; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB7:.*]], label %[[BB8:.*]], !prof [[PROF1]]
329-
; CHECK: [[BB7]]:
330-
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR4]]
331-
; CHECK-NEXT: unreachable
332-
; CHECK: [[BB8]]:
309+
; CHECK-NEXT: [[TMP4:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer
310+
; CHECK-NEXT: [[TMP5:%.*]] = icmp ne <32 x i16> [[TMP2]], zeroinitializer
311+
; CHECK-NEXT: [[TMP6:%.*]] = or <32 x i1> [[TMP4]], [[TMP5]]
312+
; CHECK-NEXT: [[TMP7:%.*]] = sext <32 x i1> [[TMP6]] to <32 x i16>
313+
; CHECK-NEXT: [[TMP8:%.*]] = bitcast <32 x i16> [[TMP7]] to <16 x i32>
314+
; CHECK-NEXT: [[TMP19:%.*]] = icmp ne <16 x i32> [[TMP8]], zeroinitializer
315+
; CHECK-NEXT: [[TMP20:%.*]] = sext <16 x i1> [[TMP19]] to <16 x i32>
316+
; CHECK-NEXT: [[TMP21:%.*]] = or <16 x i32> [[TMP20]], [[TMP0]]
333317
; CHECK-NEXT: [[TMP9:%.*]] = tail call <16 x float> @llvm.x86.avx512bf16.dpbf16ps.512(<16 x float> [[E]], <32 x bfloat> [[A]], <32 x bfloat> [[B]])
334318
; CHECK-NEXT: [[TMP10:%.*]] = bitcast i16 [[TMP3]] to <16 x i1>
335319
; CHECK-NEXT: [[TMP11:%.*]] = bitcast i16 [[U]] to <16 x i1>
336-
; CHECK-NEXT: [[TMP12:%.*]] = select <16 x i1> [[TMP11]], <16 x i32> zeroinitializer, <16 x i32> [[TMP0]]
320+
; CHECK-NEXT: [[TMP12:%.*]] = select <16 x i1> [[TMP11]], <16 x i32> [[TMP21]], <16 x i32> [[TMP0]]
337321
; CHECK-NEXT: [[TMP13:%.*]] = bitcast <16 x float> [[TMP9]] to <16 x i32>
338322
; CHECK-NEXT: [[TMP14:%.*]] = bitcast <16 x float> [[E]] to <16 x i32>
339323
; CHECK-NEXT: [[TMP15:%.*]] = xor <16 x i32> [[TMP13]], [[TMP14]]
340-
; CHECK-NEXT: [[TMP16:%.*]] = or <16 x i32> [[TMP15]], zeroinitializer
324+
; CHECK-NEXT: [[TMP16:%.*]] = or <16 x i32> [[TMP15]], [[TMP21]]
341325
; CHECK-NEXT: [[TMP17:%.*]] = or <16 x i32> [[TMP16]], [[TMP0]]
342326
; CHECK-NEXT: [[_MSPROP_SELECT:%.*]] = select <16 x i1> [[TMP10]], <16 x i32> [[TMP17]], <16 x i32> [[TMP12]]
343327
; CHECK-NEXT: [[TMP18:%.*]] = select <16 x i1> [[TMP11]], <16 x float> [[TMP9]], <16 x float> [[E]]

0 commit comments

Comments
 (0)