Skip to content

Commit 7ad70d2

Browse files
authored
[msan] Handle AVX512/AVX10 vrndscale (#160624)
Uses the updated handleAVX512VectorGenericMaskedFP() from #159966
1 parent 2810a48 commit 7ad70d2

File tree

4 files changed

+133
-131
lines changed

4 files changed

+133
-131
lines changed

llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp

Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6278,6 +6278,62 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
62786278
/*MaskIndex=*/2);
62796279
break;
62806280

6281+
// <32 x half> @llvm.x86.avx512fp16.mask.rndscale.ph.512
6282+
// (<32 x half>, i32, <32 x half>, i32, i32)
6283+
// <16 x half> @llvm.x86.avx512fp16.mask.rndscale.ph.256
6284+
// (<16 x half>, i32, <16 x half>, i32, i16)
6285+
// <8 x half> @llvm.x86.avx512fp16.mask.rndscale.ph.128
6286+
// (<8 x half>, i32, <8 x half>, i32, i8)
6287+
//
6288+
// <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512
6289+
// (<16 x float>, i32, <16 x float>, i16, i32)
6290+
// <8 x float> @llvm.x86.avx512.mask.rndscale.ps.256
6291+
// (<8 x float>, i32, <8 x float>, i8)
6292+
// <4 x float> @llvm.x86.avx512.mask.rndscale.ps.128
6293+
// (<4 x float>, i32, <4 x float>, i8)
6294+
//
6295+
// <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512
6296+
// (<8 x double>, i32, <8 x double>, i8, i32)
6297+
// A Imm WriteThru Mask Rounding
6298+
// <4 x double> @llvm.x86.avx512.mask.rndscale.pd.256
6299+
// (<4 x double>, i32, <4 x double>, i8)
6300+
// <2 x double> @llvm.x86.avx512.mask.rndscale.pd.128
6301+
// (<2 x double>, i32, <2 x double>, i8)
6302+
// A Imm WriteThru Mask
6303+
//
6304+
// <32 x bfloat> @llvm.x86.avx10.mask.rndscale.bf16.512
6305+
// (<32 x bfloat>, i32, <32 x bfloat>, i32)
6306+
// <16 x bfloat> @llvm.x86.avx10.mask.rndscale.bf16.256
6307+
// (<16 x bfloat>, i32, <16 x bfloat>, i16)
6308+
// <8 x bfloat> @llvm.x86.avx10.mask.rndscale.bf16.128
6309+
// (<8 x bfloat>, i32, <8 x bfloat>, i8)
6310+
//
6311+
// Not supported: three vectors
6312+
// - <8 x half> @llvm.x86.avx512fp16.mask.rndscale.sh
6313+
// (<8 x half>, <8 x half>,<8 x half>, i8, i32, i32)
6314+
// - <4 x float> @llvm.x86.avx512.mask.rndscale.ss
6315+
// (<4 x float>, <4 x float>, <4 x float>, i8, i32, i32)
6316+
// - <2 x double> @llvm.x86.avx512.mask.rndscale.sd
6317+
// (<2 x double>, <2 x double>, <2 x double>, i8, i32,
6318+
// i32)
6319+
// A B WriteThru Mask Imm
6320+
// Rounding
6321+
case Intrinsic::x86_avx512fp16_mask_rndscale_ph_512:
6322+
case Intrinsic::x86_avx512fp16_mask_rndscale_ph_256:
6323+
case Intrinsic::x86_avx512fp16_mask_rndscale_ph_128:
6324+
case Intrinsic::x86_avx512_mask_rndscale_ps_512:
6325+
case Intrinsic::x86_avx512_mask_rndscale_ps_256:
6326+
case Intrinsic::x86_avx512_mask_rndscale_ps_128:
6327+
case Intrinsic::x86_avx512_mask_rndscale_pd_512:
6328+
case Intrinsic::x86_avx512_mask_rndscale_pd_256:
6329+
case Intrinsic::x86_avx512_mask_rndscale_pd_128:
6330+
case Intrinsic::x86_avx10_mask_rndscale_bf16_512:
6331+
case Intrinsic::x86_avx10_mask_rndscale_bf16_256:
6332+
case Intrinsic::x86_avx10_mask_rndscale_bf16_128:
6333+
handleAVX512VectorGenericMaskedFP(I, /*AIndex=*/0, /*WriteThruIndex=*/2,
6334+
/*MaskIndex=*/3);
6335+
break;
6336+
62816337
// AVX512 FP16 Arithmetic
62826338
case Intrinsic::x86_avx512fp16_mask_add_sh_round:
62836339
case Intrinsic::x86_avx512fp16_mask_sub_sh_round:

