@@ -56,6 +56,7 @@ from ._backend cimport ( # noqa: E211
5656 DPCTLSyclEventRef,
5757 _arg_data_type,
5858 _backend_type,
59+ _md_local_accessor,
5960 _queue_property_type,
6061)
6162from .memory._memory cimport _Memory
@@ -121,6 +122,47 @@ cdef class kernel_arg_type_attribute:
121122 return self .attr_value
122123
123124
125+ cdef class LocalAccessor:
126+ cdef _md_local_accessor lacc
127+
128+ def __cinit__ (self , size_t ndim , str type , size_t dim0 , size_t dim1 , size_t dim2 ):
129+ self .lacc.ndim = ndim
130+ self .lacc.dim0 = dim0
131+ self .lacc.dim1 = dim1
132+ self .lacc.dim2 = dim2
133+
134+ if ndim < 1 or ndim > 3 :
135+ raise ValueError
136+ if type == ' i1' :
137+ self .lacc.dpctl_type_id = _arg_data_type._INT8_T
138+ elif type == ' u1' :
139+ self .lacc.dpctl_type_id = _arg_data_type._UINT8_T
140+ elif type == ' i2' :
141+ self .lacc.dpctl_type_id = _arg_data_type._INT16_T
142+ elif type == ' u2' :
143+ self .lacc.dpctl_type_id = _arg_data_type._UINT16_T
144+ elif type == ' i4' :
145+ self .lacc.dpctl_type_id = _arg_data_type._INT32_T
146+ elif type == ' u4' :
147+ self .lacc.dpctl_type_id = _arg_data_type._UINT32_T
148+ elif type == ' i8' :
149+ self .lacc.dpctl_type_id = _arg_data_type._INT64_T
150+ elif type == ' u8' :
151+ self .lacc.dpctl_type_id = _arg_data_type._UINT64_T
152+ elif type == ' f4' :
153+ self .lacc.dpctl_type_id = _arg_data_type._FLOAT
154+ elif type == ' f8' :
155+ self .lacc.dpctl_type_id = _arg_data_type._DOUBLE
156+ else :
157+ raise ValueError (f" Unrecornigzed type value: '{type}'" )
158+
159+ def __repr__ (self ):
160+ return " LocalAccessor(" + self .ndim + " )"
161+
162+ cdef size_t addressof(self ):
163+ return < size_t> & self .lacc
164+
165+
124166cdef class _kernel_arg_type:
125167 """
126168 An enumeration of supported kernel argument types in
@@ -849,6 +891,9 @@ cdef class SyclQueue(_SyclQueue):
849891 elif isinstance (arg, _Memory):
850892 kargs[idx]= < void * > (< size_t> arg._pointer)
851893 kargty[idx] = _arg_data_type._VOID_PTR
894+ elif isinstance (arg, LocalAccessor):
895+ kargs[idx] = < void * > ((< LocalAccessor> arg).addressof())
896+ kargty[idx] = _arg_data_type._LOCAL_ACCESSOR
852897 else :
853898 ret = - 1
854899 return ret
0 commit comments