Skip to content

Commit c34a4fc

Browse files
committed
[DevMSAN] Propagate shadow memory in buffer related APIs
1 parent ad288bb commit c34a4fc

File tree

4 files changed

+119
-57
lines changed

4 files changed

+119
-57
lines changed

source/loader/layers/sanitizer/msan/msan_buffer.cpp

Lines changed: 56 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -48,22 +48,67 @@ ur_result_t EnqueueMemCopyRectHelper(
4848
char *DstOrigin = pDst + DstOffset.x + DstRowPitch * DstOffset.y +
4949
DstSlicePitch * DstOffset.z;
5050

51+
const bool IsDstDeviceUSM = getMsanInterceptor()
52+
->findAllocInfoByAddress((uptr)DstOrigin)
53+
.has_value();
54+
const bool IsSrcDeviceUSM = getMsanInterceptor()
55+
->findAllocInfoByAddress((uptr)SrcOrigin)
56+
.has_value();
57+
58+
ur_device_handle_t Device = GetDevice(Queue);
59+
std::shared_ptr<DeviceInfo> DeviceInfo =
60+
getMsanInterceptor()->getDeviceInfo(Device);
5161
std::vector<ur_event_handle_t> Events;
52-
Events.reserve(Region.depth);
62+
5363
// For now, USM doesn't support 3D memory copy operation, so we can only
5464
// loop call 2D memory copy function to implement it.
5565
for (size_t i = 0; i < Region.depth; i++) {
5666
ur_event_handle_t NewEvent{};
5767
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D(
58-
Queue, Blocking, DstOrigin + (i * DstSlicePitch), DstRowPitch,
68+
Queue, false, DstOrigin + (i * DstSlicePitch), DstRowPitch,
5969
SrcOrigin + (i * SrcSlicePitch), SrcRowPitch, Region.width,
6070
Region.height, NumEventsInWaitList, EventWaitList, &NewEvent));
61-
6271
Events.push_back(NewEvent);
72+
73+
// Update shadow memory
74+
if (IsDstDeviceUSM && IsSrcDeviceUSM) {
75+
NewEvent = nullptr;
76+
uptr DstShadowAddr = DeviceInfo->Shadow->MemToShadow(
77+
(uptr)DstOrigin + (i * DstSlicePitch));
78+
uptr SrcShadowAddr = DeviceInfo->Shadow->MemToShadow(
79+
(uptr)SrcOrigin + (i * SrcSlicePitch));
80+
getContext()->logger.always(
81+
"memcpy shadow, dst shadow {}, dst row pitch {}, src shadow "
82+
"{}, src row pitch {}, width {}, height {}",
83+
(void *)DstShadowAddr, DstRowPitch, (void *)SrcShadowAddr,
84+
SrcRowPitch, Region.width, Region.height);
85+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy2D(
86+
Queue, false, (void *)DstShadowAddr, DstRowPitch,
87+
(void *)SrcShadowAddr, SrcRowPitch, Region.width, Region.height,
88+
NumEventsInWaitList, EventWaitList, &NewEvent));
89+
Events.push_back(NewEvent);
90+
} else if (IsDstDeviceUSM && !IsSrcDeviceUSM) {
91+
NewEvent = nullptr;
92+
uptr DstShadowAddr = DeviceInfo->Shadow->MemToShadow(
93+
(uptr)DstOrigin + (i * DstSlicePitch));
94+
const char Val = 0;
95+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill2D(
96+
Queue, (void *)DstShadowAddr, DstRowPitch, 1, &Val,
97+
Region.width, Region.height, NumEventsInWaitList, EventWaitList,
98+
&NewEvent));
99+
Events.push_back(NewEvent);
100+
}
63101
}
64102

65-
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
66-
Queue, Events.size(), Events.data(), Event));
103+
if (Blocking) {
104+
UR_CALL(
105+
getContext()->urDdiTable.Event.pfnWait(Events.size(), &Events[0]));
106+
}
107+
108+
if (Event) {
109+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
110+
Queue, Events.size(), &Events[0], Event));
111+
}
67112

