Skip to content

Commit da0da2b

Browse files
authored
[DevMSAN] Handle builtins related to bf16 and complex operations (#20094)
1 parent ad08ff2 commit da0da2b

File tree

3 files changed

+116
-10
lines changed

3 files changed

+116
-10
lines changed

libdevice/sanitizer/msan_rtl.cpp

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -814,4 +814,42 @@ __msan_unpoison_strided_copy(uptr dest, uint32_t dest_as, uptr src,
814814
"__msan_unpoison_strided_copy"));
815815
}
816816

817+
static __SYCL_CONSTANT__ const char __msan_print_copy_unsupport_type[] =
818+
"[kernel] __msan_unpoison_copy: unsupported type(%d <- %d)\n";
819+
820+
DEVICE_EXTERN_C_NOINLINE void __msan_unpoison_copy(uptr dst, uint32_t dst_as,
821+
uptr src, uint32_t src_as,
822+
uint32_t dst_element_size,
823+
uint32_t src_element_size,
824+
uptr counts) {
825+
if (!GetMsanLaunchInfo)
826+
return;
827+
828+
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_beg, "__msan_unpoison_copy"));
829+
830+
uptr shadow_dst = MemToShadow(dst, dst_as);
831+
if (shadow_dst != GetMsanLaunchInfo->CleanShadow) {
832+
uptr shadow_src = MemToShadow(src, src_as);
833+
834+
if (dst_element_size == 1 && src_element_size == 1) {
835+
Memcpy<__SYCL_GLOBAL__ int8_t *, __SYCL_GLOBAL__ int8_t *>(
836+
(__SYCL_GLOBAL__ int8_t *)shadow_dst,
837+
(__SYCL_GLOBAL__ int8_t *)shadow_src, counts);
838+
} else if (dst_element_size == 4 && src_element_size == 2) {
839+
Memcpy<__SYCL_GLOBAL__ int32_t *, __SYCL_GLOBAL__ int16_t *>(
840+
(__SYCL_GLOBAL__ int32_t *)shadow_dst,
841+
(__SYCL_GLOBAL__ int16_t *)shadow_src, counts);
842+
} else if (dst_element_size == 2 && src_element_size == 4) {
843+
Memcpy<__SYCL_GLOBAL__ int16_t *, __SYCL_GLOBAL__ int32_t *>(
844+
(__SYCL_GLOBAL__ int16_t *)shadow_dst,
845+
(__SYCL_GLOBAL__ int32_t *)shadow_src, counts);
846+
} else {
847+
__spirv_ocl_printf(__msan_print_copy_unsupport_type, dst_element_size,
848+
src_element_size);
849+
}
850+
}
851+
852+
MSAN_DEBUG(__spirv_ocl_printf(__msan_print_func_end, "__msan_unpoison_copy"));
853+
}
854+
817855
#endif // __SPIR__ || __SPIRV__

llvm/lib/Transforms/Instrumentation/MemorySanitizer.cpp

Lines changed: 51 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -859,6 +859,7 @@ class MemorySanitizerOnSpirv {
859859
FunctionCallee MsanUnpoisonStackFunc;
860860
FunctionCallee MsanUnpoisonShadowFunc;
861861
FunctionCallee MsanSetPrivateBaseFunc;
862+
FunctionCallee MsanUnpoisonCopyFunc;
862863
FunctionCallee MsanUnpoisonStridedCopyFunc;
863864
};
864865

