@@ -18,10 +18,11 @@ limitations under the License.
18
18
#include < string>
19
19
20
20
#include " xla/service/gpu/scratch_allocator.h"
21
- #include " xla/service/gpu/stream_executor_util .h"
21
+ #include " xla/service/onednn_util .h"
22
22
23
23
namespace xla {
24
24
namespace gpu {
25
+
25
26
using se::DeviceMemory;
26
27
using se::DeviceMemoryBase;
27
28
using se::Stream;
@@ -39,6 +40,29 @@ using ConvBwdInputPd = dnnl::convolution_backward_data::primitive_desc;
39
40
using ConvBwdFilterPd = dnnl::convolution_backward_weights::primitive_desc;
40
41
using ConvBwdFilterPrimitive = dnnl::convolution_backward_weights;
41
42
43
+ typedef struct OneDnnConvPrimitive {
44
+ dnnl::memory src_memory;
45
+ dnnl::memory filter_memory;
46
+ dnnl::memory dst_memory;
47
+ dnnl::memory internal_filter_memory;
48
+ dnnl::memory scratchpad_memory;
49
+ dnnl::memory bias_memory;
50
+ dnnl::convolution_forward fwd_primitive;
51
+ dnnl::convolution_backward_data bwd_input_primitive;
52
+ dnnl::convolution_backward_weights bwd_filter_primitive;
53
+ dnnl::reorder filter_reorder_primitive;
54
+
55
+ std::unordered_map<int , dnnl::memory> fwd_primitives_args;
56
+ std::unordered_map<int , dnnl::memory> bwd_input_primitive_args;
57
+ std::unordered_map<int , dnnl::memory> bwd_filter_primitive_args;
58
+
59
+ std::unordered_map<int , dnnl::memory> reorder_args;
60
+
61
+ dnnl::engine engine;
62
+ dnnl::stream stream;
63
+ bool has_reorder = false ;
64
+ } OneDnnConvPrimitive;
65
+
42
66
namespace {
43
67
44
68
int64_t GetVectCSize (DataLayout layout) {
@@ -67,7 +91,7 @@ absl::Status CreateOneDnnPrimitive(
67
91
OneDnnConvPrimitive* onednn_primitive, // NOLINT
68
92
const ffi::Dictionary& dict,
69
93
absl::Span<const ffi::BufferBase> operand_buffers,
70
- ffi::BufferBase result_buffer, se::Stream* stream,
94
+ const ffi::BufferBase& result_buffer, se::Stream* stream,
71
95
se::ScratchAllocator* scratch_allocator, CudnnConvKind conv_kind) {
72
96
sycl::queue* dpcpp_stream = se::gpu::AsGpuStreamValue (stream);
73
97
onednn_primitive->engine = FindOrCreateEngine (dpcpp_stream);
@@ -456,7 +480,8 @@ absl::Status CreateOneDnnPrimitive(
456
480
onednn_primitive->bias_memory });
457
481
}
458
482
if (conv_kind == CudnnConvKind::kForwardActivation ) {
459
- auto activation_mode = static_cast <stream_executor::dnn::ActivationMode>(*dict.get <int32_t >(" activation_mode" ));
483
+ auto activation_mode = static_cast <stream_executor::dnn::ActivationMode>(
484
+ *dict.get <int32_t >(" activation_mode" ));
460
485
switch (activation_mode) {
461
486
case stream_executor::dnn::kSigmoid :
462
487
po.append_eltwise (dnnl::algorithm::eltwise_logistic, 1 , 0 );
@@ -474,7 +499,8 @@ absl::Status CreateOneDnnPrimitive(
474
499
po.append_eltwise (dnnl::algorithm::eltwise_elu, 1 , 0 );
475
500
break ;
476
501
case stream_executor::dnn::kLeakyRelu :
477
- po.append_eltwise (dnnl::algorithm::eltwise_relu, *dict.get <float >(" leakyrelu_alpha" ), 0 );
502
+ po.append_eltwise (dnnl::algorithm::eltwise_relu,
503
+ *dict.get <float >(" leakyrelu_alpha" ), 0 );
478
504
break ;
479
505
case stream_executor::dnn::kNone :
480
506
break ;
@@ -680,30 +706,35 @@ absl::Status CreateOneDnnPrimitive(
680
706
681
707
absl::StatusOr<OneDnnConvPrimitive> GetOrCreateOneDnnConvPrimitive (
682
708
se::Stream* stream, const ffi::Dictionary& dict,
683
- const std::vector< ffi::BufferBase>& operand_se_buffers ,
709
+ absl::Span< const ffi::BufferBase> operand_buffers ,
684
710
const ffi::BufferBase& result_buffer,
685
711
se::ScratchAllocator* scratch_allocator, CudnnConvKind conv_kind) {
686
712
OneDnnConvPrimitive primitive;
687
- auto status = CreateOneDnnPrimitive (&primitive, dict,
688
- absl::MakeSpan (operand_se_buffers),
689
- result_buffer, stream, scratch_allocator,
690
- conv_kind);
713
+ auto status =
714
+ CreateOneDnnPrimitive (&primitive, dict, operand_buffers, result_buffer,
715
+ stream, scratch_allocator, conv_kind);
691
716
if (TF_PREDICT_FALSE (!status.ok ())) {
692
717
return status;
693
718
}
694
719
return primitive;
695
720
}
696
721
697
- absl::Status RunGpuConv (const OneDnnConvPrimitive& onednn_primitive,
698
- const ffi::Dictionary& dict,
722
+ absl::Status RunGpuConv (se::Stream* stream, const ffi::Dictionary& dict,
699
723
absl::Span<const ffi::BufferBase> operand_buffers,
700
- ffi::BufferBase result_buffer, CudnnConvKind conv_kind) {
724
+ ffi::BufferBase& result_buffer,
725
+ se::ScratchAllocator* allocator,
726
+ CudnnConvKind conv_kind) {
701
727
void * input_data;
702
728
void * filter_data;
703
729
void * output_data;
704
730
void * bias_data = nullptr ;
705
731
void * side_input_data = nullptr ;
706
732
733
+ TF_ASSIGN_OR_RETURN (
734
+ auto onednn_primitive,
735
+ GetOrCreateOneDnnConvPrimitive (stream, dict, operand_buffers,
736
+ result_buffer, allocator, conv_kind));
737
+
707
738
switch (conv_kind) {
708
739
case CudnnConvKind::kForward :
709
740
case CudnnConvKind::kForwardActivation :
@@ -776,4 +807,4 @@ absl::Status RunGpuConv(const OneDnnConvPrimitive& onednn_primitive,
776
807
}
777
808
778
809
} // namespace gpu
779
- } // namespace xla
810
+ } // namespace xla
0 commit comments