Skip to content

Commit 0708d7a

Browse files
authored
remove macro in py_client_gpu (#417)
1 parent 73aab13 commit 0708d7a

File tree

3 files changed

+49
-95
lines changed

3 files changed

+49
-95
lines changed

third_party/openxla.patch

Lines changed: 31 additions & 95 deletions
Original file line numberDiff line numberDiff line change
@@ -428,7 +428,7 @@ index e23dcc3a4c..aaaf22ed81 100644
428428
system_link_files = {
429429
"//third_party/systemlibs:BUILD": "bazel/BUILD",
430430
diff --git a/xla/backends/profiler/plugin/BUILD b/xla/backends/profiler/plugin/BUILD
431-
index 169a4eaa4e..1b8c0bae04 100644
431+
index 169a4eaa4e..161e4e0452 100644
432432
--- a/xla/backends/profiler/plugin/BUILD
433433
+++ b/xla/backends/profiler/plugin/BUILD
434434
@@ -62,6 +62,10 @@ cc_library(
@@ -663,7 +663,7 @@ index e2c82bad04..8401ec77d8 100644
663663
)
664664

665665
diff --git a/xla/python/py_client.cc b/xla/python/py_client.cc
666-
index 0afd053313..f51f71b4f4 100644
666+
index 0afd053313..f2116e7f5a 100644
667667
--- a/xla/python/py_client.cc
668668
+++ b/xla/python/py_client.cc
669669
@@ -91,9 +91,9 @@ limitations under the License.
@@ -678,8 +678,17 @@ index 0afd053313..f51f71b4f4 100644
678678

679679
namespace xla {
680680

681+
@@ -670,7 +670,7 @@ PyClient::GetEmitPythonCallbackDescriptor(nb::callable callable,
682+
XLA_CPU_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM("xla_python_cpu_callback",
683+
&XlaPythonCpuCallback);
684+
685+
-#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM
686+
+#if GOOGLE_CUDA || TENSORFLOW_USE_ROCM || TENSORFLOW_USE_SYCL
687+
XLA_REGISTER_CUSTOM_CALL_TARGET_WITH_SYM(
688+
"xla_python_gpu_callback", &XlaPythonGpuCallback,
689+
absl::AsciiStrToUpper(PlatformUtil::CanonicalPlatformName("gpu").value()));
681690
diff --git a/xla/python/py_client_gpu.cc b/xla/python/py_client_gpu.cc
682-
index 100d9fd599..91df06ad4e 100644
691+
index 100d9fd599..642828a9ce 100644
683692
--- a/xla/python/py_client_gpu.cc
684693
+++ b/xla/python/py_client_gpu.cc
685694
@@ -12,6 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
@@ -690,105 +699,43 @@ index 100d9fd599..91df06ad4e 100644
690699

691700
#include <vector>
692701

693-
@@ -20,7 +21,7 @@ limitations under the License.
702+
@@ -20,6 +21,8 @@ limitations under the License.
694703
#include "tsl/platform/errors.h"
695704
#if TENSORFLOW_USE_ROCM
696705
#include "rocm/include/hip/hip_runtime.h"
697-
-#else
698-
+#elif GOOGLE_CUDA
706+
+#elif TENSORFLOW_USE_SYCL
707+
+#include "xla/stream_executor/sycl/sycl_gpu_runtime.h"
708+
#else
699709
#include "third_party/gpus/cuda/include/cuda.h"
700710
#include "third_party/gpus/cuda/include/cuda_runtime_api.h"
701-
#endif
702-
@@ -38,13 +39,15 @@ limitations under the License.
711+
@@ -38,6 +41,13 @@ limitations under the License.
703712
#define gpuStreamSynchronize hipStreamSynchronize
704713
#define gpuMemcpyDeviceToHost hipMemcpyDeviceToHost
705714
#define gpuMemcpyHostToDevice hipMemcpyHostToDevice
706-
-#else
707-
+#elif GOOGLE_CUDA
715+
+#elif TENSORFLOW_USE_SYCL
716+
+#define gpuSuccess SYCL_SUCCESS
717+
+#define gpuStreamHandle ::sycl::queue*
718+
+#define gpuMemcpyAsync SYCLMemcpyAsync
719+
+#define gpuStreamSynchronize SYCLStreamSynchronize
720+
+#define gpuMemcpyDeviceToHost SYCLMemcpyDtoHAsync
721+
+#define gpuMemcpyHostToDevice SYCLMemcpyHtoDAsync
722+
#else
708723
#define gpuSuccess cudaSuccess
709724
#define gpuStreamHandle CUstream
710-
#define gpuMemcpyAsync cudaMemcpyAsync
711-
#define gpuStreamSynchronize cudaStreamSynchronize
712-
#define gpuMemcpyDeviceToHost cudaMemcpyDeviceToHost
713-
#define gpuMemcpyHostToDevice cudaMemcpyHostToDevice
714-
+#else
715-
+#define gpuStreamHandle ::sycl::queue*
716-
#endif
717-
718-
namespace nb = nanobind;
719-
@@ -74,13 +77,20 @@ void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers,
720-
}
721-
void* buf = new char[arg.size_in_bytes];
722-
host_input_buffers[i] = buf;
723-
+#if TENSORFLOW_USE_SYCL
724-
+ auto event = stream->memcpy(buf, (const void*)(buffers[i]), arg.size_in_bytes);
725-
+ event.wait();
726-
+#else
727-
// TODO(b/238441608): Use pinned memory here to speed up the transfer.
728-
auto gpu_res = gpuMemcpyAsync(buf, buffers[i], arg.size_in_bytes,
729-
gpuMemcpyDeviceToHost, stream);
730-
CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync";
731-
+#endif
732-
}
733-
+#ifndef TENSORFLOW_USE_SYCL
734-
CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess)
735-
<< "Failed to gpuStreamSynchronize";
736-
+#endif
737-
nb::gil_scoped_acquire gil;
738-
nb::tuple host_input_arrays = nb::steal<nb::tuple>(PyTuple_New(arity));
739-
for (size_t i = 0; i < arity; ++i) {
740-
@@ -120,10 +130,15 @@ void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers,
741-
absl::Span<int64_t const> strides(
742-
reinterpret_cast<const int64_t*>(array.strides()), array.ndim());
743-
if (strides == result.expected_strides) {
744-
+#ifdef TENSORFLOW_USE_SYCL
745-
+ auto event = stream->memcpy(buffers[arity + i], array.data(), result.size_in_bytes);
746-
+ event.wait();
747-
+#else
748-
auto gpu_res =
749-
gpuMemcpyAsync(buffers[arity + i], array.data(), result.size_in_bytes,
750-
gpuMemcpyHostToDevice, stream);
751-
CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync";
752-
+#endif
753-
} else {
754-
void* temp = new char[result.size_in_bytes];
755-
temp_buffers.push_back(temp);
756-
@@ -138,15 +153,22 @@ void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers,
757-
throw xla::XlaRuntimeError(plan.status().ToString());
758-
}
759-
plan.value()->Execute(array.data(), temp);
760-
+#ifdef TENSORFLOW_USE_SYCL
761-
+ auto event = stream->memcpy(buffers[arity + i], temp, result.size_in_bytes);
762-
+ event.wait();
763-
+#else
764-
auto gpu_res =
765-
gpuMemcpyAsync(buffers[arity + i], temp, result.size_in_bytes,
766-
gpuMemcpyHostToDevice, stream);
767-
CHECK_EQ(gpu_res, gpuSuccess) << "Failed to gpuMemcpyAsync";
768-
+#endif
769-
}
770-
}
771-
nb::gil_scoped_release release;
772-
+#ifndef TENSORFLOW_USE_SYCL
773-
CHECK_EQ(gpuStreamSynchronize(stream), gpuSuccess)
774-
<< "Failed to gpuStreamSynchronize";
775-
+#endif
776-
for (int i = 0; i < temp_buffers.size(); ++i) {
777-
delete[] static_cast<char*>(temp_buffers[i]);
778-
}
779725
diff --git a/xla/python/py_client_gpu.h b/xla/python/py_client_gpu.h
780-
index d7675e1b6a..571a134431 100644
726+
index d7675e1b6a..17da528bec 100644
781727
--- a/xla/python/py_client_gpu.h
782728
+++ b/xla/python/py_client_gpu.h
783-
@@ -18,17 +18,28 @@ limitations under the License.
729+
@@ -18,6 +18,8 @@ limitations under the License.
784730

