Skip to content

Commit 99f3efc

Browse files
author
Lu Teng
authored
Temporarily skip linking warning after enabling FFI (#369)
1 parent 4168481 commit 99f3efc

File tree

9 files changed

+223
-223
lines changed

9 files changed

+223
-223
lines changed

third_party/openxla.patch

Lines changed: 14 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -618,7 +618,7 @@ index 8acdeb102..50c7efe11 100644
618618

619619
protected:
620620
diff --git a/xla/service/computation_placer.cc b/xla/service/computation_placer.cc
621-
index b896c7d10..02c5a642a 100644
621+
index b896c7d10..bd5dcea2b 100644
622622
--- a/xla/service/computation_placer.cc
623623
+++ b/xla/service/computation_placer.cc
624624
@@ -31,6 +31,7 @@ limitations under the License.
@@ -629,7 +629,19 @@ index b896c7d10..02c5a642a 100644
629629
#include "xla/types.h"
630630
#include "xla/util.h"
631631
#include "tsl/platform/errors.h"
632-
@@ -216,6 +217,8 @@ static bool InitModule() {
632+
@@ -164,6 +165,11 @@ absl::StatusOr<DeviceAssignment> ComputationPlacer::AssignDevices(
633+
absl::MutexLock lock(&ComputationPlacer::platform_computation_placer_mutex_);
634+
auto* computation_placers = GetPlatformComputationPlacers();
635+
if (computation_placers->find(platform_id) != computation_placers->end()) {
636+
+ // FIXME(intel): Temporarily skip the registry to avoid linking warning.
637+
+ // Will reopen this check once refine oneDNN custom call code.
638+
+#ifdef TENSORFLOW_USE_SYCL
639+
+ return;
640+
+#endif
641+
// TODO(b/282059652): Consider logging the platform name using
642+
// PlatformManager::PlatformWithId(). No doing that for now to avoid
643+
// introducing unwanted dependency.
644+
@@ -216,6 +222,8 @@ static bool InitModule() {
633645
stream_executor::cuda::kCudaPlatformId, &CreateComputationPlacer);
634646
xla::ComputationPlacer::RegisterComputationPlacer(
635647
stream_executor::rocm::kROCmPlatformId, &CreateComputationPlacer);

xla/service/gpu/BUILD

Lines changed: 9 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -28,13 +28,10 @@ cc_library(
2828
],
2929
visibility = ["//visibility:public"],
3030
deps = [
31-
"//xla/service:onednn_util",
32-
"//xla/service/gpu:sycl_onednn",
33-
"@xla//xla/ffi",
31+
":sycl_onednn_header",
32+
"@xla//xla/ffi:ffi",
3433
"@xla//xla/ffi:ffi_api",
35-
"@xla//xla/stream_executor",
3634
"@com_google_absl//absl/status",
37-
"@com_google_absl//absl/strings",
3835
],
3936
alwayslink = 1,
4037
)
@@ -56,7 +53,7 @@ cc_library(
5653
srcs = ["gemm_impl_picker.cc",],
5754
hdrs = ["gemm_impl_picker.h"],
5855
deps = [
59-
":sycl_onednn",
56+
":sycl_onednn_header",
6057
"//xla/stream_executor/sycl:hw_info",
6158
"@com_google_absl//absl/algorithm:container",
6259
"@tsl//tsl/platform:errors",
@@ -72,7 +69,6 @@ cc_library(
7269
"@xla//xla/stream_executor",
7370
"@xla//xla/stream_executor:device_description",
7471
"@xla//xla/stream_executor:device_memory_allocator",
75-
"@xla//xla/stream_executor/gpu:gpu_stream",
7672
"@xla//xla/stream_executor/gpu:gpu_timer",
7773
"@xla//xla/service/gpu:ir_emission_utils",
7874
"@xla//xla/service/gpu:matmul_utils",
@@ -178,7 +174,7 @@ xpu_library(
178174
)
179175

180176
cc_import(
181-
name = "sycl_onednn",
177+
name = "sycl_onednn_header",
182178
hdrs = [
183179
"sycl_onednn.h",
184180
"onednn_gpu_conv_runner.h",
@@ -188,9 +184,8 @@ cc_import(
188184
visibility = ["//visibility:public"],
189185
deps = [
190186
":scratch_allocator",
191-
"@xla//xla/service/gpu:gpu_conv_runner",
192-
"@xla//xla/service/gpu:thunk",
193187
"@xla//xla/service/gpu:matmul_utils",
188+
"@xla//xla/stream_executor/gpu:gpu_stream",
194189
],
195190
)
196191

@@ -216,11 +211,9 @@ cc_library(
216211
deps = [
217212
":scratch_allocator",
218213
"//xla/service:onednn_util",
219-
"@xla//xla/ffi",
220-
"@xla//xla/ffi:ffi_api",
221-
"@xla//xla/service/gpu:gpu_conv_runner",
222-
"@xla//xla/service/gpu:stream_executor_util",
223-
"@xla//xla/service/gpu:thunk",
214+
"@xla//xla/ffi:ffi",
215+
"@xla//xla/service/gpu:cublas_cudnn",
216+
"@xla//xla/stream_executor/gpu:gpu_stream",
224217
],
225218
)
226219

@@ -317,4 +310,4 @@ cc_library(
317310
"@xla//xla/service:hlo_pass",
318311
"@xla//xla/service:pattern_matcher",
319312
],
320-
)
313+
)

xla/service/gpu/onednn_gpu_conv_runner.cc

Lines changed: 44 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -18,10 +18,11 @@ limitations under the License.
1818
#include <string>
1919

2020
#include "xla/service/gpu/scratch_allocator.h"
21-
#include "xla/service/gpu/stream_executor_util.h"
21+
#include "xla/service/onednn_util.h"
2222

2323
namespace xla {
2424
namespace gpu {
25+
2526
using se::DeviceMemory;
2627
using se::DeviceMemoryBase;
2728
using se::Stream;
@@ -39,6 +40,29 @@ using ConvBwdInputPd = dnnl::convolution_backward_data::primitive_desc;
3940
using ConvBwdFilterPd = dnnl::convolution_backward_weights::primitive_desc;
4041
using ConvBwdFilterPrimitive = dnnl::convolution_backward_weights;
4142

43+
typedef struct OneDnnConvPrimitive {
44+
dnnl::memory src_memory;
45+
dnnl::memory filter_memory;
46+
dnnl::memory dst_memory;
47+
dnnl::memory internal_filter_memory;
48+
dnnl::memory scratchpad_memory;
49+
dnnl::memory bias_memory;
50+
dnnl::convolution_forward fwd_primitive;
51+
dnnl::convolution_backward_data bwd_input_primitive;
52+
dnnl::convolution_backward_weights bwd_filter_primitive;
53+
dnnl::reorder filter_reorder_primitive;
54+
55+
std::unordered_map<int, dnnl::memory> fwd_primitives_args;
56+
std::unordered_map<int, dnnl::memory> bwd_input_primitive_args;
57+
std::unordered_map<int, dnnl::memory> bwd_filter_primitive_args;
58+
59+
std::unordered_map<int, dnnl::memory> reorder_args;
60+
61+
dnnl::engine engine;
62+
dnnl::stream stream;
63+
bool has_reorder = false;
64+
} OneDnnConvPrimitive;
65+
4266
namespace {
4367

4468
int64_t GetVectCSize(DataLayout layout) {
@@ -67,7 +91,7 @@ absl::Status CreateOneDnnPrimitive(
6791
OneDnnConvPrimitive* onednn_primitive, // NOLINT
6892
const ffi::Dictionary& dict,
6993
absl::Span<const ffi::BufferBase> operand_buffers,
70-
ffi::BufferBase result_buffer, se::Stream* stream,
94+
const ffi::BufferBase& result_buffer, se::Stream* stream,
7195
se::ScratchAllocator* scratch_allocator, CudnnConvKind conv_kind) {
7296
sycl::queue* dpcpp_stream = se::gpu::AsGpuStreamValue(stream);
7397
onednn_primitive->engine = FindOrCreateEngine(dpcpp_stream);
@@ -456,7 +480,8 @@ absl::Status CreateOneDnnPrimitive(
456480
onednn_primitive->bias_memory});
457481
}
458482
if (conv_kind == CudnnConvKind::kForwardActivation) {
459-
auto activation_mode = static_cast<stream_executor::dnn::ActivationMode>(*dict.get<int32_t>("activation_mode"));
483+
auto activation_mode = static_cast<stream_executor::dnn::ActivationMode>(
484+
*dict.get<int32_t>("activation_mode"));
460485
switch (activation_mode) {
461486
case stream_executor::dnn::kSigmoid:
462487
po.append_eltwise(dnnl::algorithm::eltwise_logistic, 1, 0);
@@ -474,7 +499,8 @@ absl::Status CreateOneDnnPrimitive(
474499
po.append_eltwise(dnnl::algorithm::eltwise_elu, 1, 0);
475500
break;
476501
case stream_executor::dnn::kLeakyRelu:
477-
po.append_eltwise(dnnl::algorithm::eltwise_relu, *dict.get<float>("leakyrelu_alpha"), 0);
502+
po.append_eltwise(dnnl::algorithm::eltwise_relu,
503+
*dict.get<float>("leakyrelu_alpha"), 0);
478504
break;
479505
case stream_executor::dnn::kNone:
480506
break;
@@ -680,30 +706,35 @@ absl::Status CreateOneDnnPrimitive(
680706

681707
absl::StatusOr<OneDnnConvPrimitive> GetOrCreateOneDnnConvPrimitive(
682708
se::Stream* stream, const ffi::Dictionary& dict,
683-
const std::vector<ffi::BufferBase>& operand_se_buffers,
709+
absl::Span<const ffi::BufferBase> operand_buffers,
684710
const ffi::BufferBase& result_buffer,
685711
se::ScratchAllocator* scratch_allocator, CudnnConvKind conv_kind) {
686712
OneDnnConvPrimitive primitive;
687-
auto status = CreateOneDnnPrimitive(&primitive, dict,
688-
absl::MakeSpan(operand_se_buffers),
689-
result_buffer, stream, scratch_allocator,
690-
conv_kind);
713+
auto status =
714+
CreateOneDnnPrimitive(&primitive, dict, operand_buffers, result_buffer,
715+
stream, scratch_allocator, conv_kind);
691716
if (TF_PREDICT_FALSE(!status.ok())) {
692717
return status;
693718
}
694719
return primitive;
695720
}
696721

697-
absl::Status RunGpuConv(const OneDnnConvPrimitive& onednn_primitive,
698-
const ffi::Dictionary& dict,
722+
absl::Status RunGpuConv(se::Stream* stream, const ffi::Dictionary& dict,
699723
absl::Span<const ffi::BufferBase> operand_buffers,
700-
ffi::BufferBase result_buffer, CudnnConvKind conv_kind) {
724+
ffi::BufferBase& result_buffer,
725+
se::ScratchAllocator* allocator,
726+
CudnnConvKind conv_kind) {
701727
void* input_data;
702728
void* filter_data;
703729
void* output_data;
704730
void* bias_data = nullptr;
705731
void* side_input_data = nullptr;
706732

733+
TF_ASSIGN_OR_RETURN(
734+
auto onednn_primitive,
735+
GetOrCreateOneDnnConvPrimitive(stream, dict, operand_buffers,
736+
result_buffer, allocator, conv_kind));
737+
707738
switch (conv_kind) {
708739
case CudnnConvKind::kForward:
709740
case CudnnConvKind::kForwardActivation:
@@ -776,4 +807,4 @@ absl::Status RunGpuConv(const OneDnnConvPrimitive& onednn_primitive,
776807
}
777808

778809
} // namespace gpu
779-
} // namespace xla
810+
} // namespace xla

xla/service/gpu/onednn_gpu_conv_runner.h

Lines changed: 6 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -19,50 +19,18 @@ limitations under the License.
1919
#include <optional>
2020

2121
#include "xla/ffi/ffi.h"
22-
#include "xla/ffi/ffi_api.h"
23-
#include "xla/service/gpu/gpu_conv_runner.h"
24-
#include "xla/service/gpu/thunk.h"
25-
#include "xla/service/onednn_util.h"
22+
#include "xla/service/gpu/cublas_cudnn.h"
23+
#include "xla/stream_executor/gpu/gpu_stream.h"
2624

2725
namespace xla {
2826

2927
namespace gpu {
3028

31-
typedef struct OneDnnConvPrimitive {
32-
dnnl::memory src_memory;
33-
dnnl::memory filter_memory;
34-
dnnl::memory dst_memory;
35-
dnnl::memory internal_filter_memory;
36-
dnnl::memory scratchpad_memory;
37-
dnnl::memory bias_memory;
38-
dnnl::convolution_forward fwd_primitive;
39-
dnnl::convolution_backward_data bwd_input_primitive;
40-
dnnl::convolution_backward_weights bwd_filter_primitive;
41-
dnnl::reorder filter_reorder_primitive;
42-
43-
std::unordered_map<int, dnnl::memory> fwd_primitives_args;
44-
std::unordered_map<int, dnnl::memory> bwd_input_primitive_args;
45-
std::unordered_map<int, dnnl::memory> bwd_filter_primitive_args;
46-
47-
std::unordered_map<int, dnnl::memory> reorder_args;
48-
49-
dnnl::engine engine;
50-
dnnl::stream stream;
51-
bool has_reorder = false;
52-
} OneDnnConvPrimitive;
53-
54-
absl::StatusOr<OneDnnConvPrimitive> GetOrCreateOneDnnConvPrimitive(
55-
se::Stream*, const ffi::Dictionary& dict,
56-
const std::vector<ffi::BufferBase>& operand_se_buffers,
57-
const ffi::BufferBase& result_buffer,
58-
se::ScratchAllocator* scratch_allocator, CudnnConvKind conv_kind);
59-
60-
absl::Status RunGpuConv(const OneDnnConvPrimitive& onednn_primitive,
61-
const ffi::Dictionary& dict,
62-
absl::Span<const ffi::BufferBase> operand_buffers,
63-
ffi::BufferBase result_buffer, CudnnConvKind conv_kind);
29+
absl::Status RunGpuConv(se::Stream*, const ffi::Dictionary&,
30+
absl::Span<const ffi::BufferBase>, ffi::BufferBase&,
31+
se::ScratchAllocator*, CudnnConvKind);
6432

6533
} // namespace gpu
6634
} // namespace xla
6735

68-
#endif // XLA_SERVICE_GPU_ONEDNN_GPU_CONV_RUNNER_H_
36+
#endif // XLA_SERVICE_GPU_ONEDNN_GPU_CONV_RUNNER_H_

0 commit comments

Comments
 (0)