Skip to content

Commit cc8825b

Browse files
committed
adding attributes to function decls and callsites of atomic fetch_phi fns; adding attributes to callsites of load,store,exhange,cmpexchg fns
1 parent ee8e080 commit cc8825b

File tree

1 file changed

+35
-8
lines changed

1 file changed

+35
-8
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py

Lines changed: 35 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@
3333
get_scope,
3434
)
3535
from .spv_atomic_fn_declarations import (
36+
_SUPPORT_CONVERGENT,
3637
get_or_insert_atomic_load_fn,
3738
get_or_insert_spv_atomic_compare_exchange_fn,
3839
get_or_insert_spv_atomic_exchange_fn,
@@ -114,19 +115,29 @@ def gen(context, builder, sig, args):
114115
mangled_fn_name,
115116
)
116117
func.calling_convention = CC_SPIR_FUNC
117-
spirv_memory_semantics_mask = get_memory_semantics_mask(
118-
atomic_ref_ty.memory_order
119-
)
120-
spirv_scope = get_scope(atomic_ref_ty.memory_scope)
118+
if _SUPPORT_CONVERGENT:
119+
func.attributes.add("convergent")
120+
func.attributes.add("nounwind")
121121

122122
fn_args = [
123123
builder.extract_value(args[0], data_attr_pos),
124-
context.get_constant(types.int32, spirv_scope),
125-
context.get_constant(types.int32, spirv_memory_semantics_mask),
124+
context.get_constant(
125+
types.int32, get_scope(atomic_ref_ty.memory_scope)
126+
),
127+
context.get_constant(
128+
types.int32,
129+
get_memory_semantics_mask(atomic_ref_ty.memory_order),
130+
),
126131
args[1],
127132
]
128133

129-
return builder.call(func, fn_args)
134+
callinst = builder.call(func, fn_args)
135+
136+
if _SUPPORT_CONVERGENT:
137+
callinst.attributes.add("convergent")
138+
callinst.attributes.add("nounwind")
139+
140+
return callinst
130141

131142
return sig, gen
132143

@@ -271,6 +282,10 @@ def _intrinsic_load_gen(context, builder, sig, args):
271282
fn_args,
272283
)
273284

285+
if _SUPPORT_CONVERGENT:
286+
ret_val.attributes.add("convergent")
287+
ret_val.attributes.add("nounwind")
288+
274289
if sig.args[0].dtype == types.float32:
275290
ret_val = builder.bitcast(ret_val, llvmir.FloatType())
276291
elif sig.args[0].dtype == types.float64:
@@ -326,13 +341,17 @@ def _intrinsic_store_gen(context, builder, sig, args):
326341
store_arg,
327342
]
328343

329-
builder.call(
344+
callinst = builder.call(
330345
get_or_insert_spv_atomic_store_fn(
331346
context, builder.module, atomic_ref_ty
332347
),
333348
atomic_store_fn_args,
334349
)
335350

351+
if _SUPPORT_CONVERGENT:
352+
callinst.attributes.add("convergent")
353+
callinst.attributes.add("nounwind")
354+
336355
return sig, _intrinsic_store_gen
337356

338357

@@ -388,6 +407,10 @@ def _intrinsic_exchange_gen(context, builder, sig, args):
388407
atomic_exchange_fn_args,
389408
)
390409

410+
if _SUPPORT_CONVERGENT:
411+
ret_val.attributes.add("convergent")
412+
ret_val.attributes.add("nounwind")
413+
391414
if sig.args[0].dtype == types.float32:
392415
ret_val = builder.bitcast(ret_val, llvmir.FloatType())
393416
elif sig.args[0].dtype == types.float64:
@@ -478,6 +501,10 @@ def _intrinsic_compare_exchange_gen(context, builder, sig, args):
478501
atomic_cmpexchg_fn_args,
479502
)
480503

504+
if _SUPPORT_CONVERGENT:
505+
ret_val.attributes.add("convergent")
506+
ret_val.attributes.add("nounwind")
507+
481508
# compare_exchange returns the old value stored in AtomicRef object.
482509
# If the return value is same as expected, then compare_exchange
483510
# succeeded in replacing AtomicRef object with desired.

0 commit comments

Comments
 (0)