@@ -69,6 +69,10 @@ def gen(context, builder, sig, args):
69
69
"--spirv-ext=+SPV_EXT_shader_atomic_float_add"
70
70
]
71
71
72
+ context .extra_compile_options [LLVM_SPIRV_ARGS ] = [
73
+ "--spirv-ext=+SPV_EXT_shader_atomic_float_min_max"
74
+ ]
75
+
72
76
ptr_type = retty .as_pointer ()
73
77
ptr_type .addrspace = atomic_ref_ty .address_space
74
78
@@ -118,6 +122,59 @@ def _intrinsic_fetch_add(ty_context, ty_atomic_ref, ty_val):
118
122
return _intrinsic_helper (ty_context , ty_atomic_ref , ty_val , "fetch_add" )
119
123
120
124
125
+ def _atomic_sub_float_wrapper (gen_fn ):
126
+ def gen (context , builder , sig , args ):
127
+ # args is a tuple, which is immutable
128
+ # covert tuple to list obj first before replacing arg[1]
129
+ # with fneg and convert back to tuple again.
130
+ args_lst = list (args )
131
+ args_lst [1 ] = builder .fneg (args [1 ])
132
+ args = tuple (args_lst )
133
+
134
+ gen_fn (context , builder , sig , args )
135
+
136
+ return gen
137
+
138
+
139
+ @intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
140
+ def _intrinsic_fetch_sub (ty_context , ty_atomic_ref , ty_val ):
141
+ if ty_atomic_ref .dtype in (types .float32 , types .float64 ):
142
+ # dpcpp does not support ``__spirv_AtomicFSubEXT``. fetch_sub
143
+ # for floats is implemented by negating the value and calling fetch_add.
144
+ # For example, A.fetch_sub(A, val) is implemented as A.fetch_add(-val).
145
+ sig , gen = _intrinsic_helper (
146
+ ty_context , ty_atomic_ref , ty_val , "fetch_add"
147
+ )
148
+ return sig , _atomic_sub_float_wrapper (gen )
149
+
150
+ return _intrinsic_helper (ty_context , ty_atomic_ref , ty_val , "fetch_sub" )
151
+
152
+
153
+ @intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
154
+ def _intrinsic_fetch_min (ty_context , ty_atomic_ref , ty_val ):
155
+ return _intrinsic_helper (ty_context , ty_atomic_ref , ty_val , "fetch_min" )
156
+
157
+
158
+ @intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
159
+ def _intrinsic_fetch_max (ty_context , ty_atomic_ref , ty_val ):
160
+ return _intrinsic_helper (ty_context , ty_atomic_ref , ty_val , "fetch_max" )
161
+
162
+
163
+ @intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
164
+ def _intrinsic_fetch_and (ty_context , ty_atomic_ref , ty_val ):
165
+ return _intrinsic_helper (ty_context , ty_atomic_ref , ty_val , "fetch_and" )
166
+
167
+
168
+ @intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
169
+ def _intrinsic_fetch_or (ty_context , ty_atomic_ref , ty_val ):
170
+ return _intrinsic_helper (ty_context , ty_atomic_ref , ty_val , "fetch_or" )
171
+
172
+
173
+ @intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
174
+ def _intrinsic_fetch_xor (ty_context , ty_atomic_ref , ty_val ):
175
+ return _intrinsic_helper (ty_context , ty_atomic_ref , ty_val , "fetch_xor" )
176
+
177
+
121
178
@intrinsic (target = DPEX_KERNEL_EXP_TARGET_NAME )
122
179
def _intrinsic_atomic_ref_ctor (
123
180
ty_context , ref , ty_index , ty_retty_ref # pylint: disable=unused-argument
@@ -294,3 +351,168 @@ def ol_fetch_add_impl(atomic_ref, val):
294
351
return _intrinsic_fetch_add (atomic_ref , val )
295
352
296
353
return ol_fetch_add_impl
354
+
355
+
356
+ @overload_method (AtomicRefType , "fetch_sub" , target = DPEX_KERNEL_EXP_TARGET_NAME )
357
+ def ol_fetch_sub (atomic_ref , val ):
358
+ """SPIR-V overload for
359
+ :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_sub`.
360
+
361
+ Generates the same LLVM IR instruction as dpcpp for the
362
+ `atomic_ref::fetch_sub` function.
363
+
364
+ Raises:
365
+ TypingError: When the dtype of the aggregator value does not match the
366
+ dtype of the AtomicRef type.
367
+ """
368
+ if atomic_ref .dtype != val :
369
+ raise errors .TypingError (
370
+ f"Type of value to sub: { val } does not match the type of the "
371
+ f"reference: { atomic_ref .dtype } stored in the atomic ref."
372
+ )
373
+
374
+ def ol_fetch_sub_impl (atomic_ref , val ):
375
+ # pylint: disable=no-value-for-parameter
376
+ return _intrinsic_fetch_sub (atomic_ref , val )
377
+
378
+ return ol_fetch_sub_impl
379
+
380
+
381
+ @overload_method (AtomicRefType , "fetch_min" , target = DPEX_KERNEL_EXP_TARGET_NAME )
382
+ def ol_fetch_min (atomic_ref , val ):
383
+ """SPIR-V overload for
384
+ :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_min`.
385
+
386
+ Generates the same LLVM IR instruction as dpcpp for the
387
+ `atomic_ref::fetch_min` function.
388
+
389
+ Raises:
390
+ TypingError: When the dtype of the aggregator value does not match the
391
+ dtype of the AtomicRef type.
392
+ """
393
+ if atomic_ref .dtype != val :
394
+ raise errors .TypingError (
395
+ f"Type of value to find min: { val } does not match the type of the "
396
+ f"reference: { atomic_ref .dtype } stored in the atomic ref."
397
+ )
398
+
399
+ def ol_fetch_min_impl (atomic_ref , val ):
400
+ # pylint: disable=no-value-for-parameter
401
+ return _intrinsic_fetch_min (atomic_ref , val )
402
+
403
+ return ol_fetch_min_impl
404
+
405
+
406
+ @overload_method (AtomicRefType , "fetch_max" , target = DPEX_KERNEL_EXP_TARGET_NAME )
407
+ def ol_fetch_max (atomic_ref , val ):
408
+ """SPIR-V overload for
409
+ :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_max`.
410
+
411
+ Generates the same LLVM IR instruction as dpcpp for the
412
+ `atomic_ref::fetch_max` function.
413
+
414
+ Raises:
415
+ TypingError: When the dtype of the aggregator value does not match the
416
+ dtype of the AtomicRef type.
417
+ """
418
+ if atomic_ref .dtype != val :
419
+ raise errors .TypingError (
420
+ f"Type of value to find max: { val } does not match the type of the "
421
+ f"reference: { atomic_ref .dtype } stored in the atomic ref."
422
+ )
423
+
424
+ def ol_fetch_max_impl (atomic_ref , val ):
425
+ # pylint: disable=no-value-for-parameter
426
+ return _intrinsic_fetch_max (atomic_ref , val )
427
+
428
+ return ol_fetch_max_impl
429
+
430
+
431
+ @overload_method (AtomicRefType , "fetch_and" , target = DPEX_KERNEL_EXP_TARGET_NAME )
432
+ def ol_fetch_and (atomic_ref , val ):
433
+ """SPIR-V overload for
434
+ :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_and`.
435
+
436
+ Generates the same LLVM IR instruction as dpcpp for the
437
+ `atomic_ref::fetch_and` function.
438
+
439
+ Raises:
440
+ TypingError: When the dtype of the aggregator value does not match the
441
+ dtype of the AtomicRef type.
442
+ """
443
+ if atomic_ref .dtype != val :
444
+ raise errors .TypingError (
445
+ f"Type of value to and: { val } does not match the type of the "
446
+ f"reference: { atomic_ref .dtype } stored in the atomic ref."
447
+ )
448
+
449
+ if atomic_ref .dtype not in (types .int32 , types .int64 ):
450
+ raise errors .TypingError (
451
+ "fetch_and operation only supported on int32 and int64 dtypes."
452
+ )
453
+
454
+ def ol_fetch_and_impl (atomic_ref , val ):
455
+ # pylint: disable=no-value-for-parameter
456
+ return _intrinsic_fetch_and (atomic_ref , val )
457
+
458
+ return ol_fetch_and_impl
459
+
460
+
461
+ @overload_method (AtomicRefType , "fetch_or" , target = DPEX_KERNEL_EXP_TARGET_NAME )
462
+ def ol_fetch_or (atomic_ref , val ):
463
+ """SPIR-V overload for
464
+ :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_or`.
465
+
466
+ Generates the same LLVM IR instruction as dpcpp for the
467
+ `atomic_ref::fetch_or` function.
468
+
469
+ Raises:
470
+ TypingError: When the dtype of the aggregator value does not match the
471
+ dtype of the AtomicRef type.
472
+ """
473
+ if atomic_ref .dtype != val :
474
+ raise errors .TypingError (
475
+ f"Type of value to or: { val } does not match the type of the "
476
+ f"reference: { atomic_ref .dtype } stored in the atomic ref."
477
+ )
478
+
479
+ if atomic_ref .dtype not in (types .int32 , types .int64 ):
480
+ raise errors .TypingError (
481
+ "fetch_or operation only supported on int32 and int64 dtypes."
482
+ )
483
+
484
+ def ol_fetch_or_impl (atomic_ref , val ):
485
+ # pylint: disable=no-value-for-parameter
486
+ return _intrinsic_fetch_or (atomic_ref , val )
487
+
488
+ return ol_fetch_or_impl
489
+
490
+
491
+ @overload_method (AtomicRefType , "fetch_xor" , target = DPEX_KERNEL_EXP_TARGET_NAME )
492
+ def ol_fetch_xor (atomic_ref , val ):
493
+ """SPIR-V overload for
494
+ :meth:`numba_dpex.experimental.kernel_iface.AtomicRef.fetch_xor`.
495
+
496
+ Generates the same LLVM IR instruction as dpcpp for the
497
+ `atomic_ref::fetch_xor` function.
498
+
499
+ Raises:
500
+ TypingError: When the dtype of the aggregator value does not match the
501
+ dtype of the AtomicRef type.
502
+ """
503
+ if atomic_ref .dtype != val :
504
+ raise errors .TypingError (
505
+ f"Type of value to xor: { val } does not match the type of the "
506
+ f"reference: { atomic_ref .dtype } stored in the atomic ref."
507
+ )
508
+
509
+ if atomic_ref .dtype not in (types .int32 , types .int64 ):
510
+ raise errors .TypingError (
511
+ "fetch_xor operation only supported on int32 and int64 dtypes."
512
+ )
513
+
514
+ def ol_fetch_xor_impl (atomic_ref , val ):
515
+ # pylint: disable=no-value-for-parameter
516
+ return _intrinsic_fetch_xor (atomic_ref , val )
517
+
518
+ return ol_fetch_xor_impl
0 commit comments