@@ -966,6 +967,18 @@ void MemorySanitizerOnSpirv::initializeCallbacks() {
966967
M.getOrInsertFunction("__msan_set_private_base", IRB.getVoidTy(),
967968
PointerType::get(C, kSpirOffloadPrivateAS));
968969

970+
// __msan_unpoison_copy(
971+
// uptr dest, uint32_t dest_as,
972+
// uptr src, uint32_t src_as,
973+
// uint32_t dst_element_size,
974+
// uint32_t src_element_size,
975+
// uptr counts,
976+
// )
977+
MsanUnpoisonCopyFunc = M.getOrInsertFunction(
978+
"__msan_unpoison_copy", IRB.getVoidTy(), IntptrTy, IRB.getInt32Ty(),
979+
IntptrTy, IRB.getInt32Ty(), IRB.getInt32Ty(), IRB.getInt32Ty(),
980+
IRB.getInt64Ty());
981+
969982
// __msan_unpoison_strided_copy(
970983
// uptr dest, uint32_t dest_as,
971984
// uptr src, uint32_t src_as,
@@ -7024,24 +7037,53 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
70247037
IRB.getInt32(Src->getType()->getPointerAddressSpace()),
70257038
IRB.getInt32(ElementSize), NumElements, Stride});
70267039
} else if (FuncName.contains(
7027-
"__sycl_getComposite2020SpecConstantValue")) {
7040+
"__sycl_getComposite2020SpecConstantValue") ||
7041+
FuncName.contains("clog")) {
70287042
// clang-format off
7029-
// Handle builtin functions like "_Z40__sycl_getComposite2020SpecConstantValue"
7043+
// Handle builtin functions which have sret arguments.
70307044
// Structs which are larger than 64b will be returned via sret arguments
70317045
// and will be initialized inside the function. So we need to unpoison
70327046
// the sret arguments.
70337047
// clang-format on
70347048
if (Func->hasStructRetAttr()) {
70357049
Type *SCTy = Func->getParamStructRetType(0);
70367050
unsigned Size = Func->getDataLayout().getTypeStoreSize(SCTy);
7037-
auto *Addr = CB.getArgOperand(0);
7038-
IRB.CreateCall(
7039-
MS.Spirv.MsanUnpoisonShadowFunc,
7040-
{IRB.CreatePointerCast(Addr, MS.Spirv.IntptrTy),
7041-
ConstantInt::get(MS.Spirv.Int32Ty,
7042-
Addr->getType()->getPointerAddressSpace()),
7043-
ConstantInt::get(MS.Spirv.IntptrTy, Size)});
7051+
if (FuncName.contains("clog")) {
7052+
auto *Dest = CB.getArgOperand(0);
7053+
auto *Src = CB.getArgOperand(1);
7054+
IRB.CreateCall(
7055+
MS.Spirv.MsanUnpoisonCopyFunc,
7056+
{IRB.CreatePointerCast(Dest, MS.Spirv.IntptrTy),
7057+
IRB.getInt32(Dest->getType()->getPointerAddressSpace()),
7058+
IRB.CreatePointerCast(Src, MS.Spirv.IntptrTy),
7059+
IRB.getInt32(Src->getType()->getPointerAddressSpace()),
7060+
IRB.getInt32(1), IRB.getInt32(1),
7061+
ConstantInt::get(MS.Spirv.IntptrTy, Size)});
7062+
} else {
7063+
auto *Addr = CB.getArgOperand(0);
7064+
IRB.CreateCall(
7065+
MS.Spirv.MsanUnpoisonShadowFunc,
7066+
{IRB.CreatePointerCast(Addr, MS.Spirv.IntptrTy),
7067+
ConstantInt::get(MS.Spirv.Int32Ty,
7068+
Addr->getType()->getPointerAddressSpace()),
7069+
ConstantInt::get(MS.Spirv.IntptrTy, Size)});
7070+
}
70447071
}
7072+
} else if (FuncName.contains("__devicelib_ConvertBF16ToFINTELVec") ||
7073+
FuncName.contains("__devicelib_ConvertFToBF16INTELVec")) {
7074+
size_t NumElements;
7075+
bool IsBF16ToF = FuncName.contains("BF16ToF");
7076+
FuncName.take_back().getAsInteger(10, NumElements);
7077+
auto *Src = CB.getArgOperand(0);
7078+
auto *Dest = CB.getArgOperand(1);
7079+
IRB.CreateCall(
7080+
MS.Spirv.MsanUnpoisonCopyFunc,
7081+
{IRB.CreatePointerCast(Dest, MS.Spirv.IntptrTy),
7082+
IRB.getInt32(Dest->getType()->getPointerAddressSpace()),
7083+
IRB.CreatePointerCast(Src, MS.Spirv.IntptrTy),
7084+
IRB.getInt32(Src->getType()->getPointerAddressSpace()),
7085+
IRB.getInt32(IsBF16ToF ? 4 : 2), IRB.getInt32(IsBF16ToF ? 2 : 4),
7086+
ConstantInt::get(MS.Spirv.IntptrTy, NumElements)});
70457087
}
70467088
}
70477089
}

