@@ -859,6 +859,7 @@ class MemorySanitizerOnSpirv {
859
859
FunctionCallee MsanUnpoisonStackFunc;
860
860
FunctionCallee MsanUnpoisonShadowFunc;
861
861
FunctionCallee MsanSetPrivateBaseFunc;
862
+ FunctionCallee MsanUnpoisonCopyFunc;
862
863
FunctionCallee MsanUnpoisonStridedCopyFunc;
863
864
};
864
865
@@ -966,6 +967,18 @@ void MemorySanitizerOnSpirv::initializeCallbacks() {
966
967
M.getOrInsertFunction (" __msan_set_private_base" , IRB.getVoidTy (),
967
968
PointerType::get (C, kSpirOffloadPrivateAS ));
968
969
970
+ // __msan_unpoison_copy(
971
+ // uptr dest, uint32_t dest_as,
972
+ // uptr src, uint32_t src_as,
973
+ // uint32_t dst_element_size,
974
+ // uint32_t src_element_size,
975
+ // uptr counts,
976
+ // )
977
+ MsanUnpoisonCopyFunc = M.getOrInsertFunction (
978
+ " __msan_unpoison_copy" , IRB.getVoidTy (), IntptrTy, IRB.getInt32Ty (),
979
+ IntptrTy, IRB.getInt32Ty (), IRB.getInt32Ty (), IRB.getInt32Ty (),
980
+ IRB.getInt64Ty ());
981
+
969
982
// __msan_unpoison_strided_copy(
970
983
// uptr dest, uint32_t dest_as,
971
984
// uptr src, uint32_t src_as,
@@ -7024,24 +7037,53 @@ struct MemorySanitizerVisitor : public InstVisitor<MemorySanitizerVisitor> {
7024
7037
IRB.getInt32 (Src->getType ()->getPointerAddressSpace ()),
7025
7038
IRB.getInt32 (ElementSize), NumElements, Stride});
7026
7039
} else if (FuncName.contains (
7027
- " __sycl_getComposite2020SpecConstantValue" )) {
7040
+ " __sycl_getComposite2020SpecConstantValue" ) ||
7041
+ FuncName.contains (" clog" )) {
7028
7042
// clang-format off
7029
- // Handle builtin functions like "_Z40__sycl_getComposite2020SpecConstantValue"
7043
+ // Handle builtin functions which have sret arguments.
7030
7044
// Structs which are larger than 64b will be returned via sret arguments
7031
7045
// and will be initialized inside the function. So we need to unpoison
7032
7046
// the sret arguments.
7033
7047
// clang-format on
7034
7048
if (Func->hasStructRetAttr ()) {
7035
7049
Type *SCTy = Func->getParamStructRetType (0 );
7036
7050
unsigned Size = Func->getDataLayout ().getTypeStoreSize (SCTy);
7037
- auto *Addr = CB.getArgOperand (0 );
7038
- IRB.CreateCall (
7039
- MS.Spirv .MsanUnpoisonShadowFunc ,
7040
- {IRB.CreatePointerCast (Addr, MS.Spirv .IntptrTy ),
7041
- ConstantInt::get (MS.Spirv .Int32Ty ,
7042
- Addr->getType ()->getPointerAddressSpace ()),
7043
- ConstantInt::get (MS.Spirv .IntptrTy , Size)});
7051
+ if (FuncName.contains (" clog" )) {
7052
+ auto *Dest = CB.getArgOperand (0 );
7053
+ auto *Src = CB.getArgOperand (1 );
7054
+ IRB.CreateCall (
7055
+ MS.Spirv .MsanUnpoisonCopyFunc ,
7056
+ {IRB.CreatePointerCast (Dest, MS.Spirv .IntptrTy ),
7057
+ IRB.getInt32 (Dest->getType ()->getPointerAddressSpace ()),
7058
+ IRB.CreatePointerCast (Src, MS.Spirv .IntptrTy ),
7059
+ IRB.getInt32 (Src->getType ()->getPointerAddressSpace ()),
7060
+ IRB.getInt32 (1 ), IRB.getInt32 (1 ),
7061
+ ConstantInt::get (MS.Spirv .IntptrTy , Size)});
7062
+ } else {
7063
+ auto *Addr = CB.getArgOperand (0 );
7064
+ IRB.CreateCall (
7065
+ MS.Spirv .MsanUnpoisonShadowFunc ,
7066
+ {IRB.CreatePointerCast (Addr, MS.Spirv .IntptrTy ),
7067
+ ConstantInt::get (MS.Spirv .Int32Ty ,
7068
+ Addr->getType ()->getPointerAddressSpace ()),
7069
+ ConstantInt::get (MS.Spirv .IntptrTy , Size)});
7070
+ }
7044
7071
}
7072
+ } else if (FuncName.contains (" __devicelib_ConvertBF16ToFINTELVec" ) ||
7073
+ FuncName.contains (" __devicelib_ConvertFToBF16INTELVec" )) {
7074
+ size_t NumElements;
7075
+ bool IsBF16ToF = FuncName.contains (" BF16ToF" );
7076
+ FuncName.take_back ().getAsInteger (10 , NumElements);
7077
+ auto *Src = CB.getArgOperand (0 );
7078
+ auto *Dest = CB.getArgOperand (1 );
7079
+ IRB.CreateCall (
7080
+ MS.Spirv .MsanUnpoisonCopyFunc ,
7081
+ {IRB.CreatePointerCast (Dest, MS.Spirv .IntptrTy ),
7082
+ IRB.getInt32 (Dest->getType ()->getPointerAddressSpace ()),
7083
+ IRB.CreatePointerCast (Src, MS.Spirv .IntptrTy ),
7084
+ IRB.getInt32 (Src->getType ()->getPointerAddressSpace ()),
7085
+ IRB.getInt32 (IsBF16ToF ? 4 : 2 ), IRB.getInt32 (IsBF16ToF ? 2 : 4 ),
7086
+ ConstantInt::get (MS.Spirv .IntptrTy , NumElements)});
7045
7087
}
7046
7088
}
7047
7089
}
0 commit comments