@@ -59,6 +59,7 @@ from ._backend cimport ( # noqa: E211
59
59
DPCTLWorkGroupMemory_Delete,
60
60
_arg_data_type,
61
61
_backend_type,
62
+ _md_local_accessor,
62
63
_queue_property_type,
63
64
)
64
65
from .memory._memory cimport _Memory
@@ -125,6 +126,47 @@ cdef class kernel_arg_type_attribute:
125
126
return self .attr_value
126
127
127
128
129
+ cdef class LocalAccessor:
130
+ cdef _md_local_accessor lacc
131
+
132
+ def __cinit__ (self , size_t ndim , str type , size_t dim0 , size_t dim1 , size_t dim2 ):
133
+ self .lacc.ndim = ndim
134
+ self .lacc.dim0 = dim0
135
+ self .lacc.dim1 = dim1
136
+ self .lacc.dim2 = dim2
137
+
138
+ if ndim < 1 or ndim > 3 :
139
+ raise ValueError
140
+ if type == ' i1' :
141
+ self .lacc.dpctl_type_id = _arg_data_type._INT8_T
142
+ elif type == ' u1' :
143
+ self .lacc.dpctl_type_id = _arg_data_type._UINT8_T
144
+ elif type == ' i2' :
145
+ self .lacc.dpctl_type_id = _arg_data_type._INT16_T
146
+ elif type == ' u2' :
147
+ self .lacc.dpctl_type_id = _arg_data_type._UINT16_T
148
+ elif type == ' i4' :
149
+ self .lacc.dpctl_type_id = _arg_data_type._INT32_T
150
+ elif type == ' u4' :
151
+ self .lacc.dpctl_type_id = _arg_data_type._UINT32_T
152
+ elif type == ' i8' :
153
+ self .lacc.dpctl_type_id = _arg_data_type._INT64_T
154
+ elif type == ' u8' :
155
+ self .lacc.dpctl_type_id = _arg_data_type._UINT64_T
156
+ elif type == ' f4' :
157
+ self .lacc.dpctl_type_id = _arg_data_type._FLOAT
158
+ elif type == ' f8' :
159
+ self .lacc.dpctl_type_id = _arg_data_type._DOUBLE
160
+ else :
161
+ raise ValueError (f" Unrecornigzed type value: '{type}'" )
162
+
163
+ def __repr__ (self ):
164
+ return " LocalAccessor(" + self .ndim + " )"
165
+
166
+ cdef size_t addressof(self ):
167
+ return < size_t> & self .lacc
168
+
169
+
128
170
cdef class _kernel_arg_type:
129
171
"""
130
172
An enumeration of supported kernel argument types in
@@ -865,6 +907,9 @@ cdef class SyclQueue(_SyclQueue):
865
907
elif isinstance (arg, WorkGroupMemory):
866
908
kargs[idx] = < void * > (< size_t> arg._ref)
867
909
kargty[idx] = _arg_data_type._WORK_GROUP_MEMORY
910
+ elif isinstance (arg, LocalAccessor):
911
+ kargs[idx] = < void * > ((< LocalAccessor> arg).addressof())
912
+ kargty[idx] = _arg_data_type._LOCAL_ACCESSOR
868
913
else :
869
914
ret = - 1
870
915
return ret
0 commit comments