Skip to content

Commit 126338c

Browse files
authored
[SYCL] Fix bug with introp buffer (#2637)
Fixed a problem which happened when a buffer constructed using interaperablity constructor with sycl::context A is used in a different context.
1 parent f189e41 commit 126338c

File tree

7 files changed

+93
-26
lines changed

7 files changed

+93
-26
lines changed

sycl/include/CL/sycl/detail/sycl_mem_obj_i.hpp

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,10 @@ class SYCLMemObjI {
6161
// Returns size of object in bytes
6262
virtual size_t getSize() const = 0;
6363

64+
// Returns the context which is passed if a memory object is created using
65+
// interoperability constructor, nullptr otherwise.
66+
virtual ContextImplPtr getInteropContext() const = 0;
67+
6468
protected:
6569
// Pointer to the record that contains the memory commands. This is managed
6670
// by the scheduler.

sycl/include/CL/sycl/detail/sycl_mem_obj_t.hpp

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,8 @@ class __SYCL_EXPORT SYCLMemObjT : public SYCLMemObjI {
287287

288288
DLL_LOCAL MemObjType getType() const override { return UNDEFINED; }
289289

290+
ContextImplPtr getInteropContext() const override { return MInteropContext; }
291+
290292
protected:
291293
// Allocator used for allocation memory on host.
292294
unique_ptr_class<SYCLMemObjAllocator> MAllocator;

sycl/source/detail/memory_manager.cpp

Lines changed: 9 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -125,21 +125,16 @@ void *MemoryManager::allocateInteropMemObject(
125125
ContextImplPtr TargetContext, void *UserPtr,
126126
const EventImplPtr &InteropEvent, const ContextImplPtr &InteropContext,
127127
const sycl::property_list &, RT::PiEvent &OutEventToWait) {
128-
// If memory object is created with interop c'tor.
129-
// Return cl_mem as is if contexts match.
130-
if (TargetContext == InteropContext) {
131-
OutEventToWait = InteropEvent->getHandleRef();
132-
// Retain the event since it will be released during alloca command
133-
// destruction
134-
if (nullptr != OutEventToWait) {
135-
const detail::plugin &Plugin = InteropEvent->getPlugin();
136-
Plugin.call<PiApiKind::piEventRetain>(OutEventToWait);
137-
}
138-
return UserPtr;
128+
// If memory object is created with interop c'tor return cl_mem as is.
129+
assert(TargetContext == InteropContext && "Expected matching contexts");
130+
OutEventToWait = InteropEvent->getHandleRef();
131+
// Retain the event since it will be released during alloca command
132+
// destruction
133+
if (nullptr != OutEventToWait) {
134+
const detail::plugin &Plugin = InteropEvent->getPlugin();
135+
Plugin.call<PiApiKind::piEventRetain>(OutEventToWait);
139136
}
140-
// Allocate new cl_mem and initialize from user provided one.
141-
assert(false && "Not implemented");
142-
return nullptr;
137+
return UserPtr;
143138
}
144139

145140
void *MemoryManager::allocateImageObject(ContextImplPtr TargetContext,

sycl/source/detail/queue_impl.hpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,7 +142,7 @@ class queue_impl {
142142

143143
const plugin &getPlugin() const { return MContext->getPlugin(); }
144144

145-
ContextImplPtr getContextImplPtr() const { return MContext; }
145+
const ContextImplPtr &getContextImplPtr() const { return MContext; }
146146

147147
/// \return an associated SYCL device.
148148
device get_device() const { return createSyclObjFromImpl<device>(MDevice); }

sycl/source/detail/scheduler/graph_builder.cpp

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -176,8 +176,29 @@ Scheduler::GraphBuilder::getOrInsertMemObjRecord(const QueueImplPtr &Queue,
176176
--(Dependency->MLeafCounter);
177177
};
178178

179-
MemObject->MRecord.reset(new MemObjRecord{Queue->getContextImplPtr(),
180-
LeafLimit, AllocateDependency});
179+
const ContextImplPtr &InteropCtxPtr = Req->MSYCLMemObj->getInteropContext();
180+
if (InteropCtxPtr) {
181+
// The memory object has been constructed using interoperability constructor
182+
// which means that there is already an allocation(cl_mem) in some context.
183+
// Registering this allocation in the SYCL graph.
184+
185+
sycl::vector_class<sycl::device> Devices =
186+
InteropCtxPtr->get_info<info::context::devices>();
187+
assert(Devices.size() != 0);
188+
DeviceImplPtr Dev = detail::getSyclObjImpl(Devices[0]);
189+
190+
// Since all the Scheduler commands require queue but we have only context
191+
// here, we need to create a dummy queue bound to the context and one of the
192+
// devices from the context.
193+
QueueImplPtr InteropQueuePtr{new detail::queue_impl{
194+
Dev, InteropCtxPtr, /*AsyncHandler=*/{}, /*PropertyList=*/{}}};
195+
196+
MemObject->MRecord.reset(
197+
new MemObjRecord{InteropCtxPtr, LeafLimit, AllocateDependency});
198+
getOrCreateAllocaForReq(MemObject->MRecord.get(), Req, InteropQueuePtr);
199+
} else
200+
MemObject->MRecord.reset(new MemObjRecord{Queue->getContextImplPtr(),
201+
LeafLimit, AllocateDependency});
181202

182203
MMemObjs.push_back(MemObject);
183204
return MemObject->MRecord.get();

sycl/test/basic_tests/buffer/buffer_interop.cpp

Lines changed: 53 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,47 @@
1313
//
1414
//===----------------------------------------------------------------------===//
1515
#include <CL/sycl.hpp>
16+
1617
#include <cassert>
18+
#include <iostream>
1719
#include <memory>
1820

1921
using namespace cl::sycl;
2022

2123
int main() {
2224
bool Failed = false;
25+
{
26+
sycl::context Ctx;
27+
cl_context OCLCtx = Ctx.get();
28+
29+
cl_int Error = CL_SUCCESS;
30+
cl_mem OCLBuf =
31+
clCreateBuffer(OCLCtx, CL_MEM_READ_WRITE, sizeof(int), nullptr, &Error);
32+
CHECK_OCL_CODE(Error);
33+
Error = clReleaseContext(OCLCtx);
34+
CHECK_OCL_CODE(Error);
35+
36+
sycl::buffer<int, 1> Buf{OCLBuf, Ctx};
37+
38+
sycl::queue Q;
39+
40+
if (Ctx == Q.get_context()) {
41+
std::cerr << "Expected different contexts" << std::endl;
42+
Failed = true;
43+
}
44+
45+
Q.submit([&](sycl::handler &CGH) {
46+
auto Acc = Buf.get_access<access::mode::write>(CGH);
47+
CGH.single_task<class BufferInterop_DifferentContext>(
48+
[=]() { Acc[0] = 42; });
49+
});
50+
51+
auto Acc = Buf.get_access<sycl::access::mode::read>();
52+
if (Acc[0] != 42) {
53+
std::cerr << "Result is incorrect" << std::endl;
54+
Failed = true;
55+
}
56+
}
2357
{
2458
constexpr size_t Size = 32;
2559
int Init[Size] = {5};
@@ -29,9 +63,11 @@ int main() {
2963

3064
queue MyQueue;
3165

32-
cl_mem OpenCLBuffer = clCreateBuffer(
33-
MyQueue.get_context().get(), CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR,
34-
Size * sizeof(int), Init, &Error);
66+
cl_context OCLCtx = MyQueue.get_context().get();
67+
68+
cl_mem OpenCLBuffer =
69+
clCreateBuffer(OCLCtx, CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR,
70+
Size * sizeof(int), Init, &Error);
3571
CHECK_OCL_CODE(Error);
3672
buffer<int, 1> Buffer{OpenCLBuffer, MyQueue.get_context()};
3773

@@ -79,6 +115,8 @@ int main() {
79115
Failed = true;
80116
}
81117
}
118+
Error = clReleaseContext(OCLCtx);
119+
CHECK_OCL_CODE(Error);
82120
}
83121
// Check set_final_data
84122
{
@@ -88,10 +126,11 @@ int main() {
88126
cl_int Error = CL_SUCCESS;
89127

90128
queue MyQueue;
129+
cl_context OCLCtx = MyQueue.get_context().get();
91130

92-
cl_mem OpenCLBuffer = clCreateBuffer(
93-
MyQueue.get_context().get(), CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR,
94-
Size * sizeof(int), Init, &Error);
131+
cl_mem OpenCLBuffer =
132+
clCreateBuffer(OCLCtx, CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR,
133+
Size * sizeof(int), Init, &Error);
95134
CHECK_OCL_CODE(Error);
96135
{
97136
buffer<int, 1> Buffer{OpenCLBuffer, MyQueue.get_context()};
@@ -113,6 +152,8 @@ int main() {
113152
Failed = true;
114153
}
115154
}
155+
Error = clReleaseContext(OCLCtx);
156+
CHECK_OCL_CODE(Error);
116157
}
117158
// Check host accessor
118159
{
@@ -121,10 +162,11 @@ int main() {
121162
cl_int Error = CL_SUCCESS;
122163

123164
queue MyQueue;
165+
cl_context OCLCtx = MyQueue.get_context().get();
124166

125-
cl_mem OpenCLBuffer = clCreateBuffer(
126-
MyQueue.get_context().get(), CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR,
127-
Size * sizeof(int), Init, &Error);
167+
cl_mem OpenCLBuffer =
168+
clCreateBuffer(OCLCtx, CL_MEM_READ_WRITE | CL_MEM_COPY_HOST_PTR,
169+
Size * sizeof(int), Init, &Error);
128170
CHECK_OCL_CODE(Error);
129171
buffer<int, 1> Buffer{OpenCLBuffer, MyQueue.get_context()};
130172

@@ -144,6 +186,8 @@ int main() {
144186
}
145187
Error = clReleaseMemObject(OpenCLBuffer);
146188
CHECK_OCL_CODE(Error);
189+
Error = clReleaseContext(OCLCtx);
190+
CHECK_OCL_CODE(Error);
147191
}
148192
// Check interop constructor event
149193
{

sycl/unittests/scheduler/LinkedAllocaDependencies.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ class MemObjMock : public cl::sycl::detail::SYCLMemObjI {
3333
void releaseMem(ContextImplPtr, void *) {}
3434
void releaseHostMem(void *) {}
3535
size_t getSize() const override { return 10; }
36+
detail::ContextImplPtr getInteropContext() const override { return nullptr; }
3637
};
3738

3839
TEST_F(SchedulerTest, LinkedAllocaDependencies) {

0 commit comments

Comments
 (0)