llvm/test/Instrumentation/MemorySanitizer/X86/avx512-intrinsics.ll

Lines changed: 9 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@
2121
; - llvm.x86.avx512.mask.pmov.db.mem.512, llvm.x86.avx512.mask.pmov.dw.mem.512, llvm.x86.avx512.mask.pmov.qb.mem.512, llvm.x86.avx512.mask.pmov.qd.mem.512llvm.x86.avx512.mask.pmov.qw.mem.512
2222
; - llvm.x86.avx512.mask.pmovs.db.mem.512, llvm.x86.avx512.mask.pmovs.dw.mem.512, llvm.x86.avx512.mask.pmovs.qb.mem.512, llvm.x86.avx512.mask.pmovs.qd.mem.512, llvm.x86.avx512.mask.pmovs.qw.mem.512
2323
; - llvm.x86.avx512.mask.pmovus.db.mem.512, llvm.x86.avx512.mask.pmovus.dw.mem.512, llvm.x86.avx512.mask.pmovus.qb.mem.512, llvm.x86.avx512.mask.pmovus.qd.mem.512, llvm.x86.avx512.mask.pmovus.qw.mem.512
24-
; - llvm.x86.avx512.mask.rndscale.pd.512, llvm.x86.avx512.mask.rndscale.ps.512, llvm.x86.avx512.mask.rndscale.sd, llvm.x86.avx512.mask.rndscale.ss
24+
; - llvm.x86.avx512.mask.rndscale.sd, llvm.x86.avx512.mask.rndscale.ss
2525
; - llvm.x86.avx512.mask.scalef.pd.512, llvm.x86.avx512.mask.scalef.ps.512
2626
; - llvm.x86.avx512.mask.sqrt.sd, llvm.x86.avx512.mask.sqrt.ss
2727
; - llvm.x86.avx512.maskz.fixupimm.pd.512, llvm.x86.avx512.maskz.fixupimm.ps.512, llvm.x86.avx512.maskz.fixupimm.sd, llvm.x86.avx512.maskz.fixupimm.ss
@@ -965,18 +965,11 @@ define <8 x double> @test7(<8 x double> %a) #0 {
965965
; CHECK-LABEL: @test7(
966966
; CHECK-NEXT: [[TMP1:%.*]] = load <8 x i64>, ptr @__msan_param_tls, align 8
967967
; CHECK-NEXT: call void @llvm.donothing()
968-
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <8 x i64> [[TMP1]] to i512
969-
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
970-
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <8 x i64> [[TMP1]] to i512
971-
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP3]], 0
972-
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
973-
; CHECK-NEXT: br i1 [[_MSOR]], label [[TMP4:%.*]], label [[TMP5:%.*]], !prof [[PROF1]]
974-
; CHECK: 4:
975-
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
976-
; CHECK-NEXT: unreachable
977-
; CHECK: 5:
968+
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <8 x i64> [[TMP1]], zeroinitializer
969+
; CHECK-NEXT: [[TMP3:%.*]] = sext <8 x i1> [[TMP2]] to <8 x i64>
970+
; CHECK-NEXT: [[TMP4:%.*]] = select <8 x i1> splat (i1 true), <8 x i64> [[TMP3]], <8 x i64> [[TMP1]]
978971
; CHECK-NEXT: [[RES:%.*]] = call <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double> [[A:%.*]], i32 11, <8 x double> [[A]], i8 -1, i32 4)
979-
; CHECK-NEXT: store <8 x i64> zeroinitializer, ptr @__msan_retval_tls, align 8
972+
; CHECK-NEXT: store <8 x i64> [[TMP4]], ptr @__msan_retval_tls, align 8
980973
; CHECK-NEXT: ret <8 x double> [[RES]]
981974
;
982975
%res = call <8 x double> @llvm.x86.avx512.mask.rndscale.pd.512(<8 x double> %a, i32 11, <8 x double> %a, i8 -1, i32 4)
@@ -989,18 +982,11 @@ define <16 x float> @test8(<16 x float> %a) #0 {
989982
; CHECK-LABEL: @test8(
990983
; CHECK-NEXT: [[TMP1:%.*]] = load <16 x i32>, ptr @__msan_param_tls, align 8
991984
; CHECK-NEXT: call void @llvm.donothing()
992-
; CHECK-NEXT: [[TMP2:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
993-
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP2]], 0
994-
; CHECK-NEXT: [[TMP3:%.*]] = bitcast <16 x i32> [[TMP1]] to i512
995-
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP3]], 0
996-
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
997-
; CHECK-NEXT: br i1 [[_MSOR]], label [[TMP4:%.*]], label [[TMP5:%.*]], !prof [[PROF1]]
998-
; CHECK: 4:
999-
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR10]]
1000-
; CHECK-NEXT: unreachable
1001-
; CHECK: 5:
985+
; CHECK-NEXT: [[TMP2:%.*]] = icmp ne <16 x i32> [[TMP1]], zeroinitializer
986+
; CHECK-NEXT: [[TMP3:%.*]] = sext <16 x i1> [[TMP2]] to <16 x i32>
987+
; CHECK-NEXT: [[TMP4:%.*]] = select <16 x i1> splat (i1 true), <16 x i32> [[TMP3]], <16 x i32> [[TMP1]]
1002988
; CHECK-NEXT: [[RES:%.*]] = call <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float> [[A:%.*]], i32 11, <16 x float> [[A]], i16 -1, i32 4)
1003-
; CHECK-NEXT: store <16 x i32> zeroinitializer, ptr @__msan_retval_tls, align 8
989+
; CHECK-NEXT: store <16 x i32> [[TMP4]], ptr @__msan_retval_tls, align 8
1004990
; CHECK-NEXT: ret <16 x float> [[RES]]
1005991
;
1006992
%res = call <16 x float> @llvm.x86.avx512.mask.rndscale.ps.512(<16 x float> %a, i32 11, <16 x float> %a, i16 -1, i32 4)

