Skip to content

Commit 59d5238

Browse files
author
Diptorup Deb
authored
Fix for atomic load, store, exchange failure on ocl (#1336)
This PR fixes the failure that was occurring `opencl:cpu` when atomic load, store and exchange operations were being used. Spirv expects arguments to be of integer types always. This PR introduces a `bitcast` to integer types when the argument are of floating point types. This PR also added additional function attributes to these atomic function declarations, keeping it the same as the declaration generated for DPC++. It further updates the test cases to remove expected failure on `opencl:cpu` executions.
2 parents 59cd572 + 8be6115 commit 59d5238

File tree

4 files changed

+167
-86
lines changed

4 files changed

+167
-86
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py

Lines changed: 108 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -231,27 +231,52 @@ def _intrinsic_load(
231231

232232
def _intrinsic_load_gen(context, builder, sig, args):
233233
atomic_ref_ty = sig.args[0]
234-
fn = get_or_insert_atomic_load_fn(
235-
context, builder.module, atomic_ref_ty
236-
)
237234

238-
spirv_memory_semantics_mask = get_memory_semantics_mask(
239-
atomic_ref_ty.memory_order
235+
atomic_ref_ptr = builder.extract_value(
236+
args[0],
237+
context.data_model_manager.lookup(atomic_ref_ty).get_field_position(
238+
"ref"
239+
),
240240
)
241-
spirv_scope = get_scope(atomic_ref_ty.memory_scope)
241+
if sig.args[0].dtype == types.float32:
242+
atomic_ref_ptr = builder.bitcast(
243+
atomic_ref_ptr,
244+
llvmir.PointerType(
245+
llvmir.IntType(32), addrspace=sig.args[0].address_space
246+
),
247+
)
248+
elif sig.args[0].dtype == types.float64:
249+
atomic_ref_ptr = builder.bitcast(
250+
atomic_ref_ptr,
251+
llvmir.PointerType(
252+
llvmir.IntType(64), addrspace=sig.args[0].address_space
253+
),
254+
)
242255

243256
fn_args = [
244-
builder.extract_value(
245-
args[0],
246-
context.data_model_manager.lookup(
247-
atomic_ref_ty
248-
).get_field_position("ref"),
257+
atomic_ref_ptr,
258+
context.get_constant(
259+
types.int32, get_scope(atomic_ref_ty.memory_scope)
260+
),
261+
context.get_constant(
262+
types.int32,
263+
get_memory_semantics_mask(atomic_ref_ty.memory_order),
249264
),
250-
context.get_constant(types.int32, spirv_scope),
251-
context.get_constant(types.int32, spirv_memory_semantics_mask),
252265
]
253266

254-
return builder.call(fn, fn_args)
267+
ret_val = builder.call(
268+
get_or_insert_atomic_load_fn(
269+
context, builder.module, atomic_ref_ty
270+
),
271+
fn_args,
272+
)
273+
274+
if sig.args[0].dtype == types.float32:
275+
ret_val = builder.bitcast(ret_val, llvmir.FloatType())
276+
elif sig.args[0].dtype == types.float64:
277+
ret_val = builder.bitcast(ret_val, llvmir.DoubleType())
278+
279+
return ret_val
255280

256281
return sig, _intrinsic_load_gen
257282

@@ -264,28 +289,49 @@ def _intrinsic_store(
264289

265290
def _intrinsic_store_gen(context, builder, sig, args):
266291
atomic_ref_ty = sig.args[0]
267-
atomic_store_fn = get_or_insert_spv_atomic_store_fn(
268-
context, builder.module, atomic_ref_ty
292+
293+
store_arg = args[1]
294+
atomic_ref_ptr = builder.extract_value(
295+
args[0],
296+
context.data_model_manager.lookup(atomic_ref_ty).get_field_position(
297+
"ref"
298+
),
269299
)
300+
if sig.args[0].dtype == types.float32:
301+
atomic_ref_ptr = builder.bitcast(
302+
atomic_ref_ptr,
303+
llvmir.PointerType(
304+
llvmir.IntType(32), addrspace=sig.args[0].address_space
305+
),
306+
)
307+
store_arg = builder.bitcast(store_arg, llvmir.IntType(32))
308+
elif sig.args[0].dtype == types.float64:
309+
atomic_ref_ptr = builder.bitcast(
310+
atomic_ref_ptr,
311+
llvmir.PointerType(
312+
llvmir.IntType(64), addrspace=sig.args[0].address_space
313+
),
314+
)
315+
store_arg = builder.bitcast(store_arg, llvmir.IntType(64))
270316

271317
atomic_store_fn_args = [
272-
builder.extract_value(
273-
args[0],
274-
context.data_model_manager.lookup(
275-
atomic_ref_ty
276-
).get_field_position("ref"),
277-
),
318+
atomic_ref_ptr,
278319
context.get_constant(
279320
types.int32, get_scope(atomic_ref_ty.memory_scope)
280321
),
281322
context.get_constant(
282323
types.int32,
283324
get_memory_semantics_mask(atomic_ref_ty.memory_order),
284325
),
285-
args[1],
326+
store_arg,
286327
]
287328

288-
builder.call(atomic_store_fn, atomic_store_fn_args)
329+
builder.call(
330+
get_or_insert_spv_atomic_store_fn(
331+
context, builder.module, atomic_ref_ty
332+
),
333+
atomic_store_fn_args,
334+
)
289335

290336
return sig, _intrinsic_store_gen
291337

@@ -298,28 +344,56 @@ def _intrinsic_exchange(
298344

299345
def _intrinsic_exchange_gen(context, builder, sig, args):
300346
atomic_ref_ty = sig.args[0]
301-
atomic_exchange_fn = get_or_insert_spv_atomic_exchange_fn(
302-
context, builder.module, atomic_ref_ty
347+
348+
exchange_arg = args[1]
349+
atomic_ref_ptr = builder.extract_value(
350+
args[0],
351+
context.data_model_manager.lookup(atomic_ref_ty).get_field_position(
352+
"ref"
353+
),
303354
)
355+
if sig.args[0].dtype == types.float32:
356+
atomic_ref_ptr = builder.bitcast(
357+
atomic_ref_ptr,
358+
llvmir.PointerType(
359+
llvmir.IntType(32), addrspace=sig.args[0].address_space
360+
),
361+
)
362+
exchange_arg = builder.bitcast(exchange_arg, llvmir.IntType(32))
363+
elif sig.args[0].dtype == types.float64:
364+
atomic_ref_ptr = builder.bitcast(
365+
atomic_ref_ptr,
366+
llvmir.PointerType(
367+
llvmir.IntType(64), addrspace=sig.args[0].address_space
368+
),
369+
)
370+
exchange_arg = builder.bitcast(exchange_arg, llvmir.IntType(64))
304371

305372
atomic_exchange_fn_args = [
306-
builder.extract_value(
307-
args[0],
308-
context.data_model_manager.lookup(
309-
atomic_ref_ty
310-
).get_field_position("ref"),
311-
),
373+
atomic_ref_ptr,
312374
context.get_constant(
313375
types.int32, get_scope(atomic_ref_ty.memory_scope)
314376
),
315377
context.get_constant(
316378
types.int32,
317379
get_memory_semantics_mask(atomic_ref_ty.memory_order),
318380
),
319-
args[1],
381+
exchange_arg,
320382
]
321383

322-
return builder.call(atomic_exchange_fn, atomic_exchange_fn_args)
384+
ret_val = builder.call(
385+
get_or_insert_spv_atomic_exchange_fn(
386+
context, builder.module, atomic_ref_ty
387+
),
388+
atomic_exchange_fn_args,
389+
)
390+
391+
if sig.args[0].dtype == types.float32:
392+
ret_val = builder.bitcast(ret_val, llvmir.FloatType())
393+
elif sig.args[0].dtype == types.float64:
394+
ret_val = builder.bitcast(ret_val, llvmir.DoubleType())
395+
396+
return ret_val
323397

324398
return sig, _intrinsic_exchange_gen
325399

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_group_barrier_overloads.py

Lines changed: 1 addition & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
Provides overloads for functions included in kernel_iface.barrier that
77
generate dpcpp SPIR-V LLVM IR intrinsic function calls.
88
"""
9-
import warnings
109

1110
from llvmlite import ir as llvmir
1211
from numba.core import cgutils, types
@@ -20,18 +19,7 @@
2019
from numba_dpex.kernel_api.memory_enums import MemoryOrder, MemoryScope
2120

2221
from ._spv_atomic_inst_helper import get_memory_semantics_mask, get_scope
23-
24-
_SUPPORT_CONVERGENT = True
25-
26-
try:
27-
llvmir.FunctionAttributes("convergent")
28-
except ValueError:
29-
warnings.warn(
30-
"convergent attribute is supported only starting llvmlite "
31-
+ "0.42. Not setting this attribute may result in unexpected behavior"
32-
+ "when using group_barrier"
33-
)
34-
_SUPPORT_CONVERGENT = False
22+
from .spv_atomic_fn_declarations import _SUPPORT_CONVERGENT
3523

3624

3725
def _get_memory_scope(fence_scope):

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/spv_atomic_fn_declarations.py

Lines changed: 52 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,22 +7,43 @@
77
functions and their use inside an LLVM module.
88
"""
99

10+
import warnings
11+
1012
from llvmlite import ir as llvmir
1113
from numba.core import cgutils, types
1214

1315
from numba_dpex.core import itanium_mangler as ext_itanium_mangler
1416
from numba_dpex.kernel_api_impl.spirv.target import CC_SPIR_FUNC
1517

18+
_SUPPORT_CONVERGENT = True
19+
20+
try:
21+
llvmir.FunctionAttributes("convergent")
22+
except ValueError:
23+
warnings.warn(
24+
"convergent attribute is supported only starting llvmlite "
25+
+ "0.42. Not setting this attribute may result in unexpected behavior"
26+
+ "when using group_barrier"
27+
)
28+
_SUPPORT_CONVERGENT = False
29+
1630

1731
def get_or_insert_atomic_load_fn(context, module, atomic_ref_ty):
1832
"""
1933
Gets or inserts a declaration for a __spirv_AtomicLoad call into the
2034
specified LLVM IR module.
2135
"""
2236
atomic_ref_dtype = atomic_ref_ty.dtype
37+
38+
if atomic_ref_dtype == types.float32:
39+
atomic_ref_dtype = types.uint32
40+
elif atomic_ref_dtype == types.float64:
41+
atomic_ref_dtype = types.uint64
42+
2343
atomic_load_fn_retty = context.get_value_type(atomic_ref_dtype)
2444
ptr_type = atomic_load_fn_retty.as_pointer()
2545
ptr_type.addrspace = atomic_ref_ty.address_space
46+
2647
atomic_load_fn_arg_types = [
2748
ptr_type,
2849
llvmir.IntType(32),
@@ -44,6 +65,10 @@ def get_or_insert_atomic_load_fn(context, module, atomic_ref_ty):
4465
)
4566
fn.calling_convention = CC_SPIR_FUNC
4667

68+
if _SUPPORT_CONVERGENT:
69+
fn.attributes.add("convergent")
70+
fn.attributes.add("nounwind")
71+
4772
return fn
4873

4974

@@ -53,9 +78,16 @@ def get_or_insert_spv_atomic_store_fn(context, module, atomic_ref_ty):
5378
specified LLVM IR module.
5479
"""
5580
atomic_ref_dtype = atomic_ref_ty.dtype
81+
82+
if atomic_ref_dtype == types.float32:
83+
atomic_ref_dtype = types.uint32
84+
elif atomic_ref_dtype == types.float64:
85+
atomic_ref_dtype = types.uint64
86+
5687
ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
5788
ptr_type.addrspace = atomic_ref_ty.address_space
5889
atomic_store_fn_retty = llvmir.VoidType()
90+
5991
atomic_store_fn_arg_types = [
6092
ptr_type,
6193
llvmir.IntType(32),
@@ -80,6 +112,10 @@ def get_or_insert_spv_atomic_store_fn(context, module, atomic_ref_ty):
80112
)
81113
fn.calling_convention = CC_SPIR_FUNC
82114

115+
if _SUPPORT_CONVERGENT:
116+
fn.attributes.add("convergent")
117+
fn.attributes.add("nounwind")
118+
83119
return fn
84120

85121

@@ -89,9 +125,16 @@ def get_or_insert_spv_atomic_exchange_fn(context, module, atomic_ref_ty):
89125
specified LLVM IR module.
90126
"""
91127
atomic_ref_dtype = atomic_ref_ty.dtype
128+
129+
if atomic_ref_dtype == types.float32:
130+
atomic_ref_dtype = types.uint32
131+
elif atomic_ref_dtype == types.float64:
132+
atomic_ref_dtype = types.uint64
133+
92134
ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
93135
ptr_type.addrspace = atomic_ref_ty.address_space
94-
atomic_exchange_fn_retty = context.get_value_type(atomic_ref_ty.dtype)
136+
atomic_exchange_fn_retty = context.get_value_type(atomic_ref_dtype)
137+
95138
atomic_exchange_fn_arg_types = [
96139
ptr_type,
97140
llvmir.IntType(32),
@@ -118,6 +161,10 @@ def get_or_insert_spv_atomic_exchange_fn(context, module, atomic_ref_ty):
118161
)
119162
fn.calling_convention = CC_SPIR_FUNC
120163

164+
if _SUPPORT_CONVERGENT:
165+
fn.attributes.add("convergent")
166+
fn.attributes.add("nounwind")
167+
121168
return fn
122169

123170

@@ -174,4 +221,8 @@ def get_or_insert_spv_atomic_compare_exchange_fn(
174221
)
175222
fn.calling_convention = CC_SPIR_FUNC
176223

224+
if _SUPPORT_CONVERGENT:
225+
fn.attributes.add("convergent")
226+
fn.attributes.add("nounwind")
227+
177228
return fn

0 commit comments

Comments
 (0)