@@ -213,54 +213,20 @@ These translate any kernel dimensions from one convention to the other. An
213213example of an equivalent SYCL call for a 3D kernel using ` compat ` is
214214` syclcompat::global_id::x() == get_global_id(2) ` .
215215
216- ### Local Memory
216+ ### ptr_to_int
217217
218- When using ` compat ` functions, there are two distinct interfaces to allocate
219- device local memory. The first interface uses the _ sycl_ext_oneapi_local_memory_
220- extension to leverage local memory defined at compile time.
221- _ sycl_ext_oneapi_local_memory_ is accessed through the following wrapper:
222-
223- ``` c++
224- namespace syclcompat {
225-
226- template <typename AllocT > auto * local_mem();
227-
228- } // syclcompat
229- ```
230-
231- `syclcompat::local_mem<AllocT>()` can be used as illustrated in the example
232- below.
233-
234- ```c++
235- // Sample kernel
236- using namespace syclcompat;
237- template <int BLOCK_SIZE>
238- void local_mem_2d(int *d_A) {
239- // Local memory extension wrapper, size defined at compile-time
240- auto As = local_mem<int[BLOCK_SIZE][BLOCK_SIZE]>();
241- int id_x = local_id::x();
242- int id_y = local_id::y();
243- As[id_y][id_x] = id_x * BLOCK_SIZE + id_y;
244- wg_barrier();
245- int val = As[BLOCK_SIZE - id_y - 1][BLOCK_SIZE - id_x - 1];
246- d_A[global_id::y() * BLOCK_SIZE + global_id::x()] = val;
247- }
248- ```
249-
250- The second interface allows users to allocate device local memory at runtime.
251- SYCLcompat provides this functionality through its kernel launch interface,
252- ` launch<function> ` , defined in the following section.
253-
254- The following cuda backend specific functions are introduced in order
218+ The following cuda backend specific function is introduced in order
255219to translate from the local memory pointers introduced above to ` uint32_t ` or
256220` size_t ` variables that contain a byte address to the local
257221( local refers to`.shared` in nvptx) memory state space.
258222
259223``` c++
260224namespace syclcompat {
261- __ syclcompat_inline__ uint32_t nvvm_get_smem_pointer(void * ptr);
262-
263- __ syclcompat_inline__ size_t cvta_generic_to_shared(void * ptr);
225+ template <typename T >
226+ __ syclcompat_inline__
227+ std::enable_if_t<std::is_same_v<T, uint32_t> || std::is_same_v<T, size_t>,
228+ T>
229+ ptr_to_int(void * ptr)
264230} // syclcompat
265231```
266232
@@ -272,14 +238,8 @@ A simplified example usage of the above functions is as follows:
272238 half *data = syclcompat::local_mem<half[NUM_ELEMENTS]>();
273239 // ...
274240 // ...
275- T addr;
276- if constexpr (std::is_same_v<size_t, T>) {
277- addr = syclcompat::cvta_generic_to_shared(
278- reinterpret_cast<char *>(data) + (id % 8) * 16);
279- } else { // T == uint32_t
280- addr = syclcompat::nvvm_get_smem_pointer(
281- reinterpret_cast<char *>(data) + (id % 8) * 16);
282- }
241+ T addr =
242+ syclcompat::ptr_to_int<T>(reinterpret_cast<char *>(data) + (id % 8) * 16);
283243
284244uint32_t fragment;
285245#if defined(__NVPTX__)
0 commit comments