From 8918091b9deb115b5c6891247d4f0af7aa8aa406 Mon Sep 17 00:00:00 2001 From: "Wang, Quintin" Date: Thu, 28 Nov 2024 14:32:50 +0000 Subject: [PATCH 1/3] Fix print UTs --- python/test/unit/language/print_helper.py | 19 ++++++++++++------- third_party/intel/backend/driver.c | 15 +++++++++++++++ third_party/intel/backend/driver.py | 4 ++++ 3 files changed, 31 insertions(+), 7 deletions(-) diff --git a/python/test/unit/language/print_helper.py b/python/test/unit/language/print_helper.py index b89494be7e..9a3250ebaa 100644 --- a/python/test/unit/language/print_helper.py +++ b/python/test/unit/language/print_helper.py @@ -99,6 +99,16 @@ def test_print(func: str, data_type: str, device: str): x = torch.arange(0, N, dtype=torch.int32, device=device).to(getattr(torch, data_type)) y = torch.zeros((N, ), dtype=x.dtype, device=device) + + if device == "xpu": + + def exit_hook(lazy_dict: triton.compiler.LazyDict): + # Need this for xpu device to capture print results before child process exit + # torch.xpu.synchronize() does not work because it just sync on reserved stream + triton.runtime.driver.active.utils.wait() + + triton.compiler.CompiledKernel.launch_exit_hook = exit_hook + if func == "device_print": kernel_device_print[(1, )](x, y, num_warps=num_warps, BLOCK=N) elif func == "device_print_scalar": @@ -130,12 +140,6 @@ def test_print(func: str, data_type: str, device: str): kernel_print_pointer[(1, )](x, y, num_warps=num_warps, BLOCK=N) else: assert f"Unknown kernel: {func}" - - if device == "xpu": - # FIXME: remove trigger to get output from kernel - repr(x) - repr(y) - if func != "print_no_arg" and func != "no_arg_print" and func != "device_print_large" and \ func != "print_multiple_args" and func != "device_print_multiple_args" and \ func != "device_print_pointer" and func != "device_print_scalar": @@ -143,7 +147,8 @@ def test_print(func: str, data_type: str, device: str): # Wait until driver complete all the jobs for the device_print, especially test_subprocess # require this which captures stdout when child exits. - getattr(torch, device).synchronize() + if device != "xpu": + getattr(torch, device).synchronize() if __name__ == "__main__": diff --git a/third_party/intel/backend/driver.c b/third_party/intel/backend/driver.c index 295af3e3fe..d55bc4496d 100644 --- a/third_party/intel/backend/driver.c +++ b/third_party/intel/backend/driver.c @@ -338,6 +338,19 @@ static PyObject *initDevices(PyObject *self, PyObject *args) { return Py_BuildValue("(i)", deviceCount); } +static PyObject *waitOnSYCLQueue(PyObject *self, PyObject *args) { + PyObject *cap; + void *queue = NULL; + if (!PyArg_ParseTuple(args, "O", &cap)) + return NULL; + if (!(queue = PyLong_AsVoidPtr(cap))) + return NULL; + sycl::queue *sycl_queue = static_cast(queue); + sycl_queue->wait(); + + return Py_None; +} + static PyMethodDef ModuleMethods[] = { {"load_binary", loadBinary, METH_VARARGS, "Load provided SPV into ZE driver"}, @@ -347,6 +360,8 @@ static PyMethodDef ModuleMethods[] = { "Initialize the ZE GPU context"}, {"init_devices", initDevices, METH_VARARGS, "Initialize the ZE GPU devices and return device count"}, + {"wait_on_sycl_queue", waitOnSYCLQueue, METH_VARARGS, + "call wait on a specific sycl::queue"}, {NULL, NULL, 0, NULL} // sentinel }; diff --git a/third_party/intel/backend/driver.py b/third_party/intel/backend/driver.py index f372091e56..488b6828fe 100644 --- a/third_party/intel/backend/driver.py +++ b/third_party/intel/backend/driver.py @@ -161,6 +161,7 @@ def __init__(self): self.context = mod.init_context(self.get_sycl_queue()) self.device_count = mod.init_devices(self.get_sycl_queue()) self.current_device = 0 if self.device_count[0] > 0 else -1 + self.wait_on_sycl_queue = mod.wait_on_sycl_queue def get_current_device(self): return self.current_device @@ -172,6 +173,9 @@ def get_sycl_queue(self): import torch return torch.xpu.current_stream().sycl_queue + def wait(self): + self.wait_on_sycl_queue(self.get_sycl_queue()) + # ------------------------ # Launcher From a875d57306939add5efcfcb3912237905efc9c30 Mon Sep 17 00:00:00 2001 From: "Wang, Quintin" Date: Thu, 28 Nov 2024 14:35:56 +0000 Subject: [PATCH 2/3] Remove test_print from skiplists --- scripts/skiplist/conda/subprocess.txt | 7 +------ scripts/skiplist/default/subprocess.txt | 9 +-------- scripts/skiplist/lts/subprocess.txt | 9 +-------- 3 files changed, 3 insertions(+), 22 deletions(-) diff --git a/scripts/skiplist/conda/subprocess.txt b/scripts/skiplist/conda/subprocess.txt index 8895b4638b..8b13789179 100644 --- a/scripts/skiplist/conda/subprocess.txt +++ b/scripts/skiplist/conda/subprocess.txt @@ -1,6 +1 @@ -test/unit/language/test_subprocess.py::test_print[device_print_scalar-float32] -test/unit/language/test_subprocess.py::test_print[device_print-float32] -test/unit/language/test_subprocess.py::test_print[device_print_scalar-float16] -test/unit/language/test_subprocess.py::test_print[device_print-float16] -test/unit/language/test_subprocess.py::test_print[device_print_scalar-float64] -test/unit/language/test_subprocess.py::test_print[device_print-float64] + diff --git a/scripts/skiplist/default/subprocess.txt b/scripts/skiplist/default/subprocess.txt index 6f0ada5109..7bcee0b047 100644 --- a/scripts/skiplist/default/subprocess.txt +++ b/scripts/skiplist/default/subprocess.txt @@ -1,9 +1,2 @@ -# https://github.com/intel/intel-xpu-backend-for-triton/issues/800 -test/unit/language/test_subprocess.py::test_print[device_print-float16] -test/unit/language/test_subprocess.py::test_print[device_print-float32] -test/unit/language/test_subprocess.py::test_print[device_print-float64] -test/unit/language/test_subprocess.py::test_print[device_print_scalar-float16] -test/unit/language/test_subprocess.py::test_print[device_print_scalar-float64] -test/unit/language/test_subprocess.py::test_print[device_print_scalar-float32] # https://github.com/intel/intel-xpu-backend-for-triton/issues/1704 -test/unit/language/test_subprocess.py::test_print[device_print_uint-uint32] + diff --git a/scripts/skiplist/lts/subprocess.txt b/scripts/skiplist/lts/subprocess.txt index 6f0ada5109..7bcee0b047 100644 --- a/scripts/skiplist/lts/subprocess.txt +++ b/scripts/skiplist/lts/subprocess.txt @@ -1,9 +1,2 @@ -# https://github.com/intel/intel-xpu-backend-for-triton/issues/800 -test/unit/language/test_subprocess.py::test_print[device_print-float16] -test/unit/language/test_subprocess.py::test_print[device_print-float32] -test/unit/language/test_subprocess.py::test_print[device_print-float64] -test/unit/language/test_subprocess.py::test_print[device_print_scalar-float16] -test/unit/language/test_subprocess.py::test_print[device_print_scalar-float64] -test/unit/language/test_subprocess.py::test_print[device_print_scalar-float32] # https://github.com/intel/intel-xpu-backend-for-triton/issues/1704 -test/unit/language/test_subprocess.py::test_print[device_print_uint-uint32] + From bd25bba3b7a3cdf096bb999cda83f42c37808f6d Mon Sep 17 00:00:00 2001 From: "Wang, Quintin" Date: Thu, 28 Nov 2024 14:52:32 +0000 Subject: [PATCH 3/3] Fix pre-commit errors --- scripts/skiplist/conda/subprocess.txt | 1 - scripts/skiplist/default/subprocess.txt | 2 -- scripts/skiplist/lts/subprocess.txt | 2 -- 3 files changed, 5 deletions(-) diff --git a/scripts/skiplist/conda/subprocess.txt b/scripts/skiplist/conda/subprocess.txt index 8b13789179..e69de29bb2 100644 --- a/scripts/skiplist/conda/subprocess.txt +++ b/scripts/skiplist/conda/subprocess.txt @@ -1 +0,0 @@ - diff --git a/scripts/skiplist/default/subprocess.txt b/scripts/skiplist/default/subprocess.txt index 7bcee0b047..e69de29bb2 100644 --- a/scripts/skiplist/default/subprocess.txt +++ b/scripts/skiplist/default/subprocess.txt @@ -1,2 +0,0 @@ -# https://github.com/intel/intel-xpu-backend-for-triton/issues/1704 - diff --git a/scripts/skiplist/lts/subprocess.txt b/scripts/skiplist/lts/subprocess.txt index 7bcee0b047..e69de29bb2 100644 --- a/scripts/skiplist/lts/subprocess.txt +++ b/scripts/skiplist/lts/subprocess.txt @@ -1,2 +0,0 @@ -# https://github.com/intel/intel-xpu-backend-for-triton/issues/1704 -