Skip to content

Commit 18e0e04

Browse files
mayuyuaceLu Teng
andauthored
Support FFI (main) (#357)
Co-authored-by: Lu Teng <[email protected]>
1 parent b32771d commit 18e0e04

File tree

14 files changed

+1132
-549
lines changed

14 files changed

+1132
-549
lines changed

third_party/openxla.patch

Lines changed: 313 additions & 244 deletions
Large diffs are not rendered by default.

xla/BUILD

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,6 @@
11
cc_binary(
22
name = "pjrt_plugin_xpu.so",
3+
linkopts = ["-Wl,-rpath,$$ORIGIN/../intel_extension_for_openxla/service/gpu"],
34
linkshared = True,
45
visibility = ["//visibility:public"],
56
deps = [

xla/service/gpu/BUILD

Lines changed: 56 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,24 @@ cc_library(
2121
],
2222
)
2323

24+
cc_library(
25+
name = "sycl_custom_call",
26+
srcs = [
27+
"sycl_custom_call.cc",
28+
],
29+
visibility = ["//visibility:public"],
30+
deps = [
31+
"//xla/service:onednn_util",
32+
"//xla/service/gpu:sycl_onednn",
33+
"@xla//xla/ffi",
34+
"@xla//xla/ffi:ffi_api",
35+
"@xla//xla/stream_executor",
36+
"@com_google_absl//absl/status",
37+
"@com_google_absl//absl/strings",
38+
],
39+
alwayslink = 1,
40+
)
41+
2442
xetla_library(
2543
name = "onednn_matmul_utils",
2644
srcs = ["onednn_matmul_utils.cc"],
@@ -29,27 +47,7 @@ xetla_library(
2947
":scratch_allocator",
3048
"//xla/service:onednn_util",
3149
"//xla/service/gpu/xetla/gemm:gemm_kernel",
32-
"//xla/stream_executor/sycl:sycl_executor",
33-
"@com_google_absl//absl/algorithm:container",
34-
"@com_google_absl//absl/types:span",
35-
"@onednn_gpu//:onednn_gpu",
36-
"@tsl//tsl/framework:numeric_types",
37-
"@tsl//tsl/platform:statusor",
38-
"@tsl//tsl/platform:types",
39-
"@xetla//:xetla_header",
40-
"@xla//xla:shape_util",
41-
"@xla//xla:status_macros",
42-
"@xla//xla:statusor",
43-
"@xla//xla:types",
44-
"@xla//xla:util",
45-
"@xla//xla:xla_data_proto_cc",
46-
"@xla//xla/hlo/ir:hlo",
47-
"@xla//xla/mlir_hlo",
48-
"@xla//xla/mlir_hlo:lhlo_gpu",
49-
"@xla//xla/service/gpu:backend_configs_cc",
50-
"@xla//xla/service/gpu:ir_emission_utils",
5150
"@xla//xla/service/gpu:matmul_utils",
52-
"@xla//xla/stream_executor:stream_executor_headers",
5351
],
5452
)
5553

@@ -58,7 +56,7 @@ cc_library(
5856
srcs = ["gemm_impl_picker.cc",],
5957
hdrs = ["gemm_impl_picker.h"],
6058
deps = [
61-
":onednn_matmul_utils",
59+
":sycl_onednn",
6260
"//xla/stream_executor/sycl:hw_info",
6361
"@com_google_absl//absl/algorithm:container",
6462
"@tsl//tsl/platform:errors",
@@ -179,31 +177,50 @@ xpu_library(
179177
],
180178
)
181179

180+
cc_import(
181+
name = "sycl_onednn",
182+
hdrs = [
183+
"sycl_onednn.h",
184+
"onednn_gpu_conv_runner.h",
185+
"onednn_matmul_utils.h",
186+
],
187+
shared_library = ":sycl_onednn.so",
188+
visibility = ["//visibility:public"],
189+
deps = [
190+
":scratch_allocator",
191+
"@xla//xla/service/gpu:gpu_conv_runner",
192+
"@xla//xla/service/gpu:thunk",
193+
"@xla//xla/service/gpu:matmul_utils",
194+
],
195+
)
196+
197+
cc_binary(
198+
name = "sycl_onednn.so",
199+
srcs = [
200+
"sycl_onednn.cc",
201+
"sycl_onednn.h",
202+
],
203+
linkshared = True,
204+
deps = [
205+
":onednn_gpu_conv_runner",
206+
":onednn_matmul_utils",
207+
],
208+
)
209+
182210
cc_library(
183211
name = "onednn_gpu_conv_runner",
184-
srcs = ["onednn_gpu_conv_runner.cc"],
185-
hdrs = ["onednn_gpu_conv_runner.h"],
212+
srcs = [
213+
"onednn_gpu_conv_runner.cc",
214+
"onednn_gpu_conv_runner.h",
215+
],
186216
deps = [
187217
":scratch_allocator",
188218
"//xla/service:onednn_util",
189-
"@com_google_absl//absl/strings",
190-
"@tsl//tsl/framework:numeric_types",
191-
"@xla//xla:shape_util",
192-
"@xla//xla:status",
193-
"@xla//xla:status_macros",
194-
"@xla//xla:statusor",
195-
"@xla//xla:types",
196-
"@xla//xla:util",
197-
"@xla//xla:xla_data_proto_cc",
198-
"@xla//xla/hlo/ir:hlo",
199-
"@xla//xla/service/gpu:backend_configs_cc",
200-
"@xla//xla/service/gpu:cublas_cudnn",
219+
"@xla//xla/ffi",
220+
"@xla//xla/ffi:ffi_api",
201221
"@xla//xla/service/gpu:gpu_conv_runner",
202222
"@xla//xla/service/gpu:stream_executor_util",
203223
"@xla//xla/service/gpu:thunk",
204-
"@xla//xla/stream_executor",
205-
"@xla//xla/stream_executor/gpu:gpu_stream",
206-
"@xla//xla/stream_executor/gpu:gpu_types_header",
207224
],
208225
)
209226

@@ -300,4 +317,4 @@ cc_library(
300317
"@xla//xla/service:hlo_pass",
301318
"@xla//xla/service:pattern_matcher",
302319
],
303-
)
320+
)

xla/service/gpu/gemm_impl_picker.cc

Lines changed: 2 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -55,26 +55,6 @@ bool IsXetlaSupport(const GemmConfig& config) {
5555
return xetla_support;
5656
}
5757

58-
absl::StatusOr<se::gpu::BlasLt::Epilogue> AsBlasLtEpilogue(
59-
GemmBackendConfig_Epilogue epilogue) {
60-
switch (epilogue) {
61-
case GemmBackendConfig::DEFAULT:
62-
return se::gpu::BlasLt::Epilogue::kDefault;
63-
case GemmBackendConfig::RELU:
64-
return se::gpu::BlasLt::Epilogue::kReLU;
65-
case GemmBackendConfig::GELU:
66-
return se::gpu::BlasLt::Epilogue::kGELU;
67-
case GemmBackendConfig::BIAS:
68-
return se::gpu::BlasLt::Epilogue::kBias;
69-
case GemmBackendConfig::BIAS_RELU:
70-
return se::gpu::BlasLt::Epilogue::kBiasThenReLU;
71-
case GemmBackendConfig::BIAS_GELU:
72-
return se::gpu::BlasLt::Epilogue::kBiasThenGELU;
73-
default:
74-
return absl::InternalError("Unsupported Epilogue.");
75-
}
76-
}
77-
7858
absl::StatusOr<absl::Duration> GetExecuteTime(
7959
const HloInstruction* gemm, const AutotuneConfig& autotune_config) {
8060
se::DeviceMemoryAllocator* allocator = autotune_config.GetAllocator();
@@ -122,7 +102,7 @@ absl::StatusOr<absl::Duration> GetExecuteTime(
122102
autotune_config, rng_state));
123103
}
124104

125-
TF_ASSIGN_OR_RETURN(auto epilogue, AsBlasLtEpilogue(gemm_config.epilogue()));
105+
TF_ASSIGN_OR_RETURN(auto epilogue, SYCLGemm::AsSYCLEpilogue(gemm_config.epilogue()));
126106
se::OwningScratchAllocator<> scratch_allocator(
127107
stream->parent()->device_ordinal(), autotune_config.GetAllocator());
128108

@@ -264,4 +244,4 @@ absl::StatusOr<bool> GemmAlgorithmPicker::Run(
264244
}
265245

266246
} // namespace gpu
267-
} // namespace xla
247+
} // namespace xla

0 commit comments

Comments
 (0)