1212
1313namespace embree
1414{
15- size_t total_allocations = 0 ;
16-
17- #if defined(EMBREE_SYCL_SUPPORT)
18-
19- __thread sycl::context* tls_context_tutorial = nullptr ;
20- __thread sycl::device* tls_device_tutorial = nullptr ;
21-
22- __thread sycl::context* tls_context_embree = nullptr ;
23- __thread sycl::device* tls_device_embree = nullptr ;
24-
25- void enableUSMAllocEmbree (sycl::context* context, sycl::device* device)
26- {
27- // if (tls_context_embree != nullptr) throw std::runtime_error("USM allocation already enabled");
28- // if (tls_device_embree != nullptr) throw std::runtime_error("USM allocation already enabled");
29- if (tls_context_embree != nullptr ) {
30- abort ();
31- }
32- if (tls_device_embree != nullptr ) {
33- abort ();
34- }
35- tls_context_embree = context;
36- tls_device_embree = device;
37- }
38-
39- void disableUSMAllocEmbree ()
40- {
41- // if (tls_context_embree == nullptr) throw std::runtime_error("USM allocation not enabled");
42- // if (tls_device_embree == nullptr) throw std::runtime_error("USM allocation not enabled");
43- if (tls_context_embree == nullptr ) {
44- abort ();
45- }
46- if (tls_device_embree == nullptr ) {
47- abort ();
48- }
49- tls_context_embree = nullptr ;
50- tls_device_embree = nullptr ;
51- }
52-
53- void enableUSMAllocTutorial (sycl::context* context, sycl::device* device)
54- {
55- // if (tls_context_tutorial != nullptr) throw std::runtime_error("USM allocation already enabled");
56- // if (tls_device_tutorial != nullptr) throw std::runtime_error("USM allocation already enabled");
57- tls_context_tutorial = context;
58- tls_device_tutorial = device;
59- }
60-
61- void disableUSMAllocTutorial ()
62- {
63- // if (tls_context_tutorial == nullptr) throw std::runtime_error("USM allocation not enabled");
64- // if (tls_device_tutorial == nullptr) throw std::runtime_error("USM allocation not enabled");
65- if (tls_context_tutorial == nullptr ) {
66- abort ();
67- }
68- if (tls_device_tutorial == nullptr ) {
69- abort ();
70- }
71-
72- tls_context_tutorial = nullptr ;
73- tls_device_tutorial = nullptr ;
74- }
75-
76- #endif
77-
7815 void * alignedMalloc (size_t size, size_t align)
7916 {
8017 if (size == 0 )
8118 return nullptr ;
8219
8320 assert ((align & (align-1 )) == 0 );
8421 void * ptr = _mm_malloc (size,align);
85- // if (size != 0 && ptr == nullptr)
86- // throw std::bad_alloc();
87- if (size != 0 && ptr == nullptr ) {
88- abort ();
89- }
22+ if (size != 0 && ptr == nullptr )
23+ abort (); // throw std::bad_alloc();
9024 return ptr;
9125 }
9226
9327 void alignedFree (void * ptr)
9428 {
95- if (ptr)
29+ if (ptr) {
9630 _mm_free (ptr);
31+ }
9732 }
9833
9934#if defined(EMBREE_SYCL_SUPPORT)
@@ -107,67 +42,66 @@ namespace embree
10742 return nullptr ;
10843
10944 assert ((align & (align-1 )) == 0 );
110- total_allocations++;
11145
11246 void * ptr = nullptr ;
113- if (mode == EMBREE_USM_SHARED_DEVICE_READ_ONLY )
47+ if (mode == EmbreeUSMMode::DEVICE_READ_ONLY )
11448 ptr = sycl::aligned_alloc_shared (align,size,*device,*context,sycl::ext::oneapi::property::usm::device_read_only ());
11549 else
11650 ptr = sycl::aligned_alloc_shared (align,size,*device,*context);
117-
118- // if (size != 0 && ptr == nullptr)
119- // throw std::bad_alloc();
120- if (size != 0 && ptr == nullptr ) {
121- abort ();
122- }
51+
52+ if (size != 0 && ptr == nullptr )
53+ abort (); // throw std::bad_alloc();
12354
12455 return ptr;
12556 }
126-
127- static MutexSys g_alloc_mutex;
128-
129- void * alignedSYCLMalloc (size_t size, size_t align, EmbreeUSMMode mode)
130- {
131- if (tls_context_tutorial) return alignedSYCLMalloc (tls_context_tutorial, tls_device_tutorial, size, align, mode);
132- if (tls_context_embree ) return alignedSYCLMalloc (tls_context_embree, tls_device_embree, size, align, mode);
133- return nullptr ;
134- }
13557
136- void alignedSYCLFree (sycl::context* context, void * ptr )
58+ void * alignedSYCLMalloc (sycl::context* context, sycl::device* device, size_t size, size_t align, EmbreeUSMMode mode, EmbreeMemoryType type )
13759 {
13860 assert (context);
139- if (ptr) {
140- sycl::free (ptr,*context);
141- }
142- }
61+ assert (device);
62+
63+ if (size == 0 )
64+ return nullptr ;
14365
144- void alignedSYCLFree (void * ptr)
145- {
146- if (tls_context_tutorial) return alignedSYCLFree (tls_context_tutorial, ptr);
147- if (tls_context_embree ) return alignedSYCLFree (tls_context_embree, ptr);
148- }
66+ assert ((align & (align-1 )) == 0 );
14967
150- #endif
68+ void * ptr = nullptr ;
69+ if (type == EmbreeMemoryType::USM_SHARED) {
70+ if (mode == EmbreeUSMMode::DEVICE_READ_ONLY)
71+ ptr = sycl::aligned_alloc_shared (align,size,*device,*context,sycl::ext::oneapi::property::usm::device_read_only ());
72+ else
73+ ptr = sycl::aligned_alloc_shared (align,size,*device,*context);
74+ }
75+ else if (type == EmbreeMemoryType::USM_HOST) {
76+ ptr = sycl::aligned_alloc_host (align,size,*context);
77+ }
78+ else if (type == EmbreeMemoryType::USM_DEVICE) {
79+ ptr = sycl::aligned_alloc_device (align,size,*device,*context);
80+ }
81+ else {
82+ ptr = alignedMalloc (size,align);
83+ }
15184
152- void * alignedUSMMalloc (size_t size, size_t align, EmbreeUSMMode mode)
85+ if (size != 0 && ptr == nullptr )
86+ abort (); // throw std::bad_alloc();
87+
88+ return ptr;
89+ }
90+
91+ void alignedSYCLFree (sycl::context* context, void * ptr)
15392 {
154- #if defined(EMBREE_SYCL_SUPPORT)
155- if (tls_context_embree || tls_context_tutorial)
156- return alignedSYCLMalloc (size,align,mode);
157- else
158- #endif
159- return alignedMalloc (size,align);
93+ assert (context);
94+ if (ptr) {
95+ sycl::usm::alloc type = sycl::get_pointer_type (ptr, *context);
96+ if (type == sycl::usm::alloc::host || type == sycl::usm::alloc::device || type == sycl::usm::alloc::shared)
97+ sycl::free (ptr,*context);
98+ else {
99+ alignedFree (ptr);
100+ }
101+ }
160102 }
161103
162- void alignedUSMFree (void * ptr)
163- {
164- #if defined(EMBREE_SYCL_SUPPORT)
165- if (tls_context_embree || tls_context_tutorial)
166- return alignedSYCLFree (ptr);
167- else
168104#endif
169- return alignedFree (ptr);
170- }
171105
172106 static bool huge_pages_enabled = false ;
173107 static MutexSys os_init_mutex;
@@ -265,10 +199,7 @@ namespace embree
265199 /* fall back to 4k pages */
266200 int flags = MEM_COMMIT | MEM_RESERVE;
267201 char * ptr = (char *) VirtualAlloc (nullptr ,bytes,flags,PAGE_READWRITE);
268- // if (ptr == nullptr) throw std::bad_alloc();
269- if (ptr == nullptr ) {
270- abort ();
271- }
202+ if (ptr == nullptr ) abort (); // throw std::bad_alloc();
272203 hugepages = false ;
273204 return ptr;
274205 }
@@ -284,11 +215,8 @@ namespace embree
284215 if (bytesNew >= bytesOld)
285216 return bytesOld;
286217
287- // if (!VirtualFree((char*)ptr+bytesNew,bytesOld-bytesNew,MEM_DECOMMIT))
288- // throw std::bad_alloc();
289- if (!VirtualFree ((char *)ptr+bytesNew,bytesOld-bytesNew,MEM_DECOMMIT)) {
290- abort ();
291- }
218+ if (!VirtualFree ((char *)ptr+bytesNew,bytesOld-bytesNew,MEM_DECOMMIT))
219+ abort (); // throw std::bad_alloc();
292220
293221 return bytesNew;
294222 }
@@ -298,11 +226,8 @@ namespace embree
298226 if (bytes == 0 )
299227 return ;
300228
301- // if (!VirtualFree(ptr,0,MEM_RELEASE))
302- // throw std::bad_alloc();
303- if (!VirtualFree (ptr,0 ,MEM_RELEASE)) {
304- abort ();
305- }
229+ if (!VirtualFree (ptr,0 ,MEM_RELEASE))
230+ abort (); // throw std::bad_alloc();
306231 }
307232
308233 void os_advise (void *ptr, size_t bytes)
@@ -406,10 +331,7 @@ namespace embree
406331
407332 /* fallback to 4k pages */
408333 void * ptr = (char *) mmap (0 , bytes, PROT_READ | PROT_WRITE, MAP_PRIVATE | MAP_ANON, -1 , 0 );
409- // if (ptr == MAP_FAILED) throw std::bad_alloc();
410- if (ptr == MAP_FAILED) {
411- abort ();
412- }
334+ if (ptr == MAP_FAILED) abort (); // throw std::bad_alloc();
413335 hugepages = false ;
414336
415337 /* advise huge page hint for THP */
@@ -425,11 +347,8 @@ namespace embree
425347 if (bytesNew >= bytesOld)
426348 return bytesOld;
427349
428- // if (munmap((char*)ptr+bytesNew,bytesOld-bytesNew) == -1)
429- // throw std::bad_alloc();
430- if (munmap ((char *)ptr+bytesNew,bytesOld-bytesNew) == -1 ) {
431- abort ();
432- }
350+ if (munmap ((char *)ptr+bytesNew,bytesOld-bytesNew) == -1 )
351+ abort (); // throw std::bad_alloc();
433352
434353 return bytesNew;
435354 }
@@ -442,11 +361,8 @@ namespace embree
442361 /* for hugepages we need to also align the size */
443362 const size_t pageSize = hugepages ? PAGE_SIZE_2M : PAGE_SIZE_4K;
444363 bytes = (bytes+pageSize-1 ) & ~(pageSize-1 );
445- // if (munmap(ptr,bytes) == -1)
446- // throw std::bad_alloc();
447- if (munmap (ptr,bytes) == -1 ) {
448- abort ();
449- }
364+ if (munmap (ptr,bytes) == -1 )
365+ abort (); // throw std::bad_alloc();
450366 }
451367
452368 /* hint for transparent huge pages (THP) */
0 commit comments