Skip to content

Commit fd01ba5

Browse files
oleksandr-pavlykndgrigorian
authored andcommitted
Define LocalAccessor type to use to specify local accessor kernel arguments
LocalAccessor(ndim, elemental_type_str, dim0, dim1, dim2) The elemental type can be one of the following: "i1", "u1", "i2", "u2", "i4", "u4", "i8", "u8", "f4", "f8"
1 parent b6d639a commit fd01ba5

File tree

1 file changed

+45
-0
lines changed

1 file changed

+45
-0
lines changed

dpctl/_sycl_queue.pyx

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ from ._backend cimport ( # noqa: E211
5959
DPCTLWorkGroupMemory_Delete,
6060
_arg_data_type,
6161
_backend_type,
62+
_md_local_accessor,
6263
_queue_property_type,
6364
)
6465
from .memory._memory cimport _Memory
@@ -125,6 +126,47 @@ cdef class kernel_arg_type_attribute:
125126
return self.attr_value
126127

127128

129+
cdef class LocalAccessor:
130+
cdef _md_local_accessor lacc
131+
132+
def __cinit__(self, size_t ndim, str type, size_t dim0, size_t dim1, size_t dim2):
133+
self.lacc.ndim = ndim
134+
self.lacc.dim0 = dim0
135+
self.lacc.dim1 = dim1
136+
self.lacc.dim2 = dim2
137+
138+
if ndim < 1 or ndim > 3:
139+
raise ValueError
140+
if type == 'i1':
141+
self.lacc.dpctl_type_id = _arg_data_type._INT8_T
142+
elif type == 'u1':
143+
self.lacc.dpctl_type_id = _arg_data_type._UINT8_T
144+
elif type == 'i2':
145+
self.lacc.dpctl_type_id = _arg_data_type._INT16_T
146+
elif type == 'u2':
147+
self.lacc.dpctl_type_id = _arg_data_type._UINT16_T
148+
elif type == 'i4':
149+
self.lacc.dpctl_type_id = _arg_data_type._INT32_T
150+
elif type == 'u4':
151+
self.lacc.dpctl_type_id = _arg_data_type._UINT32_T
152+
elif type == 'i8':
153+
self.lacc.dpctl_type_id = _arg_data_type._INT64_T
154+
elif type == 'u8':
155+
self.lacc.dpctl_type_id = _arg_data_type._UINT64_T
156+
elif type == 'f4':
157+
self.lacc.dpctl_type_id = _arg_data_type._FLOAT
158+
elif type == 'f8':
159+
self.lacc.dpctl_type_id = _arg_data_type._DOUBLE
160+
else:
161+
raise ValueError(f"Unrecornigzed type value: '{type}'")
162+
163+
def __repr__(self):
164+
return "LocalAccessor(" + self.ndim + ")"
165+
166+
cdef size_t addressof(self):
167+
return <size_t>&self.lacc
168+
169+
128170
cdef class _kernel_arg_type:
129171
"""
130172
An enumeration of supported kernel argument types in
@@ -865,6 +907,9 @@ cdef class SyclQueue(_SyclQueue):
865907
elif isinstance(arg, WorkGroupMemory):
866908
kargs[idx] = <void*>(<size_t>arg._ref)
867909
kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY
910+
elif isinstance(arg, LocalAccessor):
911+
kargs[idx] = <void*>((<LocalAccessor>arg).addressof())
912+
kargty[idx] = _arg_data_type._LOCAL_ACCESSOR
868913
else:
869914
ret = -1
870915
return ret

0 commit comments

Comments
 (0)