Skip to content

Commit 85a0804

Browse files
committed
fixing input dtypes for load, store, and exchange functions using bitcast;adding convergent, nounwind attributes
1 parent 59cd572 commit 85a0804

File tree

3 files changed

+161
-48
lines changed

3 files changed

+161
-48
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py

Lines changed: 108 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -231,27 +231,52 @@ def _intrinsic_load(
231231

232232
def _intrinsic_load_gen(context, builder, sig, args):
233233
atomic_ref_ty = sig.args[0]
234-
fn = get_or_insert_atomic_load_fn(
235-
context, builder.module, atomic_ref_ty
236-
)
237234

238-
spirv_memory_semantics_mask = get_memory_semantics_mask(
239-
atomic_ref_ty.memory_order
235+
atomic_ref_ptr = builder.extract_value(
236+
args[0],
237+
context.data_model_manager.lookup(atomic_ref_ty).get_field_position(
238+
"ref"
239+
),
240240
)
241-
spirv_scope = get_scope(atomic_ref_ty.memory_scope)
241+
if sig.args[0].dtype == types.float32:
242+
atomic_ref_ptr = builder.bitcast(
243+
atomic_ref_ptr,
244+
llvmir.PointerType(
245+
llvmir.IntType(32), addrspace=sig.args[0].address_space
246+
),
247+
)
248+
elif sig.args[0].dtype == types.float64:
249+
atomic_ref_ptr = builder.bitcast(
250+
atomic_ref_ptr,
251+
llvmir.PointerType(
252+
llvmir.IntType(64), addrspace=sig.args[0].address_space
253+
),
254+
)
242255

243256
fn_args = [
244-
builder.extract_value(
245-
args[0],
246-
context.data_model_manager.lookup(
247-
atomic_ref_ty
248-
).get_field_position("ref"),
257+
atomic_ref_ptr,
258+
context.get_constant(
259+
types.int32, get_scope(atomic_ref_ty.memory_scope)
260+
),
261+
context.get_constant(
262+
types.int32,
263+
get_memory_semantics_mask(atomic_ref_ty.memory_order),
249264
),
250-
context.get_constant(types.int32, spirv_scope),
251-
context.get_constant(types.int32, spirv_memory_semantics_mask),
252265
]
253266

254-
return builder.call(fn, fn_args)
267+
ret_val = builder.call(
268+
get_or_insert_atomic_load_fn(
269+
context, builder.module, atomic_ref_ty
270+
),
271+
fn_args,
272+
)
273+
274+
if sig.args[0].dtype == types.float32:
275+
ret_val = builder.bitcast(ret_val, llvmir.FloatType())
276+
elif sig.args[0].dtype == types.float64:
277+
ret_val = builder.bitcast(ret_val, llvmir.DoubleType())
278+
279+
return ret_val
255280

256281
return sig, _intrinsic_load_gen
257282

@@ -264,28 +289,49 @@ def _intrinsic_store(
264289

265290
def _intrinsic_store_gen(context, builder, sig, args):
266291
atomic_ref_ty = sig.args[0]
267-
atomic_store_fn = get_or_insert_spv_atomic_store_fn(
268-
context, builder.module, atomic_ref_ty
292+
293+
store_arg = args[1]
294+
atomic_ref_ptr = builder.extract_value(
295+
args[0],
296+
context.data_model_manager.lookup(atomic_ref_ty).get_field_position(
297+
"ref"
298+
),
269299
)
300+
if sig.args[0].dtype == types.float32:
301+
atomic_ref_ptr = builder.bitcast(
302+
atomic_ref_ptr,
303+
llvmir.PointerType(
304+
llvmir.IntType(32), addrspace=sig.args[0].address_space
305+
),
306+
)
307+
store_arg = builder.bitcast(store_arg, llvmir.IntType(32))
308+
elif sig.args[0].dtype == types.float64:
309+
atomic_ref_ptr = builder.bitcast(
310+
atomic_ref_ptr,
311+
llvmir.PointerType(
312+
llvmir.IntType(64), addrspace=sig.args[0].address_space
313+
),
314+
)
315+
store_arg = builder.bitcast(store_arg, llvmir.IntType(64))
270316

271317
atomic_store_fn_args = [
272-
builder.extract_value(
273-
args[0],
274-
context.data_model_manager.lookup(
275-
atomic_ref_ty
276-
).get_field_position("ref"),
277-
),
318+
atomic_ref_ptr,
278319
context.get_constant(
279320
types.int32, get_scope(atomic_ref_ty.memory_scope)
280321
),
281322
context.get_constant(
282323
types.int32,
283324
get_memory_semantics_mask(atomic_ref_ty.memory_order),
284325
),
285-
args[1],
326+
store_arg,
286327
]
287328

288-
builder.call(atomic_store_fn, atomic_store_fn_args)
329+
builder.call(
330+
get_or_insert_spv_atomic_store_fn(
331+
context, builder.module, atomic_ref_ty
332+
),
333+
atomic_store_fn_args,
334+
)
289335

290336
return sig, _intrinsic_store_gen
291337

@@ -298,28 +344,56 @@ def _intrinsic_exchange(
298344

299345
def _intrinsic_exchange_gen(context, builder, sig, args):
300346
atomic_ref_ty = sig.args[0]
301-
atomic_exchange_fn = get_or_insert_spv_atomic_exchange_fn(
302-
context, builder.module, atomic_ref_ty
347+
348+
exchange_arg = args[1]
349+
atomic_ref_ptr = builder.extract_value(
350+
args[0],
351+
context.data_model_manager.lookup(atomic_ref_ty).get_field_position(
352+
"ref"
353+
),
303354
)
355+
if sig.args[0].dtype == types.float32:
356+
atomic_ref_ptr = builder.bitcast(
357+
atomic_ref_ptr,
358+
llvmir.PointerType(
359+
llvmir.IntType(32), addrspace=sig.args[0].address_space
360+
),
361+
)
362+
exchange_arg = builder.bitcast(exchange_arg, llvmir.IntType(32))
363+
elif sig.args[0].dtype == types.float64:
364+
atomic_ref_ptr = builder.bitcast(
365+
atomic_ref_ptr,
366+
llvmir.PointerType(
367+
llvmir.IntType(64), addrspace=sig.args[0].address_space
368+
),
369+
)
370+
exchange_arg = builder.bitcast(exchange_arg, llvmir.IntType(64))
304371

305372
atomic_exchange_fn_args = [
306-
builder.extract_value(
307-
args[0],
308-
context.data_model_manager.lookup(
309-
atomic_ref_ty
310-
).get_field_position("ref"),
311-
),
373+
atomic_ref_ptr,
312374
context.get_constant(
313375
types.int32, get_scope(atomic_ref_ty.memory_scope)
314376
),
315377
context.get_constant(
316378
types.int32,
317379
get_memory_semantics_mask(atomic_ref_ty.memory_order),
318380
),
319-
args[1],
381+
exchange_arg,
320382
]
321383

322-
return builder.call(atomic_exchange_fn, atomic_exchange_fn_args)
384+
ret_val = builder.call(
385+
get_or_insert_spv_atomic_exchange_fn(
386+
context, builder.module, atomic_ref_ty
387+
),
388+
atomic_exchange_fn_args,
389+
)
390+
391+
if sig.args[0].dtype == types.float32:
392+
ret_val = builder.bitcast(ret_val, llvmir.FloatType())
393+
elif sig.args[0].dtype == types.float64:
394+
ret_val = builder.bitcast(ret_val, llvmir.DoubleType())
395+
396+
return ret_val
323397

324398
return sig, _intrinsic_exchange_gen
325399

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_group_barrier_overloads.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
Provides overloads for functions included in kernel_iface.barrier that
77
generate dpcpp SPIR-V LLVM IR intrinsic function calls.
88
"""
9-
import warnings
109

1110
from llvmlite import ir as llvmir
1211
from numba.core import cgutils, types
@@ -20,18 +19,7 @@
2019
from numba_dpex.kernel_api.memory_enums import MemoryOrder, MemoryScope
2120

2221
from ._spv_atomic_inst_helper import get_memory_semantics_mask, get_scope
23-
24-
_SUPPORT_CONVERGENT = True
25-
26-
try:
27-
llvmir.FunctionAttributes("convergent")
28-
except ValueError:
29-
warnings.warn(
30-
"convergent attribute is supported only starting llvmlite "
31-
+ "0.42. Not setting this attribute may result in unexpected behavior"
32-
+ "when using group_barrier"
33-
)
34-
_SUPPORT_CONVERGENT = False
22+
from .spv_atomic_fn_declarations import _SUPPORT_CONVERGENT
3523

3624

3725
def _get_memory_scope(fence_scope):

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/spv_atomic_fn_declarations.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,43 @@
77
functions and their use inside an LLVM module.
88
"""
99

10+
import warnings
11+
1012
from llvmlite import ir as llvmir
1113
from numba.core import cgutils, types
1214

1315
from numba_dpex.core import itanium_mangler as ext_itanium_mangler
1416
from numba_dpex.kernel_api_impl.spirv.target import CC_SPIR_FUNC
1517

18+
_SUPPORT_CONVERGENT = True
19+
20+
try:
21+
llvmir.FunctionAttributes("convergent")
22+
except ValueError:
23+
warnings.warn(
24+
"convergent attribute is supported only starting llvmlite "
25+
+ "0.42. Not setting this attribute may result in unexpected behavior"
26+
+ "when using group_barrier"
27+
)
28+
_SUPPORT_CONVERGENT = False
29+
1630

1731
def get_or_insert_atomic_load_fn(context, module, atomic_ref_ty):
1832
"""
1933
Gets or inserts a declaration for a __spirv_AtomicLoad call into the
2034
specified LLVM IR module.
2135
"""
2236
atomic_ref_dtype = atomic_ref_ty.dtype
37+
38+
if atomic_ref_dtype == types.float32:
39+
atomic_ref_dtype = types.uint32
40+
elif atomic_ref_dtype == types.float64:
41+
atomic_ref_dtype = types.uint64
42+
2343
atomic_load_fn_retty = context.get_value_type(atomic_ref_dtype)
2444
ptr_type = atomic_load_fn_retty.as_pointer()
2545
ptr_type.addrspace = atomic_ref_ty.address_space
46+
2647
atomic_load_fn_arg_types = [
2748
ptr_type,
2849
llvmir.IntType(32),
@@ -44,6 +65,10 @@ def get_or_insert_atomic_load_fn(context, module, atomic_ref_ty):
4465
)
4566
fn.calling_convention = CC_SPIR_FUNC
4667

68+
if _SUPPORT_CONVERGENT:
69+
fn.attributes.add("convergent")
70+
fn.attributes.add("nounwind")
71+
4772
return fn
4873

4974

@@ -53,9 +78,16 @@ def get_or_insert_spv_atomic_store_fn(context, module, atomic_ref_ty):
5378
specified LLVM IR module.
5479
"""
5580
atomic_ref_dtype = atomic_ref_ty.dtype
81+
82+
if atomic_ref_dtype == types.float32:
83+
atomic_ref_dtype = types.uint32
84+
elif atomic_ref_dtype == types.float64:
85+
atomic_ref_dtype = types.uint64
86+
5687
ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
5788
ptr_type.addrspace = atomic_ref_ty.address_space
5889
atomic_store_fn_retty = llvmir.VoidType()
90+
5991
atomic_store_fn_arg_types = [
6092
ptr_type,
6193
llvmir.IntType(32),
@@ -80,6 +112,10 @@ def get_or_insert_spv_atomic_store_fn(context, module, atomic_ref_ty):
80112
)
81113
fn.calling_convention = CC_SPIR_FUNC
82114

115+
if _SUPPORT_CONVERGENT:
116+
fn.attributes.add("convergent")
117+
fn.attributes.add("nounwind")
118+
83119
return fn
84120

85121

@@ -89,9 +125,16 @@ def get_or_insert_spv_atomic_exchange_fn(context, module, atomic_ref_ty):
89125
specified LLVM IR module.
90126
"""
91127
atomic_ref_dtype = atomic_ref_ty.dtype
128+
129+
if atomic_ref_dtype == types.float32:
130+
atomic_ref_dtype = types.uint32
131+
elif atomic_ref_dtype == types.float64:
132+
atomic_ref_dtype = types.uint64
133+
92134
ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
93135
ptr_type.addrspace = atomic_ref_ty.address_space
94-
atomic_exchange_fn_retty = context.get_value_type(atomic_ref_ty.dtype)
136+
atomic_exchange_fn_retty = context.get_value_type(atomic_ref_dtype)
137+
95138
atomic_exchange_fn_arg_types = [
96139
ptr_type,
97140
llvmir.IntType(32),
@@ -118,6 +161,10 @@ def get_or_insert_spv_atomic_exchange_fn(context, module, atomic_ref_ty):
118161
)
119162
fn.calling_convention = CC_SPIR_FUNC
120163

164+
if _SUPPORT_CONVERGENT:
165+
fn.attributes.add("convergent")
166+
fn.attributes.add("nounwind")
167+
121168
return fn
122169

123170

@@ -174,4 +221,8 @@ def get_or_insert_spv_atomic_compare_exchange_fn(
174221
)
175222
fn.calling_convention = CC_SPIR_FUNC
176223

224+
if _SUPPORT_CONVERGENT:
225+
fn.attributes.add("convergent")
226+
fn.attributes.add("nounwind")
227+
177228
return fn

0 commit comments

Comments
 (0)