Skip to content

Commit dbe7a5b

Browse files
authored
[UT] Fix test_print UTs (#2867)
`print` works now on DLE 2025.0.0. But `test_print` UTs still fail. There are two reason: 1. We use subprocess to call `print` while we use `sycl::queue` from `torch`. So the `queue` cannot be synced before the exiting of subprocess. 2. `torch.xpu.synchronize()` does not work because it just sync on `reserved streams`. (See [the comment](https://github.com/pytorch/pytorch/blob/19d01a1ef0c0d65768eb0a5c97a25328eec57fbd/c10/xpu/XPUStream.cpp#L249)). Accroding to my test, our print kernels were not waited on. So I add an `launch_exit_hook` to wait on that queue.
1 parent f3ad673 commit dbe7a5b

File tree

6 files changed

+31
-31
lines changed

6 files changed

+31
-31
lines changed

python/test/unit/language/print_helper.py

Lines changed: 12 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,16 @@ def test_print(func: str, data_type: str, device: str):
9999

100100
x = torch.arange(0, N, dtype=torch.int32, device=device).to(getattr(torch, data_type))
101101
y = torch.zeros((N, ), dtype=x.dtype, device=device)
102+
103+
if device == "xpu":
104+
105+
def exit_hook(lazy_dict: triton.compiler.LazyDict):
106+
# Need this for xpu device to capture print results before child process exit
107+
# torch.xpu.synchronize() does not work because it just sync on reserved stream
108+
triton.runtime.driver.active.utils.wait()
109+
110+
triton.compiler.CompiledKernel.launch_exit_hook = exit_hook
111+
102112
if func == "device_print":
103113
kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N)
104114
elif func == "device_print_scalar":
@@ -130,20 +140,15 @@ def test_print(func: str, data_type: str, device: str):
130140
kernel_print_pointer[(1, )](x, y, num_warps=num_warps, BLOCK=N)
131141
else:
132142
assert f"Unknown kernel: {func}"
133-
134-
if device == "xpu":
135-
# FIXME: remove trigger to get output from kernel
136-
repr(x)
137-
repr(y)
138-
139143
if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \
140144
func != "print_multiple_args" and func != "device_print_multiple_args" and \
141145
func != "device_print_pointer" and func != "device_print_scalar":
142146
assert_close(y, x)
143147

144148
# Wait until driver complete all the jobs for the device_print, especially test_subprocess
145149
# require this which captures stdout when child exits.
146-
getattr(torch, device).synchronize()
150+
if device != "xpu":
151+
getattr(torch, device).synchronize()
147152

148153

149154
if __name__ == "__main__":
Lines changed: 0 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +0,0 @@
1-
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float32]
2-
test/unit/language/test_subprocess.py::test_print[device_print-float32]
3-
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float16]
4-
test/unit/language/test_subprocess.py::test_print[device_print-float16]
5-
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float64]
6-
test/unit/language/test_subprocess.py::test_print[device_print-float64]
Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +0,0 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/800
2-
test/unit/language/test_subprocess.py::test_print[device_print-float16]
3-
test/unit/language/test_subprocess.py::test_print[device_print-float32]
4-
test/unit/language/test_subprocess.py::test_print[device_print-float64]
5-
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float16]
6-
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float64]
7-
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float32]
8-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1704
9-
test/unit/language/test_subprocess.py::test_print[device_print_uint-uint32]
Lines changed: 0 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +0,0 @@
1-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/800
2-
test/unit/language/test_subprocess.py::test_print[device_print-float16]
3-
test/unit/language/test_subprocess.py::test_print[device_print-float32]
4-
test/unit/language/test_subprocess.py::test_print[device_print-float64]
5-
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float16]
6-
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float64]
7-
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float32]
8-
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1704
9-
test/unit/language/test_subprocess.py::test_print[device_print_uint-uint32]

third_party/intel/backend/driver.c

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -346,6 +346,19 @@ static PyObject *initDevices(PyObject *self, PyObject *args) {
346346
return Py_BuildValue("(i)", deviceCount);
347347
}
348348

349+
static PyObject *waitOnSYCLQueue(PyObject *self, PyObject *args) {
350+
PyObject *cap;
351+
void *queue = NULL;
352+
if (!PyArg_ParseTuple(args, "O", &cap))
353+
return NULL;
354+
if (!(queue = PyLong_AsVoidPtr(cap)))
355+
return NULL;
356+
sycl::queue *sycl_queue = static_cast<sycl::queue *>(queue);
357+
sycl_queue->wait();
358+
359+
return Py_None;
360+
}
361+
349362
static PyMethodDef ModuleMethods[] = {
350363
{"load_binary", loadBinary, METH_VARARGS,
351364
"Load provided SPV into ZE driver"},
@@ -355,6 +368,8 @@ static PyMethodDef ModuleMethods[] = {
355368
"Initialize the ZE GPU context"},
356369
{"init_devices", initDevices, METH_VARARGS,
357370
"Initialize the ZE GPU devices and return device count"},
371+
{"wait_on_sycl_queue", waitOnSYCLQueue, METH_VARARGS,
372+
"call wait on a specific sycl::queue"},
358373
{NULL, NULL, 0, NULL} // sentinel
359374
};
360375

third_party/intel/backend/driver.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -159,6 +159,7 @@ def __init__(self):
159159
self.context = mod.init_context(self.get_sycl_queue())
160160
self.device_count = mod.init_devices(self.get_sycl_queue())
161161
self.current_device = 0 if self.device_count[0] > 0 else -1
162+
self.wait_on_sycl_queue = mod.wait_on_sycl_queue
162163

163164
def get_current_device(self):
164165
return self.current_device
@@ -167,6 +168,9 @@ def get_sycl_queue(self):
167168
import torch
168169
return torch.xpu.current_stream().sycl_queue
169170

171+
def wait(self):
172+
self.wait_on_sycl_queue(self.get_sycl_queue())
173+
170174

171175
# ------------------------
172176
# Launcher

0 commit comments

Comments
 (0)