@@ -428,7 +428,7 @@ index e23dcc3a4c..aaaf22ed81 100644
428
428
system_link_files = {
429
429
"//third_party/systemlibs:BUILD": "bazel/BUILD",
430
430
diff --git a/xla/backends/profiler/plugin/BUILD b/xla/backends/profiler/plugin/BUILD
431
- index 169a4eaa4e..1b8c0bae04 100644
431
+ index 169a4eaa4e..161e4e0452 100644
432
432
--- a/xla/backends/profiler/plugin/BUILD
433
433
+++ b/xla/backends/profiler/plugin/BUILD
434
434
@@ -62,6 +62,10 @@ cc_library(
@@ -663,7 +663,7 @@ index e2c82bad04..8401ec77d8 100644
663
663
)
664
664
665
665
diff --git a/xla/python/py_client.cc b/xla/python/py_client.cc
666
- index 0afd053313..f51f71b4f4 100644
666
+ index 0afd053313..f2116e7f5a 100644
667
667
--- a/xla/python/py_client.cc
668
668
+++ b/xla/python/py_client.cc
669
669
@@ -91,9 +91,9 @@ limitations under the License.
@@ -678,8 +678,17 @@ index 0afd053313..f51f71b4f4 100644
678
678
679
679
namespace xla {
680
680
681
+ @@ -670,7 +670,7 @@ PyClient::GetEmitPythonCallbackDescriptor(nb::callable callable,
682
+ XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback",
683
+ &XlaPythonCpuCallback);
684
+
685
+ - #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
686
+ + #if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || TENSORFLOW_USE_SYCL
687
+ XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
688
+ "xla_python_gpu_callback", &XlaPythonGpuCallback,
689
+ absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()));
681
690
diff --git a/xla/python/py_client_gpu.cc b/xla/python/py_client_gpu.cc
682
- index 100d9fd599..91df06ad4e 100644
691
+ index 100d9fd599..642828a9ce 100644
683
692
--- a/xla/python/py_client_gpu.cc
684
693
+++ b/xla/python/py_client_gpu.cc
685
694
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -690,105 +699,43 @@ index 100d9fd599..91df06ad4e 100644
690
699
691
700
#include <vector>
692
701
693
- @@ -20,7 +21,7 @@ limitations under the License.
702
+ @@ -20,6 +21,8 @@ limitations under the License.
694
703
#include "tsl/platform/errors.h"
695
704
#if TENSORFLOW_USE_ROCM
696
705
#include "rocm/include/hip/hip_runtime.h"
697
- - #else
698
- + #elif GOOGLE_CUDA
706
+ + #elif TENSORFLOW_USE_SYCL
707
+ + #include "xla/stream_executor/sycl/sycl_gpu_runtime.h"
708
+ #else
699
709
#include "third_party/gpus/cuda/include/cuda.h"
700
710
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
701
- #endif
702
- @@ -38,13 +39,15 @@ limitations under the License.
711
+ @@ -38,6 +41,13 @@ limitations under the License.
703
712
#define gpuStreamSynchronize hipStreamSynchronize
704
713
#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost
705
714
#define gpuMemcpyHostToDevice hipMemcpyHostToDevice
706
- - #else
707
- + #elif GOOGLE_CUDA
715
+ + #elif TENSORFLOW_USE_SYCL
716
+ + #define gpuSuccess SYCL_SUCCESS
717
+ + #define gpuStreamHandle ::sycl::queue*
718
+ + #define gpuMemcpyAsync SYCLMemcpyAsync
719
+ + #define gpuStreamSynchronize SYCLStreamSynchronize
720
+ + #define gpuMemcpyDeviceToHost SYCLMemcpyDtoHAsync
721
+ + #define gpuMemcpyHostToDevice SYCLMemcpyHtoDAsync
722
+ #else
708
723
#define gpuSuccess cudaSuccess
709
724
#define gpuStreamHandle CUstream
710
- #define gpuMemcpyAsync cudaMemcpyAsync
711
- #define gpuStreamSynchronize cudaStreamSynchronize
712
- #define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost
713
- #define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
714
- + #else
715
- + #define gpuStreamHandle ::sycl::queue*
716
- #endif
717
-
718
- namespace nb = nanobind;
719
- @@ -74,13 +77,20 @@ void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers,
720
- }
721
- void* buf = new char[arg.size_in_bytes];
722
- host_input_buffers[i] = buf;
723
- + #if TENSORFLOW_USE_SYCL
724
- + auto event = stream->memcpy(buf, (const void*)(buffers[i]), arg.size_in_bytes);
725
- + event.wait();
726
- + #else
727
- // TODO(b/238441608): Use pinned memory here to speed up the transfer.
728
- auto gpu_res = gpuMemcpyAsync(buf, buffers[i], arg.size_in_bytes,
729
- gpuMemcpyDeviceToHost, stream);
730
- CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync";
731
- + #endif
732
- }
733
- + #ifndef TENSORFLOW_USE_SYCL
734
- CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess)
735
- << "Failed to gpuStreamSynchronize";
736
- + #endif
737
- nb::gil_scoped_acquire gil;
738
- nb::tuple host_input_arrays = nb::steal<nb::tuple>(PyTuple_New(arity));
739
- for (size_t i = 0; i < arity; ++i) {
740
- @@ -120,10 +130,15 @@ void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers,
741
- absl::Span<int64_t const> strides(
742
- reinterpret_cast<const int64_t*>(array.strides()), array.ndim());
743
- if (strides == result.expected_strides) {
744
- + #ifdef TENSORFLOW_USE_SYCL
745
- + auto event = stream->memcpy(buffers[arity + i], array.data(), result.size_in_bytes);
746
- + event.wait();
747
- + #else
748
- auto gpu_res =
749
- gpuMemcpyAsync(buffers[arity + i], array.data(), result.size_in_bytes,
750
- gpuMemcpyHostToDevice, stream);
751
- CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync";
752
- + #endif
753
- } else {
754
- void* temp = new char[result.size_in_bytes];
755
- temp_buffers.push_back(temp);
756
- @@ -138,15 +153,22 @@ void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers,
757
- throw xla::XlaRuntimeError(plan.status().ToString());
758
- }
759
- plan.value()->Execute(array.data(), temp);
760
- + #ifdef TENSORFLOW_USE_SYCL
761
- + auto event = stream->memcpy(buffers[arity + i], temp, result.size_in_bytes);
762
- + event.wait();
763
- + #else
764
- auto gpu_res =
765
- gpuMemcpyAsync(buffers[arity + i], temp, result.size_in_bytes,
766
- gpuMemcpyHostToDevice, stream);
767
- CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync";
768
- + #endif
769
- }
770
- }
771
- nb::gil_scoped_release release;
772
- + #ifndef TENSORFLOW_USE_SYCL
773
- CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess)
774
- << "Failed to gpuStreamSynchronize";
775
- + #endif
776
- for (int i = 0; i < temp_buffers.size(); ++i) {
777
- delete[] static_cast<char*>(temp_buffers[i]);
778
- }
779
725
diff --git a/xla/python/py_client_gpu.h b/xla/python/py_client_gpu.h
780
- index d7675e1b6a..571a134431 100644
726
+ index d7675e1b6a..17da528bec 100644
781
727
--- a/xla/python/py_client_gpu.h
782
728
+++ b/xla/python/py_client_gpu.h
783
- @@ -18,17 +18,28 @@ limitations under the License.
729
+ @@ -18,6 +18,8 @@ limitations under the License.
784
730
785
731
#if TENSORFLOW_USE_ROCM
786
732
#include "rocm/include/hip/hip_runtime.h"
787
- - #else
788
- + #elif GOOGLE_CUDA
733
+ + #elif TENSORFLOW_USE_SYCL
734
+ + #include "xla/stream_executor/sycl/sycl_gpu_runtime.h"
735
+ #else
789
736
#include "third_party/gpus/cuda/include/cuda.h"
790
737
#endif
791
- #include "xla/service/custom_call_status.h"
738
+ @@ -25,8 +27,10 @@ limitations under the License.
792
739
793
740
#if TENSORFLOW_USE_ROCM
794
741
#define gpuStreamHandle hipStream_t
@@ -799,18 +746,7 @@ index d7675e1b6a..571a134431 100644
799
746
+ #define gpuStreamHandle ::sycl::queue*
800
747
#endif
801
748
802
- + #if TENSORFLOW_USE_SYCL
803
- + #if __has_include(<sycl/sycl.hpp>)
804
- + #include <sycl/sycl.hpp>
805
- + #elif __has_include(<CL/sycl.hpp>)
806
- + #include <CL/sycl.hpp>
807
- + #else
808
- + #error "Unsupported compiler"
809
- + #endif
810
- + #endif
811
749
namespace xla {
812
-
813
- void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers,
814
750
diff --git a/xla/service/BUILD b/xla/service/BUILD
815
751
index bcedb98906..952e4c5f6f 100644
816
752
--- a/xla/service/BUILD
0 commit comments