Skip to content

Commit ae7d68e

Browse files
author
Diptorup Deb
committed
Refactors the intrinsic functions for load, store, and exchange.
- Removes the helper function for intrinsic codegen for atomic store and atomic exchange. - Adds a new module that has helper functions for inserting the LLVM IR module-level declaration for individual SPV functions.
1 parent a1148f1 commit ae7d68e

File tree

2 files changed

+178
-102
lines changed

2 files changed

+178
-102
lines changed

numba_dpex/experimental/_kernel_dpcpp_spirv_overloads/_atomic_ref_overloads.py

Lines changed: 57 additions & 102 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,11 @@
2424
get_memory_semantics_mask,
2525
get_scope,
2626
)
27+
from .spv_fn_generator import (
28+
get_or_insert_atomic_load_fn,
29+
get_or_insert_spv_atomic_exchange_fn,
30+
get_or_insert_spv_atomic_store_fn,
31+
)
2732

2833

2934
def _parse_enum_or_int_literal_(literal_int) -> int:
@@ -217,44 +222,22 @@ def _intrinsic_load(
217222

218223
def _intrinsic_load_gen(context, builder, sig, args):
219224
atomic_ref_ty = sig.args[0]
220-
atomic_ref_dtype = atomic_ref_ty.dtype
221-
retty = context.get_value_type(atomic_ref_dtype)
222-
223-
data_attr_pos = context.data_model_manager.lookup(
224-
atomic_ref_ty
225-
).get_field_position("ref")
226-
227-
ptr_type = retty.as_pointer()
228-
ptr_type.addrspace = atomic_ref_ty.address_space
229-
230-
spirv_fn_arg_types = [
231-
ptr_type,
232-
llvmir.IntType(32),
233-
llvmir.IntType(32),
234-
]
235-
236-
mangled_fn_name = ext_itanium_mangler.mangle_ext(
237-
"__spirv_AtomicLoad",
238-
[
239-
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
240-
"__spv.Scope.Flag",
241-
"__spv.MemorySemanticsMask.Flag",
242-
],
225+
fn = get_or_insert_atomic_load_fn(
226+
context, builder.module, atomic_ref_ty
243227
)
244228

245-
fn = cgutils.get_or_insert_function(
246-
builder.module,
247-
llvmir.FunctionType(retty, spirv_fn_arg_types),
248-
mangled_fn_name,
249-
)
250-
fn.calling_convention = CC_SPIR_FUNC
251229
spirv_memory_semantics_mask = get_memory_semantics_mask(
252230
atomic_ref_ty.memory_order
253231
)
254232
spirv_scope = get_scope(atomic_ref_ty.memory_scope)
255233

256234
fn_args = [
257-
builder.extract_value(args[0], data_attr_pos),
235+
builder.extract_value(
236+
args[0],
237+
context.data_model_manager.lookup(
238+
atomic_ref_ty
239+
).get_field_position("ref"),
240+
),
258241
context.get_constant(types.int32, spirv_scope),
259242
context.get_constant(types.int32, spirv_memory_semantics_mask),
260243
]
@@ -264,76 +247,37 @@ def _intrinsic_load_gen(context, builder, sig, args):
264247
return sig, _intrinsic_load_gen
265248

266249

267-
def _store_exchange_intrisic_helper(context, builder, sig, ol_info: dict):
268-
atomic_ref_ty = sig.args[0]
269-
atomic_ref_dtype = atomic_ref_ty.dtype
270-
271-
ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
272-
ptr_type.addrspace = atomic_ref_ty.address_space
273-
274-
spirv_fn_arg_types = [
275-
ptr_type,
276-
llvmir.IntType(32),
277-
llvmir.IntType(32),
278-
context.get_value_type(atomic_ref_dtype),
279-
]
280-
281-
mangled_fn_name = ext_itanium_mangler.mangle_ext(
282-
ol_info["name"],
283-
[
284-
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
285-
"__spv.Scope.Flag",
286-
"__spv.MemorySemanticsMask.Flag",
287-
atomic_ref_dtype,
288-
],
289-
)
290-
291-
fn = cgutils.get_or_insert_function(
292-
builder.module,
293-
llvmir.FunctionType(ol_info["retty"], spirv_fn_arg_types),
294-
mangled_fn_name,
295-
)
296-
fn.calling_convention = CC_SPIR_FUNC
297-
298-
fn_args = [
299-
builder.extract_value(
300-
ol_info["args"][0],
301-
context.data_model_manager.lookup(atomic_ref_ty).get_field_position(
302-
"ref"
303-
),
304-
),
305-
context.get_constant(
306-
types.int32, get_scope(atomic_ref_ty.memory_scope)
307-
),
308-
context.get_constant(
309-
types.int32, get_memory_semantics_mask(atomic_ref_ty.memory_order)
310-
),
311-
ol_info["args"][1],
312-
]
313-
314-
return builder.call(fn, fn_args)
315-
316-
317250
@intrinsic(target=DPEX_KERNEL_EXP_TARGET_NAME)
318251
def _intrinsic_store(
319252
ty_context, ty_atomic_ref, ty_val
320253
): # pylint: disable=unused-argument
321254
sig = types.void(ty_atomic_ref, ty_val)
322255

323256
def _intrinsic_store_gen(context, builder, sig, args):
324-
_store_exchange_intrisic_helper(
325-
context,
326-
builder,
327-
sig,
328-
# dict containing arguments, return type,
329-
# spirv fn name driven by pylint too-many-args
330-
{
331-
"args": args,
332-
"retty": llvmir.VoidType(),
333-
"name": "__spirv_AtomicStore",
334-
},
257+
atomic_ref_ty = sig.args[0]
258+
atomic_store_fn = get_or_insert_spv_atomic_store_fn(
259+
context, builder.module, atomic_ref_ty
335260
)
336261

262+
atomic_store_fn_args = [
263+
builder.extract_value(
264+
args[0],
265+
context.data_model_manager.lookup(
266+
atomic_ref_ty
267+
).get_field_position("ref"),
268+
),
269+
context.get_constant(
270+
types.int32, get_scope(atomic_ref_ty.memory_scope)
271+
),
272+
context.get_constant(
273+
types.int32,
274+
get_memory_semantics_mask(atomic_ref_ty.memory_order),
275+
),
276+
args[1],
277+
]
278+
279+
builder.call(atomic_store_fn, atomic_store_fn_args)
280+
337281
return sig, _intrinsic_store_gen
338282

339283

@@ -344,19 +288,30 @@ def _intrinsic_exchange(
344288
sig = ty_atomic_ref.dtype(ty_atomic_ref, ty_val)
345289

346290
def _intrinsic_exchange_gen(context, builder, sig, args):
347-
return _store_exchange_intrisic_helper(
348-
context,
349-
builder,
350-
sig,
351-
# dict containing arguments, return type,
352-
# spirv fn name driven by pylint too-many-args
353-
{
354-
"args": args,
355-
"retty": context.get_value_type(sig.args[0].dtype),
356-
"name": "__spirv_AtomicExchange",
357-
},
291+
atomic_ref_ty = sig.args[0]
292+
atomic_exchange_fn = get_or_insert_spv_atomic_exchange_fn(
293+
context, builder.module, atomic_ref_ty
358294
)
359295

296+
atomic_exchange_fn_args = [
297+
builder.extract_value(
298+
args[0],
299+
context.data_model_manager.lookup(
300+
atomic_ref_ty
301+
).get_field_position("ref"),
302+
),
303+
context.get_constant(
304+
types.int32, get_scope(atomic_ref_ty.memory_scope)
305+
),
306+
context.get_constant(
307+
types.int32,
308+
get_memory_semantics_mask(atomic_ref_ty.memory_order),
309+
),
310+
args[1],
311+
]
312+
313+
return builder.call(atomic_exchange_fn, atomic_exchange_fn_args)
314+
360315
return sig, _intrinsic_exchange_gen
361316

362317

Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
# SPDX-FileCopyrightText: 2023 - 2024 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
"""
6+
Implements a set of helper functions to generate the LLVM IR for SPIR-V
7+
functions and their use inside an LLVM module.
8+
"""
9+
10+
from llvmlite import ir as llvmir
11+
from numba.core import cgutils, types
12+
13+
from numba_dpex.core import itanium_mangler as ext_itanium_mangler
14+
from numba_dpex.core.targets.kernel_target import CC_SPIR_FUNC
15+
16+
17+
def get_or_insert_atomic_load_fn(context, module, atomic_ref_ty):
18+
"""
19+
Gets or inserts a declaration for a __spirv_AtomicLoad call into the
20+
specified LLVM IR module.
21+
"""
22+
atomic_ref_dtype = atomic_ref_ty.dtype
23+
atomic_load_fn_retty = context.get_value_type(atomic_ref_dtype)
24+
ptr_type = atomic_load_fn_retty.as_pointer()
25+
ptr_type.addrspace = atomic_ref_ty.address_space
26+
atomic_load_fn_arg_types = [
27+
ptr_type,
28+
llvmir.IntType(32),
29+
llvmir.IntType(32),
30+
]
31+
mangled_fn_name = ext_itanium_mangler.mangle_ext(
32+
"__spirv_AtomicLoad",
33+
[
34+
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
35+
"__spv.Scope.Flag",
36+
"__spv.MemorySemanticsMask.Flag",
37+
],
38+
)
39+
40+
fn = cgutils.get_or_insert_function(
41+
module,
42+
llvmir.FunctionType(atomic_load_fn_retty, atomic_load_fn_arg_types),
43+
mangled_fn_name,
44+
)
45+
fn.calling_convention = CC_SPIR_FUNC
46+
47+
return fn
48+
49+
50+
def get_or_insert_spv_atomic_store_fn(context, module, atomic_ref_ty):
51+
"""
52+
Gets or inserts a declaration for a __spirv_AtomicStore call into the
53+
specified LLVM IR module.
54+
"""
55+
atomic_ref_dtype = atomic_ref_ty.dtype
56+
ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
57+
ptr_type.addrspace = atomic_ref_ty.address_space
58+
atomic_store_fn_retty = llvmir.VoidType()
59+
atomic_store_fn_arg_types = [
60+
ptr_type,
61+
llvmir.IntType(32),
62+
llvmir.IntType(32),
63+
context.get_value_type(atomic_ref_dtype),
64+
]
65+
66+
mangled_fn_name = ext_itanium_mangler.mangle_ext(
67+
"__spirv_AtomicStore",
68+
[
69+
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
70+
"__spv.Scope.Flag",
71+
"__spv.MemorySemanticsMask.Flag",
72+
atomic_ref_dtype,
73+
],
74+
)
75+
76+
fn = cgutils.get_or_insert_function(
77+
module,
78+
llvmir.FunctionType(atomic_store_fn_retty, atomic_store_fn_arg_types),
79+
mangled_fn_name,
80+
)
81+
fn.calling_convention = CC_SPIR_FUNC
82+
83+
return fn
84+
85+
86+
def get_or_insert_spv_atomic_exchange_fn(context, module, atomic_ref_ty):
87+
"""
88+
Gets or inserts a declaration for a __spirv_AtomicExchange call into the
89+
specified LLVM IR module.
90+
"""
91+
atomic_ref_dtype = atomic_ref_ty.dtype
92+
ptr_type = context.get_value_type(atomic_ref_dtype).as_pointer()
93+
ptr_type.addrspace = atomic_ref_ty.address_space
94+
atomic_exchange_fn_retty = context.get_value_type(atomic_ref_ty.dtype)
95+
atomic_exchange_fn_arg_types = [
96+
ptr_type,
97+
llvmir.IntType(32),
98+
llvmir.IntType(32),
99+
context.get_value_type(atomic_ref_dtype),
100+
]
101+
102+
mangled_fn_name = ext_itanium_mangler.mangle_ext(
103+
"__spirv_AtomicExchange",
104+
[
105+
types.CPointer(atomic_ref_dtype, addrspace=ptr_type.addrspace),
106+
"__spv.Scope.Flag",
107+
"__spv.MemorySemanticsMask.Flag",
108+
atomic_ref_dtype,
109+
],
110+
)
111+
112+
fn = cgutils.get_or_insert_function(
113+
module,
114+
llvmir.FunctionType(
115+
atomic_exchange_fn_retty, atomic_exchange_fn_arg_types
116+
),
117+
mangled_fn_name,
118+
)
119+
fn.calling_convention = CC_SPIR_FUNC
120+
121+
return fn

0 commit comments

Comments
 (0)