llvm/test/Instrumentation/MemorySanitizer/SPIRV/spirv_groupasynccopy.ll renamed to llvm/test/Instrumentation/MemorySanitizer/SPIRV/spirv_builtins.ll

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,8 +7,9 @@ declare spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyiPU3AS3iPU3AS
77
declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS1iPU3AS3Kimm9ocl_event(i32, ptr addrspace(1), ptr addrspace(3), i64, i64, target("spirv.Event"))
88
declare dso_local spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32, ptr addrspace(3), ptr addrspace(1), i64, i64, target("spirv.Event"))
99

10-
define spir_kernel void @kernel(ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc) sanitize_memory {
10+
define spir_kernel void @kernel1(ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc) sanitize_memory {
1111
entry:
12+
; CHECK-LABEL: define spir_kernel void @kernel1
1213
; CHECK: @__msan_barrier()
1314
; CHECK: [[REG1:%[0-9]+]] = ptrtoint ptr addrspace(3) %_arg_localAcc to i64
1415
; CHECK-NEXT: [[REG2:%[0-9]+]] = ptrtoint ptr addrspace(1) %_arg_globalAcc to i64
@@ -21,3 +22,28 @@ entry:
2122
%copy3 = call spir_func target("spirv.Event") @_Z22__spirv_GroupAsyncCopyjPU3AS3Dv4_aPU3AS1KS_mm9ocl_event(i32 2, ptr addrspace(3) %_arg_localAcc, ptr addrspace(1) %_arg_globalAcc, i64 512, i64 1, target("spirv.Event") zeroinitializer)
2223
ret void
2324
}
25+
26+
define spir_kernel void @kernel2(ptr addrspace(4) %tmp.ascast.i.i.i, ptr %byval-temp.i.i.i) {
27+
entry:
28+
; CHECK-LABEL: define spir_kernel void @kernel2
29+
; CHECK: [[REG3:%.*]] = ptrtoint ptr addrspace(4) [[REG4:%.*]] to i64
30+
; CHECK-NEXT: [[REG5:%.*]] = ptrtoint ptr [[REG6:%.*]] to i64
31+
; CHECK-NEXT: call void @__msan_unpoison_copy(i64 [[REG3]], i32 4, i64 [[REG5]], i32 0, i32 1, i32 1, i64 8)
32+
; CHECK-NEXT: call spir_func void @clogf(ptr addrspace(4) dead_on_unwind writable sret({ float, float }) align 4 [[REG4]], ptr noundef nonnull byval({ float, float }) align 4 [[REG6]])
33+
call spir_func void @clogf(ptr addrspace(4) dead_on_unwind writable sret({ float, float }) align 4 %tmp.ascast.i.i.i, ptr noundef nonnull byval({ float, float }) align 4 %byval-temp.i.i.i)
34+
ret void
35+
}
36+
37+
define spir_kernel void @kernel3(ptr addrspace(4) %0) {
38+
entry:
39+
; CHECK-LABEL: define spir_kernel void @kernel3
40+
; CHECK: [[REG7:%.*]] = ptrtoint ptr addrspace(4) [[REG8:%.*]] to i64
41+
; CHECK-NEXT: [[REG9:%.*]] = ptrtoint ptr addrspace(4) [[REG10:%.*]] to i64
42+
; CHECK-NEXT: call void @__msan_unpoison_copy(i64 [[REG7]], i32 4, i64 [[REG9]], i32 4, i32 4, i32 2, i64 4)
43+
; CHECK-NEXT: call spir_func void @__devicelib_ConvertBF16ToFINTELVec4(ptr addrspace(4) noundef [[REG10]], ptr addrspace(4) noundef [[REG8]])
44+
call spir_func void @__devicelib_ConvertBF16ToFINTELVec4(ptr addrspace(4) noundef %0, ptr addrspace(4) noundef %0)
45+
ret void
46+
}
47+
48+
declare spir_func void @clogf(ptr addrspace(4) sret({ float, float }), ptr)
49+
declare spir_func void @__devicelib_ConvertBF16ToFINTELVec4(ptr addrspace(4), ptr addrspace(4))

0 commit comments

Comments
 (0)