68113
return UR_RESULT_SUCCESS;
69114
}
@@ -112,6 +157,12 @@ ur_result_t MemBuffer::getHandle(ur_device_handle_t Device, char *&Handle) {
112157
Size, HostPtr, this);
113158
return URes;
114159
}
160+
161+
// Update shadow memory
162+
std::shared_ptr<DeviceInfo> DeviceInfo =
163+
getMsanInterceptor()->getDeviceInfo(Device);
164+
UR_CALL(DeviceInfo->Shadow->EnqueuePoisonShadow(
165+
Queue, (uptr)Allocation, Size, 0));
115166
}
116167
}
117168

source/loader/layers/sanitizer/msan/msan_ddi.cpp

Lines changed: 57 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,6 @@ ur_result_t setupContext(ur_context_handle_t Context, uint32_t numDevices,
4545
UR_CALL(DI->allocShadowMemory(Context));
4646
}
4747
CI->DeviceList.emplace_back(hDevice);
48-
CI->AllocInfosMap[hDevice];
4948
}
5049
return UR_RESULT_SUCCESS;
5150
}
@@ -517,6 +516,12 @@ ur_result_t urMemBufferCreate(
517516
UR_CALL(pMemBuffer->getHandle(hDevice, Handle));
518517
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
519518
InternalQueue, true, Handle, Host, size, 0, nullptr, nullptr));
519+
520+
// Update shadow memory
521+
std::shared_ptr<DeviceInfo> DeviceInfo =
522+
getMsanInterceptor()->getDeviceInfo(hDevice);
523+
UR_CALL(DeviceInfo->Shadow->EnqueuePoisonShadow(
524+
InternalQueue, (uptr)Handle, size, 0));
520525
}
521526
}
522527