llvm/test/Instrumentation/MemorySanitizer/X86/avx512fp16-intrinsics.ll

Lines changed: 13 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
; - llvm.x86.avx512fp16.mask.rcp.sh
1818
; - llvm.x86.avx512fp16.mask.reduce.ph.512
1919
; - llvm.x86.avx512fp16.mask.reduce.sh
20-
; - llvm.x86.avx512fp16.mask.rndscale.ph.512
2120
; - llvm.x86.avx512fp16.mask.rndscale.sh
2221
; - llvm.x86.avx512fp16.mask.rsqrt.sh
2322
; - llvm.x86.avx512fp16.mask.scalef.ph.512
@@ -868,36 +867,28 @@ declare <32 x half> @llvm.x86.avx512fp16.mask.rndscale.ph.512(<32 x half>, i32,
868867
define <32 x half>@test_int_x86_avx512_mask_rndscale_ph_512(<32 x half> %x0, <32 x half> %x2, i32 %x3) #0 {
869868
; CHECK-LABEL: define <32 x half> @test_int_x86_avx512_mask_rndscale_ph_512(
870869
; CHECK-SAME: <32 x half> [[X0:%.*]], <32 x half> [[X2:%.*]], i32 [[X3:%.*]]) #[[ATTR1]] {
870+
; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
871871
; CHECK-NEXT: [[TMP1:%.*]] = load <32 x i16>, ptr @__msan_param_tls, align 8
872872
; CHECK-NEXT: [[TMP2:%.*]] = load <32 x i16>, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 64) to ptr), align 8
873-
; CHECK-NEXT: [[TMP3:%.*]] = load i32, ptr inttoptr (i64 add (i64 ptrtoint (ptr @__msan_param_tls to i64), i64 128) to ptr), align 8
874873
; CHECK-NEXT: call void @llvm.donothing()
875-
; CHECK-NEXT: [[TMP4:%.*]] = bitcast <32 x i16> [[TMP1]] to i512
876-
; CHECK-NEXT: [[_MSCMP:%.*]] = icmp ne i512 [[TMP4]], 0
877-
; CHECK-NEXT: [[TMP5:%.*]] = bitcast <32 x i16> [[TMP2]] to i512
878-
; CHECK-NEXT: [[_MSCMP1:%.*]] = icmp ne i512 [[TMP5]], 0
879-
; CHECK-NEXT: [[_MSOR:%.*]] = or i1 [[_MSCMP]], [[_MSCMP1]]
874+
; CHECK-NEXT: [[TMP4:%.*]] = bitcast i32 [[X3]] to <32 x i1>
875+
; CHECK-NEXT: [[TMP5:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer
876+
; CHECK-NEXT: [[TMP6:%.*]] = sext <32 x i1> [[TMP5]] to <32 x i16>
877+
; CHECK-NEXT: [[TMP7:%.*]] = select <32 x i1> [[TMP4]], <32 x i16> [[TMP6]], <32 x i16> [[TMP2]]
880878
; CHECK-NEXT: [[_MSCMP2:%.*]] = icmp ne i32 [[TMP3]], 0
881-
; CHECK-NEXT: [[_MSOR3:%.*]] = or i1 [[_MSOR]], [[_MSCMP2]]
882-
; CHECK-NEXT: br i1 [[_MSOR3]], label %[[BB6:.*]], label %[[BB7:.*]], !prof [[PROF1]]
883-
; CHECK: [[BB6]]:
879+
; CHECK-NEXT: br i1 [[_MSCMP2]], label %[[BB8:.*]], label %[[BB9:.*]], !prof [[PROF1]]
880+
; CHECK: [[BB8]]:
884881
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR8]]
885882
; CHECK-NEXT: unreachable
886-
; CHECK: [[BB7]]:
883+
; CHECK: [[BB9]]:
887884
; CHECK-NEXT: [[RES:%.*]] = call <32 x half> @llvm.x86.avx512fp16.mask.rndscale.ph.512(<32 x half> [[X0]], i32 8, <32 x half> [[X2]], i32 [[X3]], i32 4)
888-
; CHECK-NEXT: [[TMP8:%.*]] = bitcast <32 x i16> [[TMP1]] to i512
889-
; CHECK-NEXT: [[_MSCMP4:%.*]] = icmp ne i512 [[TMP8]], 0
890-
; CHECK-NEXT: [[TMP9:%.*]] = bitcast <32 x i16> [[TMP2]] to i512
891-
; CHECK-NEXT: [[_MSCMP5:%.*]] = icmp ne i512 [[TMP9]], 0
892-
; CHECK-NEXT: [[_MSOR6:%.*]] = or i1 [[_MSCMP4]], [[_MSCMP5]]
893-
; CHECK-NEXT: br i1 [[_MSOR6]], label %[[BB10:.*]], label %[[BB11:.*]], !prof [[PROF1]]
894-
; CHECK: [[BB10]]:
895-
; CHECK-NEXT: call void @__msan_warning_noreturn() #[[ATTR8]]
896-
; CHECK-NEXT: unreachable
897-
; CHECK: [[BB11]]:
885+
; CHECK-NEXT: [[TMP10:%.*]] = icmp ne <32 x i16> [[TMP1]], zeroinitializer
886+
; CHECK-NEXT: [[TMP11:%.*]] = sext <32 x i1> [[TMP10]] to <32 x i16>
887+
; CHECK-NEXT: [[TMP12:%.*]] = select <32 x i1> splat (i1 true), <32 x i16> [[TMP11]], <32 x i16> [[TMP2]]
898888
; CHECK-NEXT: [[RES1:%.*]] = call <32 x half> @llvm.x86.avx512fp16.mask.rndscale.ph.512(<32 x half> [[X0]], i32 4, <32 x half> [[X2]], i32 -1, i32 8)
889+
; CHECK-NEXT: [[_MSPROP:%.*]] = or <32 x i16> [[TMP7]], [[TMP12]]
899890
; CHECK-NEXT: [[RES2:%.*]] = fadd <32 x half> [[RES]], [[RES1]]
900-
; CHECK-NEXT: store <32 x i16> zeroinitializer, ptr @__msan_retval_tls, align 8
891+
; CHECK-NEXT: store <32 x i16> [[_MSPROP]], ptr @__msan_retval_tls, align 8
901892
; CHECK-NEXT: ret <32 x half> [[RES2]]
902893
;
903894
%res = call <32 x half> @llvm.x86.avx512fp16.mask.rndscale.ph.512(<32 x half> %x0, i32 8, <32 x half> %x2, i32 %x3, i32 4)

0 commit comments

Comments
 (0)