Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 12 additions & 7 deletions python/test/unit/language/print_helper.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,16 @@ def test_print(func: str, data_type: str, device: str):

x = torch.arange(0, N, dtype=torch.int32, device=device).to(getattr(torch, data_type))
y = torch.zeros((N, ), dtype=x.dtype, device=device)

if device == "xpu":

def exit_hook(lazy_dict: triton.compiler.LazyDict):
# Need this for xpu device to capture print results before child process exit
# torch.xpu.synchronize() does not work because it just sync on reserved stream
triton.runtime.driver.active.utils.wait()

triton.compiler.CompiledKernel.launch_exit_hook = exit_hook

if func == "device_print":
kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N)
elif func == "device_print_scalar":
Expand Down Expand Up @@ -130,20 +140,15 @@ def test_print(func: str, data_type: str, device: str):
kernel_print_pointer[(1, )](x, y, num_warps=num_warps, BLOCK=N)
else:
assert f"Unknown kernel: {func}"

if device == "xpu":
# FIXME: remove trigger to get output from kernel
repr(x)
repr(y)

if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \
func != "print_multiple_args" and func != "device_print_multiple_args" and \
func != "device_print_pointer" and func != "device_print_scalar":
assert_close(y, x)

# Wait until driver complete all the jobs for the device_print, especially test_subprocess
# require this which captures stdout when child exits.
getattr(torch, device).synchronize()
if device != "xpu":
getattr(torch, device).synchronize()


if __name__ == "__main__":
Expand Down
6 changes: 0 additions & 6 deletions scripts/skiplist/conda/subprocess.txt
Original file line number Diff line number Diff line change
@@ -1,6 +0,0 @@
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float32]
test/unit/language/test_subprocess.py::test_print[device_print-float32]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float16]
test/unit/language/test_subprocess.py::test_print[device_print-float16]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float64]
test/unit/language/test_subprocess.py::test_print[device_print-float64]
9 changes: 0 additions & 9 deletions scripts/skiplist/default/subprocess.txt
Original file line number Diff line number Diff line change
@@ -1,9 +0,0 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/800
test/unit/language/test_subprocess.py::test_print[device_print-float16]
test/unit/language/test_subprocess.py::test_print[device_print-float32]
test/unit/language/test_subprocess.py::test_print[device_print-float64]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float16]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float64]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float32]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1704
test/unit/language/test_subprocess.py::test_print[device_print_uint-uint32]
9 changes: 0 additions & 9 deletions scripts/skiplist/lts/subprocess.txt
Original file line number Diff line number Diff line change
@@ -1,9 +0,0 @@
# https://github.com/intel/intel-xpu-backend-for-triton/issues/800
test/unit/language/test_subprocess.py::test_print[device_print-float16]
test/unit/language/test_subprocess.py::test_print[device_print-float32]
test/unit/language/test_subprocess.py::test_print[device_print-float64]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float16]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float64]
test/unit/language/test_subprocess.py::test_print[device_print_scalar-float32]
# https://github.com/intel/intel-xpu-backend-for-triton/issues/1704
test/unit/language/test_subprocess.py::test_print[device_print_uint-uint32]
15 changes: 15 additions & 0 deletions third_party/intel/backend/driver.c
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,19 @@ static PyObject *initDevices(PyObject *self, PyObject *args) {
return Py_BuildValue("(i)", deviceCount);
}

static PyObject *waitOnSYCLQueue(PyObject *self, PyObject *args) {
PyObject *cap;
void *queue = NULL;
if (!PyArg_ParseTuple(args, "O", &cap))
return NULL;
if (!(queue = PyLong_AsVoidPtr(cap)))
return NULL;
sycl::queue *sycl_queue = static_cast<sycl::queue *>(queue);
sycl_queue->wait();

return Py_None;
}

static PyMethodDef ModuleMethods[] = {
{"load_binary", loadBinary, METH_VARARGS,
"Load provided SPV into ZE driver"},
Expand All @@ -347,6 +360,8 @@ static PyMethodDef ModuleMethods[] = {
"Initialize the ZE GPU context"},
{"init_devices", initDevices, METH_VARARGS,
"Initialize the ZE GPU devices and return device count"},
{"wait_on_sycl_queue", waitOnSYCLQueue, METH_VARARGS,
"call wait on a specific sycl::queue"},
{NULL, NULL, 0, NULL} // sentinel
};

Expand Down
4 changes: 4 additions & 0 deletions third_party/intel/backend/driver.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,6 +161,7 @@ def __init__(self):
self.context = mod.init_context(self.get_sycl_queue())
self.device_count = mod.init_devices(self.get_sycl_queue())
self.current_device = 0 if self.device_count[0] > 0 else -1
self.wait_on_sycl_queue = mod.wait_on_sycl_queue

def get_current_device(self):
return self.current_device
Expand All @@ -172,6 +173,9 @@ def get_sycl_queue(self):
import torch
return torch.xpu.current_stream().sycl_queue

def wait(self):
self.wait_on_sycl_queue(self.get_sycl_queue())
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we use the torch.xpu.current_stream().wait()?
It is easy for us to decouple the SYCL runtime in triton to torch.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems torch does not provide such a wait now.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's revisit the runtime code in the future. I think it is ok for now.



# ------------------------
# Launcher
Expand Down