@@ -732,10 +737,25 @@ ur_result_t urEnqueueMemBufferWrite(
732737
if (auto MemBuffer = getMsanInterceptor()->getMemBuffer(hBuffer)) {
733738
ur_device_handle_t Device = GetDevice(hQueue);
734739
char *pDst = nullptr;
740+
ur_event_handle_t Events[2];
735741
UR_CALL(MemBuffer->getHandle(Device, pDst));
736742
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
737743
hQueue, blockingWrite, pDst + offset, pSrc, size,
738-
numEventsInWaitList, phEventWaitList, phEvent));
744+
numEventsInWaitList, phEventWaitList, &Events[0]));
745+
746+
// Update shadow memory
747+
std::shared_ptr<DeviceInfo> DeviceInfo =
748+
getMsanInterceptor()->getDeviceInfo(Device);
749+
const char Val = 0;
750+
uptr ShadowAddr = DeviceInfo->Shadow->MemToShadow((uptr)pDst + offset);
751+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill(
752+
hQueue, (void *)ShadowAddr, 1, &Val, size, numEventsInWaitList,
753+
phEventWaitList, &Events[1]));
754+
755+
if (phEvent) {
756+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
757+
hQueue, 2, Events, phEvent));
758+
}
739759
} else {
740760
UR_CALL(pfnMemBufferWrite(hQueue, hBuffer, blockingWrite, offset, size,
741761
pSrc, numEventsInWaitList, phEventWaitList,
@@ -895,15 +915,32 @@ ur_result_t urEnqueueMemBufferCopy(
895915

896916
if (SrcBuffer && DstBuffer) {
897917
ur_device_handle_t Device = GetDevice(hQueue);
918+
std::shared_ptr<DeviceInfo> DeviceInfo =
919+
getMsanInterceptor()->getDeviceInfo(Device);
898920
char *SrcHandle = nullptr;
899921
UR_CALL(SrcBuffer->getHandle(Device, SrcHandle));
900922

901923
char *DstHandle = nullptr;
902924
UR_CALL(DstBuffer->getHandle(Device, DstHandle));
903925

926+
ur_event_handle_t Events[2];
904927
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
905928
hQueue, false, DstHandle + dstOffset, SrcHandle + srcOffset, size,
906-
numEventsInWaitList, phEventWaitList, phEvent));
929+
numEventsInWaitList, phEventWaitList, &Events[0]));
930+
931+
// Update shadow memory
932+
uptr DstShadowAddr =
933+
DeviceInfo->Shadow->MemToShadow((uptr)DstHandle + dstOffset);
934+
uptr SrcShadowAddr =
935+
DeviceInfo->Shadow->MemToShadow((uptr)SrcHandle + srcOffset);
936+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMMemcpy(
937+
hQueue, false, (void *)DstShadowAddr, (void *)SrcShadowAddr, size,
938+
numEventsInWaitList, phEventWaitList, &Events[1]));
939+
940+
if (phEvent) {
941+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
942+
hQueue, 2, Events, phEvent));
943+
}
907944
} else {
908945
UR_CALL(pfnMemBufferCopy(hQueue, hBufferSrc, hBufferDst, srcOffset,
909946
dstOffset, size, numEventsInWaitList,
@@ -1002,11 +1039,27 @@ ur_result_t urEnqueueMemBufferFill(
10021039

10031040
if (auto MemBuffer = getMsanInterceptor()->getMemBuffer(hBuffer)) {
10041041
char *Handle = nullptr;
1042+
ur_event_handle_t Events[2];
10051043
ur_device_handle_t Device = GetDevice(hQueue);
10061044
UR_CALL(MemBuffer->getHandle(Device, Handle));
10071045
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill(
10081046
hQueue, Handle + offset, patternSize, pPattern, size,
1009-
numEventsInWaitList, phEventWaitList, phEvent));
1047+
numEventsInWaitList, phEventWaitList, &Events[0]));
1048+
1049+
// Update shadow memory
1050+
std::shared_ptr<DeviceInfo> DeviceInfo =
1051+
getMsanInterceptor()->getDeviceInfo(Device);
1052+
const char Val = 0;
1053+
uptr ShadowAddr =
1054+
DeviceInfo->Shadow->MemToShadow((uptr)Handle + offset);
1055+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnUSMFill(
1056+
hQueue, (void *)ShadowAddr, 1, &Val, size, numEventsInWaitList,
1057+
phEventWaitList, &Events[1]));
1058+
1059+
if (phEvent) {
1060+
UR_CALL(getContext()->urDdiTable.Enqueue.pfnEventsWait(
1061+
hQueue, 2, Events, phEvent));
1062+
}
10101063
} else {
10111064
UR_CALL(pfnMemBufferFill(hQueue, hBuffer, pPattern, patternSize, offset,
10121065
size, numEventsInWaitList, phEventWaitList,

source/loader/layers/sanitizer/msan/msan_interceptor.cpp

Lines changed: 6 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -70,15 +70,17 @@ ur_result_t MsanInterceptor::allocateMemory(ur_context_handle_t Context,
7070

7171
AI->print();
7272

73-
// For updating shadow memory
74-
ContextInfo->insertAllocInfo({Device}, AI);
75-
7673
// For memory release
7774
{
7875
std::scoped_lock<ur_shared_mutex> Guard(m_AllocationMapMutex);
79-
m_AllocationMap.emplace(AI->AllocBegin, std::move(AI));
76+
m_AllocationMap.emplace(AI->AllocBegin, AI);
8077
}
8178

79+
// Update shadow memory
80+
ManagedQueue InternalQueue{Context, Device};
81+
UR_CALL(DeviceInfo->Shadow->EnqueuePoisonShadow(
82+
InternalQueue, (uptr)Allocated, Size, 0xff));
83+
8284
return UR_RESULT_SUCCESS;
8385
}
8486

@@ -98,8 +100,6 @@ ur_result_t MsanInterceptor::preLaunchKernel(ur_kernel_handle_t Kernel,
98100

99101
UR_CALL(prepareLaunch(DeviceInfo, InternalQueue, Kernel, LaunchInfo));
100102

101-
UR_CALL(updateShadowMemory(ContextInfo, DeviceInfo, InternalQueue));
102-
103103
return UR_RESULT_SUCCESS;
104104
}
105105

@@ -124,29 +124,6 @@ ur_result_t MsanInterceptor::postLaunchKernel(ur_kernel_handle_t Kernel,
124124
return Result;
125125
}
126126

127-
ur_result_t
128-
MsanInterceptor::enqueueAllocInfo(std::shared_ptr<DeviceInfo> &DeviceInfo,
129-
ur_queue_handle_t Queue,
130-
std::shared_ptr<MsanAllocInfo> &AI) {
131-
return DeviceInfo->Shadow->EnqueuePoisonShadow(Queue, AI->AllocBegin,
132-
AI->AllocSize, 0xff);
133-
}
134-
135-
ur_result_t
136-
MsanInterceptor::updateShadowMemory(std::shared_ptr<ContextInfo> &ContextInfo,
137-
std::shared_ptr<DeviceInfo> &DeviceInfo,
138-
ur_queue_handle_t Queue) {
139-
auto &AllocInfos = ContextInfo->AllocInfosMap[DeviceInfo->Handle];
140-
std::scoped_lock<ur_shared_mutex> Guard(AllocInfos.Mutex);
141-
142-
for (auto &AI : AllocInfos.List) {
143-
UR_CALL(enqueueAllocInfo(DeviceInfo, Queue, AI));
144-
}
145-
AllocInfos.List.clear();
146-
147-
return UR_RESULT_SUCCESS;
148-
}
149-
150127
ur_result_t MsanInterceptor::registerProgram(ur_program_handle_t Program) {
151128
ur_result_t Result = UR_RESULT_SUCCESS;
152129

source/loader/layers/sanitizer/msan/msan_interceptor.hpp

Lines changed: 0 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -120,7 +120,6 @@ struct ContextInfo {
120120
std::atomic<int32_t> RefCount = 1;
121121

122122
std::vector<ur_device_handle_t> DeviceList;
123-
std::unordered_map<ur_device_handle_t, AllocInfoList> AllocInfosMap;
124123

125124
explicit ContextInfo(ur_context_handle_t Context) : Handle(Context) {
126125
[[maybe_unused]] auto Result =
@@ -129,15 +128,6 @@ struct ContextInfo {
129128
}
130129

131130
~ContextInfo();
132-
133-
void insertAllocInfo(const std::vector<ur_device_handle_t> &Devices,
134-
std::shared_ptr<MsanAllocInfo> &AI) {
135-
for (auto Device : Devices) {
136-
auto &AllocInfos = AllocInfosMap[Device];
137-
std::scoped_lock<ur_shared_mutex> Guard(AllocInfos.Mutex);
138-
AllocInfos.List.emplace_back(AI);
139-
}
140-
}
141131
};
142132

143133
struct USMLaunchInfo {
@@ -263,15 +253,6 @@ class MsanInterceptor {
263253
bool isNormalExit() { return m_NormalExit; }
264254

265255
private:
266-
ur_result_t
267-
updateShadowMemory(std::shared_ptr<msan::ContextInfo> &ContextInfo,
268-
std::shared_ptr<msan::DeviceInfo> &DeviceInfo,
269-
ur_queue_handle_t Queue);
270-
271-
ur_result_t enqueueAllocInfo(std::shared_ptr<msan::DeviceInfo> &DeviceInfo,
272-
ur_queue_handle_t Queue,
273-
std::shared_ptr<MsanAllocInfo> &AI);
274-
275256
/// Initialize Global Variables & Kernel Name at first Launch
276257
ur_result_t prepareLaunch(std::shared_ptr<msan::DeviceInfo> &DeviceInfo,
277258
ur_queue_handle_t Queue,

0 commit comments

Comments
 (0)