Skip to content

Commit dc592f0

Browse files
Merge pull request #448 from IntelPython/remove-context-manager-in-test-sycl-program
used dpctl.SyclQueue instead of manager and get current queue
2 parents 5a5170b + 73764ae commit dc592f0

File tree

1 file changed

+26
-30
lines changed

1 file changed

+26
-30
lines changed

dpctl/tests/test_sycl_program.py

Lines changed: 26 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -38,21 +38,20 @@ def test_create_program_from_source(self):
3838
size_t index = get_global_id(0); \
3939
c[index] = a[index] + d*b[index]; \
4040
}"
41-
with dpctl.device_context("opencl:gpu:0"):
42-
q = dpctl.get_current_queue()
43-
prog = dpctl_prog.create_program_from_source(q, oclSrc)
44-
self.assertIsNotNone(prog)
41+
q = dpctl.SyclQueue("opencl:gpu")
42+
prog = dpctl_prog.create_program_from_source(q, oclSrc)
43+
self.assertIsNotNone(prog)
4544

46-
self.assertTrue(prog.has_sycl_kernel("add"))
47-
self.assertTrue(prog.has_sycl_kernel("axpy"))
45+
self.assertTrue(prog.has_sycl_kernel("add"))
46+
self.assertTrue(prog.has_sycl_kernel("axpy"))
4847

49-
addKernel = prog.get_sycl_kernel("add")
50-
axpyKernel = prog.get_sycl_kernel("axpy")
48+
addKernel = prog.get_sycl_kernel("add")
49+
axpyKernel = prog.get_sycl_kernel("axpy")
5150

52-
self.assertEqual(addKernel.get_function_name(), "add")
53-
self.assertEqual(axpyKernel.get_function_name(), "axpy")
54-
self.assertEqual(addKernel.get_num_args(), 3)
55-
self.assertEqual(axpyKernel.get_num_args(), 4)
51+
self.assertEqual(addKernel.get_function_name(), "add")
52+
self.assertEqual(axpyKernel.get_function_name(), "axpy")
53+
self.assertEqual(addKernel.get_num_args(), 3)
54+
self.assertEqual(axpyKernel.get_num_args(), 4)
5655

5756

5857
@unittest.skipUnless(has_gpu(), "No OpenCL GPU queues available")
@@ -63,20 +62,19 @@ def test_create_program_from_spirv(self):
6362
spirv_file = os.path.join(CURR_DIR, "input_files/multi_kernel.spv")
6463
with open(spirv_file, "rb") as fin:
6564
spirv = fin.read()
66-
with dpctl.device_context("opencl:gpu:0"):
67-
q = dpctl.get_current_queue()
68-
prog = dpctl_prog.create_program_from_spirv(q, spirv)
69-
self.assertIsNotNone(prog)
70-
self.assertTrue(prog.has_sycl_kernel("add"))
71-
self.assertTrue(prog.has_sycl_kernel("axpy"))
65+
q = dpctl.SyclQueue("opencl:gpu")
66+
prog = dpctl_prog.create_program_from_spirv(q, spirv)
67+
self.assertIsNotNone(prog)
68+
self.assertTrue(prog.has_sycl_kernel("add"))
69+
self.assertTrue(prog.has_sycl_kernel("axpy"))
7270

73-
addKernel = prog.get_sycl_kernel("add")
74-
axpyKernel = prog.get_sycl_kernel("axpy")
71+
addKernel = prog.get_sycl_kernel("add")
72+
axpyKernel = prog.get_sycl_kernel("axpy")
7573

76-
self.assertEqual(addKernel.get_function_name(), "add")
77-
self.assertEqual(axpyKernel.get_function_name(), "axpy")
78-
self.assertEqual(addKernel.get_num_args(), 3)
79-
self.assertEqual(axpyKernel.get_num_args(), 4)
74+
self.assertEqual(addKernel.get_function_name(), "add")
75+
self.assertEqual(axpyKernel.get_function_name(), "axpy")
76+
self.assertEqual(addKernel.get_num_args(), 3)
77+
self.assertEqual(axpyKernel.get_num_args(), 4)
8078

8179

8280
@unittest.skipUnless(
@@ -98,9 +96,8 @@ def test_create_program_from_spirv(self):
9896
spirv_file = os.path.join(CURR_DIR, "input_files/multi_kernel.spv")
9997
with open(spirv_file, "rb") as fin:
10098
spirv = fin.read()
101-
with dpctl.device_context("level_zero:gpu:0"):
102-
q = dpctl.get_current_queue()
103-
dpctl_prog.create_program_from_spirv(q, spirv)
99+
q = dpctl.SyclQueue("level_zero:gpu")
100+
dpctl_prog.create_program_from_spirv(q, spirv)
104101

105102
@unittest.expectedFailure
106103
def test_create_program_from_source(self):
@@ -113,9 +110,8 @@ def test_create_program_from_source(self):
113110
size_t index = get_global_id(0); \
114111
c[index] = a[index] + d*b[index]; \
115112
}"
116-
with dpctl.device_context("level_zero:gpu:0"):
117-
q = dpctl.get_current_queue()
118-
dpctl_prog.create_program_from_source(q, oclSrc)
113+
q = dpctl.SyclQueue("level_zero:gpu")
114+
dpctl_prog.create_program_from_source(q, oclSrc)
119115

120116

121117
if __name__ == "__main__":

0 commit comments

Comments
 (0)