@@ -143,63 +143,6 @@ def sub_group_barrier_impl(context, builder, sig, args):
143
143
return _void_value
144
144
145
145
146
- def insert_and_call_atomic_fn (
147
- context , builder , sig , fn_type , dtype , ptr , val , addrspace
148
- ):
149
- ll_p = None
150
- name = ""
151
- if dtype .name == "float32" :
152
- ll_val = llvmir .FloatType ()
153
- ll_p = ll_val .as_pointer ()
154
- if fn_type == "add" :
155
- name = "numba_dpex_atomic_add_f32"
156
- elif fn_type == "sub" :
157
- name = "numba_dpex_atomic_sub_f32"
158
- else :
159
- raise TypeError ("Operation type is not supported %s" % (fn_type ))
160
- elif dtype .name == "float64" :
161
- if True :
162
- ll_val = llvmir .DoubleType ()
163
- ll_p = ll_val .as_pointer ()
164
- if fn_type == "add" :
165
- name = "numba_dpex_atomic_add_f64"
166
- elif fn_type == "sub" :
167
- name = "numba_dpex_atomic_sub_f64"
168
- else :
169
- raise TypeError (
170
- "Operation type is not supported %s" % (fn_type )
171
- )
172
- else :
173
- raise TypeError (
174
- "Atomic operation is not supported for type %s" % (dtype .name )
175
- )
176
-
177
- if addrspace == address_space .LOCAL :
178
- name = name + "_local"
179
- else :
180
- name = name + "_global"
181
-
182
- assert ll_p is not None
183
- assert name != ""
184
- ll_p .addrspace = address_space .GENERIC
185
-
186
- mod = builder .module
187
- if sig .return_type == types .void :
188
- llretty = llvmir .VoidType ()
189
- else :
190
- llretty = context .get_value_type (sig .return_type )
191
-
192
- llargs = [ll_p , context .get_value_type (sig .args [2 ])]
193
- fnty = llvmir .FunctionType (llretty , llargs )
194
-
195
- fn = cgutils .get_or_insert_function (mod , fnty , name )
196
- fn .calling_convention = kernel_target .CC_SPIR_FUNC
197
-
198
- generic_ptr = context .addrspacecast (builder , ptr , address_space .GENERIC )
199
-
200
- return builder .call (fn , [generic_ptr , val ])
201
-
202
-
203
146
def native_atomic_add (context , builder , sig , args ):
204
147
aryty , indty , valty = sig .args
205
148
ary , inds , val = args
@@ -282,27 +225,29 @@ def native_atomic_add(context, builder, sig, args):
282
225
return builder .call (fn , fn_args )
283
226
284
227
228
+ def support_atomic (dtype : types .Type ) -> bool :
229
+ # This check should be the same as described in sycl documentation:
230
+ # https://registry.khronos.org/SYCL/specs/sycl-2020/html/sycl-2020.html#sec:atomic-references
231
+ # If atomic is not supported, it will be emulated by the sycl compiler.
232
+ return (
233
+ dtype == types .int32
234
+ or dtype == types .uint32
235
+ or dtype == types .float32
236
+ or dtype == types .int64
237
+ or dtype == types .uint64
238
+ or dtype == types .float64
239
+ )
240
+
241
+
285
242
@lower (stubs .atomic .add , types .Array , types .intp , types .Any )
286
243
@lower (stubs .atomic .add , types .Array , types .UniTuple , types .Any )
287
244
@lower (stubs .atomic .add , types .Array , types .Tuple , types .Any )
288
245
def atomic_add_tuple (context , builder , sig , args ):
289
- device_type = dpctl .get_current_queue ().sycl_device .device_type
290
246
dtype = sig .args [0 ].dtype
291
-
292
- if dtype == types .float32 or dtype == types .float64 :
293
- if (
294
- device_type == dpctl .device_type .gpu
295
- and config .NATIVE_FP_ATOMICS == 1
296
- ):
297
- return native_atomic_add (context , builder , sig , args )
298
- else :
299
- # Currently, DPCPP only supports native floating point
300
- # atomics for GPUs.
301
- return atomic_add (context , builder , sig , args , "add" )
302
- elif dtype == types .int32 or dtype == types .int64 :
247
+ if support_atomic (dtype ):
303
248
return native_atomic_add (context , builder , sig , args )
304
249
else :
305
- raise TypeError ("Atomic operation on unsupported type %s" % dtype )
250
+ raise TypeError (f "Atomic operation on unsupported type { dtype } " )
306
251
307
252
308
253
def atomic_sub_wrapper (context , builder , sig , args ):
@@ -337,81 +282,11 @@ def atomic_sub_wrapper(context, builder, sig, args):
337
282
@lower (stubs .atomic .sub , types .Array , types .UniTuple , types .Any )
338
283
@lower (stubs .atomic .sub , types .Array , types .Tuple , types .Any )
339
284
def atomic_sub_tuple (context , builder , sig , args ):
340
- device_type = dpctl .get_current_queue ().sycl_device .device_type
341
285
dtype = sig .args [0 ].dtype
342
-
343
- if dtype == types .float32 or dtype == types .float64 :
344
- if (
345
- device_type == dpctl .device_type .gpu
346
- and config .NATIVE_FP_ATOMICS == 1
347
- ):
348
- return atomic_sub_wrapper (context , builder , sig , args )
349
- else :
350
- # Currently, DPCPP only supports native floating point
351
- # atomics for GPUs.
352
- return atomic_add (context , builder , sig , args , "sub" )
353
- elif dtype == types .int32 or dtype == types .int64 :
286
+ if support_atomic (dtype ):
354
287
return atomic_sub_wrapper (context , builder , sig , args )
355
288
else :
356
- raise TypeError ("Atomic operation on unsupported type %s" % dtype )
357
-
358
-
359
- def atomic_add (context , builder , sig , args , name ):
360
- from .atomics import atomic_support_present
361
-
362
- if atomic_support_present ():
363
- context .extra_compile_options [kernel_target .LINK_ATOMIC ] = True
364
- aryty , indty , valty = sig .args
365
- ary , inds , val = args
366
- dtype = aryty .dtype
367
-
368
- if indty == types .intp :
369
- indices = [inds ] # just a single integer
370
- indty = [indty ]
371
- else :
372
- indices = cgutils .unpack_tuple (builder , inds , count = len (indty ))
373
- indices = [
374
- context .cast (builder , i , t , types .intp )
375
- for t , i in zip (indty , indices )
376
- ]
377
-
378
- if dtype != valty :
379
- raise TypeError ("expecting %s but got %s" % (dtype , valty ))
380
-
381
- if aryty .ndim != len (indty ):
382
- raise TypeError (
383
- "indexing %d-D array with %d-D index" % (aryty .ndim , len (indty ))
384
- )
385
-
386
- lary = context .make_array (aryty )(context , builder , ary )
387
- ptr = cgutils .get_item_pointer (context , builder , aryty , lary , indices )
388
-
389
- if isinstance (aryty , Array ) and aryty .addrspace == address_space .LOCAL :
390
- return insert_and_call_atomic_fn (
391
- context ,
392
- builder ,
393
- sig ,
394
- name ,
395
- dtype ,
396
- ptr ,
397
- val ,
398
- address_space .LOCAL ,
399
- )
400
- else :
401
- return insert_and_call_atomic_fn (
402
- context ,
403
- builder ,
404
- sig ,
405
- name ,
406
- dtype ,
407
- ptr ,
408
- val ,
409
- address_space .GLOBAL ,
410
- )
411
- else :
412
- raise ImportError (
413
- "Atomic support is not present, can not perform atomic_add"
414
- )
289
+ raise TypeError (f"Atomic operation on unsupported type { dtype } " )
415
290
416
291
417
292
@lower (stubs .private .array , types .IntegerLiteral , types .Any )
0 commit comments