@@ -53,26 +53,48 @@ ze_result_t DriverHandleImp::createContext(const ze_context_desc_t *desc,
5353 }
5454 }
5555
56+ bool multiOsContextDriver = false ;
5657 for (auto devicePair : context->getDevices ()) {
5758 auto neoDevice = devicePair.second ->getNEODevice ();
59+ multiOsContextDriver |= devicePair.second ->isMultiDeviceCapable ();
5860 context->rootDeviceIndices .insert (neoDevice->getRootDeviceIndex ());
5961 context->deviceBitfields .insert ({neoDevice->getRootDeviceIndex (),
6062 neoDevice->getDeviceBitfield ()});
6163 }
6264
65+ if (this ->mainContext == nullptr ) {
66+ this ->mainContext = context;
67+
68+ if (this ->getMemoryManager () == nullptr ) {
69+ this ->setMemoryManager (context->getDevices ().begin ()->second ->getNEODevice ()->getMemoryManager ());
70+ }
71+
72+ this ->setSvmAllocsManager (new NEO::SVMAllocsManager (this ->getMemoryManager (), multiOsContextDriver));
73+
74+ this ->getMemoryManager ()->setForceNonSvmForExternalHostPtr (true );
75+
76+ if (NEO::DebugManager.flags .EnableHostPointerImport .get () == 1 ) {
77+ createHostPointerManager ();
78+ }
79+ }
80+
6381 return ZE_RESULT_SUCCESS;
6482}
6583
6684NEO::MemoryManager *DriverHandleImp::getMemoryManager () {
67- return this ->memoryManager ;
85+ return this ->mainContext -> getMemoryManager () ;
6886}
6987
7088void DriverHandleImp::setMemoryManager (NEO::MemoryManager *memoryManager) {
71- this ->memoryManager = memoryManager ;
89+ this ->mainContext -> setMemoryManager ( memoryManager) ;
7290}
7391
7492NEO::SVMAllocsManager *DriverHandleImp::getSvmAllocsManager () {
75- return this ->svmAllocsManager ;
93+ return this ->mainContext ->getSvmAllocsManager ();
94+ }
95+
96+ void DriverHandleImp::setSvmAllocsManager (NEO::SVMAllocsManager *svmManager) {
97+ this ->mainContext ->setSvmAllocsManager (svmManager);
7698}
7799
78100ze_result_t DriverHandleImp::getApiVersion (ze_api_version_t *version) {
@@ -133,10 +155,6 @@ DriverHandleImp::~DriverHandleImp() {
133155 for (auto &device : this ->devices ) {
134156 delete device;
135157 }
136- if (this ->svmAllocsManager ) {
137- delete this ->svmAllocsManager ;
138- this ->svmAllocsManager = nullptr ;
139- }
140158}
141159
142160ze_result_t DriverHandleImp::initialize (std::vector<std::unique_ptr<NEO::Device>> neoDevices) {
@@ -151,13 +169,6 @@ ze_result_t DriverHandleImp::initialize(std::vector<std::unique_ptr<NEO::Device>
151169 continue ;
152170 }
153171
154- if (this ->memoryManager == nullptr ) {
155- this ->memoryManager = neoDevice->getMemoryManager ();
156- if (this ->memoryManager == nullptr ) {
157- return ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY;
158- }
159- }
160-
161172 const auto rootDeviceIndex = neoDevice->getRootDeviceIndex ();
162173 auto rootDeviceEnvironment = neoDevice->getExecutionEnvironment ()->rootDeviceEnvironments [rootDeviceIndex].get ();
163174
@@ -189,21 +200,12 @@ ze_result_t DriverHandleImp::initialize(std::vector<std::unique_ptr<NEO::Device>
189200 return ZE_RESULT_ERROR_UNINITIALIZED;
190201 }
191202
192- this ->svmAllocsManager = new NEO::SVMAllocsManager (memoryManager, multiOsContextDriver);
193- if (this ->svmAllocsManager == nullptr ) {
194- return ZE_RESULT_ERROR_OUT_OF_HOST_MEMORY;
195- }
196-
197203 this ->numDevices = static_cast <uint32_t >(this ->devices .size ());
198204
199205 extensionFunctionsLookupMap = getExtensionFunctionsLookupMap ();
200206
201207 uuidTimestamp = static_cast <uint64_t >(std::chrono::system_clock::now ().time_since_epoch ().count ());
202208
203- if (NEO::DebugManager.flags .EnableHostPointerImport .get () == 1 ) {
204- createHostPointerManager ();
205- }
206-
207209 return ZE_RESULT_SUCCESS;
208210}
209211
@@ -223,8 +225,6 @@ DriverHandle *DriverHandle::create(std::vector<std::unique_ptr<NEO::Device>> dev
223225
224226 GlobalDriver = driverHandle;
225227
226- driverHandle->getMemoryManager ()->setForceNonSvmForExternalHostPtr (true );
227-
228228 return driverHandle;
229229}
230230
@@ -250,8 +250,8 @@ bool DriverHandleImp::findAllocationDataForRange(const void *buffer,
250250 NEO::SvmAllocationData **allocData) {
251251 // Make sure the host buffer does not overlap any existing allocation
252252 const char *baseAddress = reinterpret_cast <const char *>(buffer);
253- NEO::SvmAllocationData *beginAllocData = svmAllocsManager ->getSVMAlloc (baseAddress);
254- NEO::SvmAllocationData *endAllocData = svmAllocsManager ->getSVMAlloc (baseAddress + size - 1 );
253+ NEO::SvmAllocationData *beginAllocData = getSvmAllocsManager () ->getSVMAlloc (baseAddress);
254+ NEO::SvmAllocationData *endAllocData = getSvmAllocsManager () ->getSVMAlloc (baseAddress + size - 1 );
255255
256256 if (allocData) {
257257 if (beginAllocData) {
@@ -275,8 +275,8 @@ std::vector<NEO::SvmAllocationData *> DriverHandleImp::findAllocationsWithinRang
275275 std::vector<NEO::SvmAllocationData *> allocDataArray;
276276 const char *baseAddress = reinterpret_cast <const char *>(buffer);
277277 // Check if the host buffer overlaps any existing allocation
278- NEO::SvmAllocationData *beginAllocData = svmAllocsManager ->getSVMAlloc (baseAddress);
279- NEO::SvmAllocationData *endAllocData = svmAllocsManager ->getSVMAlloc (baseAddress + size - 1 );
278+ NEO::SvmAllocationData *beginAllocData = this -> getSvmAllocsManager () ->getSVMAlloc (baseAddress);
279+ NEO::SvmAllocationData *endAllocData = this -> getSvmAllocsManager () ->getSVMAlloc (baseAddress + size - 1 );
280280
281281 // Add the allocation that matches the beginning address
282282 if (beginAllocData) {
0 commit comments