785731
#if TENSORFLOW_USE_ROCM
786732
#include "rocm/include/hip/hip_runtime.h"
787-
-#else
788-
+#elif GOOGLE_CUDA
733+
+#elif TENSORFLOW_USE_SYCL
734+
+#include "xla/stream_executor/sycl/sycl_gpu_runtime.h"
735+
#else
789736
#include "third_party/gpus/cuda/include/cuda.h"
790737
#endif
791-
#include "xla/service/custom_call_status.h"
738+
@@ -25,8 +27,10 @@ limitations under the License.
792739

793740
#if TENSORFLOW_USE_ROCM
794741
#define gpuStreamHandle hipStream_t
@@ -799,18 +746,7 @@ index d7675e1b6a..571a134431 100644
799746
+#define gpuStreamHandle ::sycl::queue*
800747
#endif
801748

802-
+#if TENSORFLOW_USE_SYCL
803-
+#if __has_include(<sycl/sycl.hpp>)
804-
+#include <sycl/sycl.hpp>
805-
+#elif __has_include(<CL/sycl.hpp>)
806-
+#include <CL/sycl.hpp>
807-
+#else
808-
+#error "Unsupported compiler"
809-
+#endif
810-
+#endif
811749
namespace xla {
812-
813-
void XlaPythonGpuCallback(gpuStreamHandle stream, void** buffers,
814750
diff --git a/xla/service/BUILD b/xla/service/BUILD
815751
index bcedb98906..952e4c5f6f 100644
816752
--- a/xla/service/BUILD

xla/stream_executor/sycl/sycl_gpu_runtime.cc

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,18 @@ SYCLError_t SYCLMemsetD32Async(void* dstDevice, unsigned int ui, size_t N,
470470
return SYCL_SUCCESS;
471471
}
472472

473+
SYCLError_t SYCLMemcpyAsync(void* dst, const void* src, size_t ByteCount,
474+
SYCLError_t (*func)(void*, const void*, size_t, sycl::queue*),
475+
sycl::queue* stream){
476+
return (*func)(dst, src, ByteCount, stream);
477+
}
478+
479+
SYCLError_t SYCLStreamSynchronize(sycl::queue* stream){
480+
stream->wait();
481+
return SYCL_SUCCESS;
482+
}
483+
484+
473485
void* SYCLMalloc(sycl::device* device, size_t ByteCount) {
474486
sycl::queue* stream;
475487
SYCLStreamPool::getDefaultStream(device, &stream);

xla/stream_executor/sycl/sycl_gpu_runtime.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -99,6 +99,12 @@ void* SYCLMallocShared(sycl::device* device, size_t ByteCount);
9999

100100
void SYCLFree(sycl::device* device, void* ptr);
101101

102+
SYCLError_t SYCLMemcpyAsync(void* dst, const void* src, size_t ByteCount,
103+
SYCLError_t (*func)(void*, const void*, size_t, sycl::queue*),
104+
sycl::queue* stream);
105+
106+
SYCLError_t SYCLStreamSynchronize(sycl::queue* stream);
107+
102108
sycl::event SYCLGetEventFromStream(sycl::queue* stream);
103109

104110
void SYCLStreamDependOnEvents(sycl::queue* stream,

0 commit comments

Comments
 (0)