Skip to content

Commit b1ac8d6

Browse files
author
Diptorup Deb
committed
Unit test checking if FlagEnum values are lowered as constants in LLVM IR.
1 parent 9ce9927 commit b1ac8d6

File tree

1 file changed

+53
-0
lines changed

1 file changed

+53
-0
lines changed
Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,53 @@
1+
# SPDX-FileCopyrightText: 2023 Intel Corporation
2+
#
3+
# SPDX-License-Identifier: Apache-2.0
4+
5+
import re
6+
7+
import dpctl
8+
from numba.core import types
9+
10+
import numba_dpex.experimental as exp_dpex
11+
from numba_dpex import DpctlSyclQueue, DpnpNdArray, int64
12+
from numba_dpex.experimental.flag_enum import FlagEnum
13+
14+
15+
def test_compilation_as_literal_constant():
16+
"""Tests if FlagEnum objects are treaded as scalar constants inside
17+
numba-dpex generated code.
18+
19+
The test case compiles the kernel `pass_flags_to_func` that includes a
20+
call to the device_func `bitwise_or_flags`. The `bitwise_or_flags` function
21+
is passed two FlagEnum arguments. The test case evaluates the generated
22+
LLVM IR for `pass_flags_to_func` to see if the call to `bitwise_or_flags`
23+
has the scalar arguments `i64 1` and `i64 2`.
24+
"""
25+
26+
class PseudoFlags(FlagEnum):
27+
FLAG1 = 1
28+
FLAG2 = 2
29+
30+
@exp_dpex.device_func
31+
def bitwise_or_flags(flag1, flag2):
32+
return flag1 | flag2
33+
34+
def pass_flags_to_func(a):
35+
f1 = PseudoFlags.FLAG1
36+
f2 = PseudoFlags.FLAG2
37+
a[0] = bitwise_or_flags(f1, f2)
38+
39+
queue_ty = DpctlSyclQueue(dpctl.SyclQueue())
40+
i64arr_ty = DpnpNdArray(ndim=1, dtype=int64, layout="C", queue=queue_ty)
41+
kernel_sig = types.void(i64arr_ty)
42+
43+
disp = exp_dpex.kernel(pass_flags_to_func)
44+
disp.compile(kernel_sig)
45+
kcres = disp.overloads[kernel_sig.args]
46+
llvm_ir_mod = kcres.library._final_module.__str__()
47+
48+
pattern = re.compile(
49+
r"call spir_func i32 @\_Z.*bitwise\_or"
50+
r"\_flags.*\(i64\* nonnull %.*, i64 1, i64 2\)"
51+
)
52+
53+
assert re.search(pattern, llvm_ir_mod) is not None

0 commit comments

Comments
 (0)