@@ -21,6 +21,24 @@ cc_library(
21
21
],
22
22
)
23
23
24
+ cc_library (
25
+ name = "sycl_custom_call" ,
26
+ srcs = [
27
+ "sycl_custom_call.cc" ,
28
+ ],
29
+ visibility = ["//visibility:public" ],
30
+ deps = [
31
+ "//xla/service:onednn_util" ,
32
+ "//xla/service/gpu:sycl_onednn" ,
33
+ "@xla//xla/ffi" ,
34
+ "@xla//xla/ffi:ffi_api" ,
35
+ "@xla//xla/stream_executor" ,
36
+ "@com_google_absl//absl/status" ,
37
+ "@com_google_absl//absl/strings" ,
38
+ ],
39
+ alwayslink = 1 ,
40
+ )
41
+
24
42
xetla_library (
25
43
name = "onednn_matmul_utils" ,
26
44
srcs = ["onednn_matmul_utils.cc" ],
@@ -29,27 +47,7 @@ xetla_library(
29
47
":scratch_allocator" ,
30
48
"//xla/service:onednn_util" ,
31
49
"//xla/service/gpu/xetla/gemm:gemm_kernel" ,
32
- "//xla/stream_executor/sycl:sycl_executor" ,
33
- "@com_google_absl//absl/algorithm:container" ,
34
- "@com_google_absl//absl/types:span" ,
35
- "@onednn_gpu//:onednn_gpu" ,
36
- "@tsl//tsl/framework:numeric_types" ,
37
- "@tsl//tsl/platform:statusor" ,
38
- "@tsl//tsl/platform:types" ,
39
- "@xetla//:xetla_header" ,
40
- "@xla//xla:shape_util" ,
41
- "@xla//xla:status_macros" ,
42
- "@xla//xla:statusor" ,
43
- "@xla//xla:types" ,
44
- "@xla//xla:util" ,
45
- "@xla//xla:xla_data_proto_cc" ,
46
- "@xla//xla/hlo/ir:hlo" ,
47
- "@xla//xla/mlir_hlo" ,
48
- "@xla//xla/mlir_hlo:lhlo_gpu" ,
49
- "@xla//xla/service/gpu:backend_configs_cc" ,
50
- "@xla//xla/service/gpu:ir_emission_utils" ,
51
50
"@xla//xla/service/gpu:matmul_utils" ,
52
- "@xla//xla/stream_executor:stream_executor_headers" ,
53
51
],
54
52
)
55
53
@@ -58,7 +56,7 @@ cc_library(
58
56
srcs = ["gemm_impl_picker.cc" ,],
59
57
hdrs = ["gemm_impl_picker.h" ],
60
58
deps = [
61
- ":onednn_matmul_utils " ,
59
+ ":sycl_onednn " ,
62
60
"//xla/stream_executor/sycl:hw_info" ,
63
61
"@com_google_absl//absl/algorithm:container" ,
64
62
"@tsl//tsl/platform:errors" ,
@@ -179,31 +177,50 @@ xpu_library(
179
177
],
180
178
)
181
179
180
+ cc_import (
181
+ name = "sycl_onednn" ,
182
+ hdrs = [
183
+ "sycl_onednn.h" ,
184
+ "onednn_gpu_conv_runner.h" ,
185
+ "onednn_matmul_utils.h" ,
186
+ ],
187
+ shared_library = ":sycl_onednn.so" ,
188
+ visibility = ["//visibility:public" ],
189
+ deps = [
190
+ ":scratch_allocator" ,
191
+ "@xla//xla/service/gpu:gpu_conv_runner" ,
192
+ "@xla//xla/service/gpu:thunk" ,
193
+ "@xla//xla/service/gpu:matmul_utils" ,
194
+ ],
195
+ )
196
+
197
+ cc_binary (
198
+ name = "sycl_onednn.so" ,
199
+ srcs = [
200
+ "sycl_onednn.cc" ,
201
+ "sycl_onednn.h" ,
202
+ ],
203
+ linkshared = True ,
204
+ deps = [
205
+ ":onednn_gpu_conv_runner" ,
206
+ ":onednn_matmul_utils" ,
207
+ ],
208
+ )
209
+
182
210
cc_library (
183
211
name = "onednn_gpu_conv_runner" ,
184
- srcs = ["onednn_gpu_conv_runner.cc" ],
185
- hdrs = ["onednn_gpu_conv_runner.h" ],
212
+ srcs = [
213
+ "onednn_gpu_conv_runner.cc" ,
214
+ "onednn_gpu_conv_runner.h" ,
215
+ ],
186
216
deps = [
187
217
":scratch_allocator" ,
188
218
"//xla/service:onednn_util" ,
189
- "@com_google_absl//absl/strings" ,
190
- "@tsl//tsl/framework:numeric_types" ,
191
- "@xla//xla:shape_util" ,
192
- "@xla//xla:status" ,
193
- "@xla//xla:status_macros" ,
194
- "@xla//xla:statusor" ,
195
- "@xla//xla:types" ,
196
- "@xla//xla:util" ,
197
- "@xla//xla:xla_data_proto_cc" ,
198
- "@xla//xla/hlo/ir:hlo" ,
199
- "@xla//xla/service/gpu:backend_configs_cc" ,
200
- "@xla//xla/service/gpu:cublas_cudnn" ,
219
+ "@xla//xla/ffi" ,
220
+ "@xla//xla/ffi:ffi_api" ,
201
221
"@xla//xla/service/gpu:gpu_conv_runner" ,
202
222
"@xla//xla/service/gpu:stream_executor_util" ,
203
223
"@xla//xla/service/gpu:thunk" ,
204
- "@xla//xla/stream_executor" ,
205
- "@xla//xla/stream_executor/gpu:gpu_stream" ,
206
- "@xla//xla/stream_executor/gpu:gpu_types_header" ,
207
224
],
208
225
)
209
226
@@ -300,4 +317,4 @@ cc_library(
300
317
"@xla//xla/service:hlo_pass" ,
301
318
"@xla//xla/service:pattern_matcher" ,
302
319
],
303
- )
320
+ )
0 commit comments