@@ -31,6 +31,123 @@ def get_spirv_abspath(fn):
31
31
return spirv_file
32
32
33
33
34
+ def _check_cpython_api_SyclProgram_GetKernelBundleRef (sycl_prog ):
35
+ """Checks Cython-generated C-API function
36
+ `SyclProgram_GetKernelBundleRef` defined in _program.pyx"""
37
+ import ctypes
38
+ import sys
39
+
40
+ assert type (sycl_prog ) is dpctl_prog .SyclProgram
41
+ mod = sys .modules [sycl_prog .__class__ .__module__ ]
42
+ # get capsule storing SyclProgram_GetKernelBundleRef function ptr
43
+ kb_ref_fn_cap = mod .__pyx_capi__ ["SyclProgram_GetKernelBundleRef" ]
44
+ # construct Python callable to invoke "SyclProgram_GetKernelBundleRef"
45
+ cap_ptr_fn = ctypes .pythonapi .PyCapsule_GetPointer
46
+ cap_ptr_fn .restype = ctypes .c_void_p
47
+ cap_ptr_fn .argtypes = [ctypes .py_object , ctypes .c_char_p ]
48
+ kb_ref_fn_ptr = cap_ptr_fn (
49
+ kb_ref_fn_cap ,
50
+ b"DPCTLSyclKernelBundleRef (struct PySyclProgramObject *)" ,
51
+ )
52
+ # PYFUNCTYPE(result_type, *arg_types)
53
+ callable_maker = ctypes .PYFUNCTYPE (ctypes .c_void_p , ctypes .py_object )
54
+ get_kernel_bundle_ref_fn = callable_maker (kb_ref_fn_ptr )
55
+
56
+ r2 = sycl_prog .addressof_ref ()
57
+ r1 = get_kernel_bundle_ref_fn (sycl_prog )
58
+ assert r1 == r2
59
+
60
+
61
+ def _check_cpython_api_SyclProgram_Make (sycl_prog ):
62
+ """Checks Cython-generated C-API function
63
+ `SyclProgram_Make` defined in _program.pyx"""
64
+ import ctypes
65
+ import sys
66
+
67
+ assert type (sycl_prog ) is dpctl_prog .SyclProgram
68
+ mod = sys .modules [sycl_prog .__class__ .__module__ ]
69
+ # get capsule storing SyclProgram_Make function ptr
70
+ make_prog_fn_cap = mod .__pyx_capi__ ["SyclProgram_Make" ]
71
+ # construct Python callable to invoke "SyclProgram_Make"
72
+ cap_ptr_fn = ctypes .pythonapi .PyCapsule_GetPointer
73
+ cap_ptr_fn .restype = ctypes .c_void_p
74
+ cap_ptr_fn .argtypes = [ctypes .py_object , ctypes .c_char_p ]
75
+ make_prog_fn_ptr = cap_ptr_fn (
76
+ make_prog_fn_cap ,
77
+ b"struct PySyclProgramObject *(DPCTLSyclKernelBundleRef)" ,
78
+ )
79
+ # PYFUNCTYPE(result_type, *arg_types)
80
+ callable_maker = ctypes .PYFUNCTYPE (ctypes .py_object , ctypes .c_void_p )
81
+ make_prog_fn = callable_maker (make_prog_fn_ptr )
82
+
83
+ p2 = make_prog_fn (sycl_prog .addressof_ref ())
84
+ assert p2 .has_sycl_kernel ("add" )
85
+ assert p2 .has_sycl_kernel ("axpy" )
86
+
87
+
88
+ def _check_cpython_api_SyclKernel_GetKernelRef (krn ):
89
+ """Checks Cython-generated C-API function
90
+ `SyclKernel_GetKernelRef` defined in _program.pyx"""
91
+ import ctypes
92
+ import sys
93
+
94
+ assert type (krn ) is dpctl_prog .SyclKernel
95
+ mod = sys .modules [krn .__class__ .__module__ ]
96
+ # get capsule storing SyclKernel_GetKernelRef function ptr
97
+ k_ref_fn_cap = mod .__pyx_capi__ ["SyclKernel_GetKernelRef" ]
98
+ # construct Python callable to invoke "SyclKernel_GetKernelRef"
99
+ cap_ptr_fn = ctypes .pythonapi .PyCapsule_GetPointer
100
+ cap_ptr_fn .restype = ctypes .c_void_p
101
+ cap_ptr_fn .argtypes = [ctypes .py_object , ctypes .c_char_p ]
102
+ k_ref_fn_ptr = cap_ptr_fn (
103
+ k_ref_fn_cap , b"DPCTLSyclKernelRef (struct PySyclKernelObject *)"
104
+ )
105
+ # PYFUNCTYPE(result_type, *arg_types)
106
+ callable_maker = ctypes .PYFUNCTYPE (ctypes .c_void_p , ctypes .py_object )
107
+ get_kernel_ref_fn = callable_maker (k_ref_fn_ptr )
108
+
109
+ r2 = krn .addressof_ref ()
110
+ r1 = get_kernel_ref_fn (krn )
111
+ assert r1 == r2
112
+
113
+
114
+ def _check_cpython_api_SyclKernel_Make (krn ):
115
+ """Checks Cython-generated C-API function
116
+ `SyclKernel_Make` defined in _program.pyx"""
117
+ import ctypes
118
+ import sys
119
+
120
+ assert type (krn ) is dpctl_prog .SyclKernel
121
+ mod = sys .modules [krn .__class__ .__module__ ]
122
+ # get capsule storing SyclKernel_Make function ptr
123
+ k_make_fn_cap = mod .__pyx_capi__ ["SyclKernel_Make" ]
124
+ # construct Python callable to invoke "SyclKernel_Make"
125
+ cap_ptr_fn = ctypes .pythonapi .PyCapsule_GetPointer
126
+ cap_ptr_fn .restype = ctypes .c_void_p
127
+ cap_ptr_fn .argtypes = [ctypes .py_object , ctypes .c_char_p ]
128
+ k_make_fn_ptr = cap_ptr_fn (
129
+ k_make_fn_cap ,
130
+ b"struct PySyclKernelObject *(DPCTLSyclKernelRef, char const *)" ,
131
+ )
132
+ # PYFUNCTYPE(result_type, *arg_types)
133
+ callable_maker = ctypes .PYFUNCTYPE (
134
+ ctypes .py_object , ctypes .c_void_p , ctypes .c_void_p
135
+ )
136
+ make_kernel_fn = callable_maker (k_make_fn_ptr )
137
+
138
+ k2 = make_kernel_fn (
139
+ krn .addressof_ref (), bytes (krn .get_function_name (), "utf-8" )
140
+ )
141
+ assert krn .get_function_name () == k2 .get_function_name ()
142
+ assert krn .get_num_args () == k2 .get_num_args ()
143
+ assert krn .work_group_size == k2 .work_group_size
144
+
145
+ k3 = make_kernel_fn (krn .addressof_ref (), ctypes .c_void_p (None ))
146
+ assert k3 .get_function_name () == "default_name"
147
+ assert krn .get_num_args () == k3 .get_num_args ()
148
+ assert krn .work_group_size == k3 .work_group_size
149
+
150
+
34
151
def _check_multi_kernel_program (prog ):
35
152
assert type (prog ) is dpctl_prog .SyclProgram
36
153
@@ -49,6 +166,9 @@ def _check_multi_kernel_program(prog):
49
166
assert type (axpyKernel .addressof_ref ()) is int
50
167
51
168
for krn in [addKernel , axpyKernel ]:
169
+ _check_cpython_api_SyclKernel_GetKernelRef (krn )
170
+ _check_cpython_api_SyclKernel_Make (krn )
171
+
52
172
na = krn .num_args
53
173
assert na == krn .get_num_args ()
54
174
wgsz = krn .work_group_size
@@ -68,6 +188,9 @@ def _check_multi_kernel_program(prog):
68
188
cmsgsz = krn .compile_sub_group_size
69
189
assert type (cmsgsz ) is int
70
190
191
+ _check_cpython_api_SyclProgram_GetKernelBundleRef (prog )
192
+ _check_cpython_api_SyclProgram_Make (prog )
193
+
71
194
72
195
def test_create_program_from_source_ocl ():
73
196
oclSrc = " \
0 commit comments