@@ -14,24 +14,31 @@ limitations under the License.
1414==============================================================================*/
1515
1616#include " tensorflow/core/kernels/tensor_buffer_ops.h"
17+ #include " tensorflow/core/util/work_sharder.h"
18+
19+ #if GOOGLE_CUDA
20+ #include " tensorflow/core/common_runtime/gpu/gpu_device.h"
21+ #include " tensorflow/stream_executor/device_memory.h"
22+ #include " tensorflow/stream_executor/stream.h"
23+ #endif // GOOGLE_CUDA
1724
1825namespace tensorflow {
1926
2027class TensorBufferOp : public OpKernel {
2128 public:
22- explicit TensorBufferOp (OpKernelConstruction* ctx) : OpKernel(ctx) {}
29+ explicit TensorBufferOp (OpKernelConstruction * ctx) : OpKernel(ctx) {}
2330
24- void Compute (OpKernelContext* ctx) override {
31+ void Compute (OpKernelContext * ctx) override {
2532 auto rm = ctx->resource_manager ();
2633 auto ndef = def ();
2734
2835 ContainerInfo cinfo;
2936 OP_REQUIRES_OK (ctx, cinfo.Init (rm, ndef, true /* use name() */ ));
3037
31- TensorBuf* buffer = nullptr ;
38+ TensorBuf * buffer = nullptr ;
3239 OP_REQUIRES_OK (ctx, rm->LookupOrCreate <TensorBuf>(
3340 cinfo.container (), cinfo.name (), &buffer,
34- [&ndef](TensorBuf** pbuf) -> Status {
41+ [&ndef](TensorBuf ** pbuf) -> Status {
3542 int64 capacity;
3643 TF_RETURN_IF_ERROR (GetNodeAttr (
3744 ndef, " shared_capacity" , &capacity));
@@ -43,33 +50,34 @@ class TensorBufferOp : public OpKernel {
4350 }
4451
4552 protected:
46- virtual void ComputeWithTensorBuf (OpKernelContext* ctx, TensorBuf* buf) = 0;
53+ virtual void ComputeWithTensorBuf (OpKernelContext * ctx, TensorBuf * buf) = 0;
4754};
4855
4956class TensorBufferAsyncOp : public AsyncOpKernel {
5057 public:
51- explicit TensorBufferAsyncOp (OpKernelConstruction* ctx) : AsyncOpKernel(ctx) {
58+ explicit TensorBufferAsyncOp (OpKernelConstruction * ctx) : AsyncOpKernel(ctx) {
5259 OP_REQUIRES_OK (ctx, ctx->GetAttr (" shared_name" , &shared_name_));
5360 OP_REQUIRES_OK (ctx, ctx->GetAttr (" shared_threads" , &shared_threads_));
5461 }
5562
56- void ComputeAsync (OpKernelContext* ctx,
63+ void ComputeAsync (OpKernelContext * ctx,
5764 AsyncOpKernel::DoneCallback done) override {
5865 auto rm = ctx->resource_manager ();
5966 NodeDef ndef (def ());
6067 ContainerInfo cinfo;
6168 OP_REQUIRES_OK_ASYNC (ctx, cinfo.Init (rm, ndef, true /* use name() */ ),
6269 done);
63- TensorBuf* buffer = nullptr ;
64- OP_REQUIRES_OK_ASYNC (ctx, rm->LookupOrCreate <TensorBuf>(
65- cinfo.container (), cinfo.name (), &buffer,
66- [&ndef](TensorBuf** resource) {
67- int64 capacity;
68- TF_RETURN_IF_ERROR (GetNodeAttr (
69- ndef, " shared_capacity" , &capacity));
70- *resource = new TensorBuf (capacity);
71- return Status::OK ();
72- }),
70+ TensorBuf *buffer = nullptr ;
71+ OP_REQUIRES_OK_ASYNC (ctx,
72+ rm->LookupOrCreate <TensorBuf>(
73+ cinfo.container (), cinfo.name (), &buffer,
74+ [&ndef](TensorBuf **resource) {
75+ int64 capacity;
76+ TF_RETURN_IF_ERROR (GetNodeAttr (
77+ ndef, " shared_capacity" , &capacity));
78+ *resource = new TensorBuf (capacity);
79+ return Status::OK ();
80+ }),
7381 done);
7482 core::ScopedUnref scoped_list (buffer);
7583 Schedule (buffer, [this , ctx, done, buffer]() {
@@ -78,26 +86,117 @@ class TensorBufferAsyncOp : public AsyncOpKernel {
7886 }
7987
8088 protected:
81- virtual void ComputeAsyncWithTensorBuf (OpKernelContext* ctx,
89+ virtual void ComputeAsyncWithTensorBuf (OpKernelContext * ctx,
8290 AsyncOpKernel::DoneCallback done,
83- TensorBuf* buffer) = 0;
91+ TensorBuf * buffer) = 0;
8492
8593 private:
8694 string shared_name_;
8795 int64 shared_threads_;
8896
89- void Schedule (TensorBuf* buffer, std::function<void ()> fn) {
97+ void Schedule (TensorBuf * buffer, std::function<void ()> fn) {
9098 buffer->Schedule (shared_name_, shared_threads_, fn);
9199 }
92100};
93101
94- class TensorBufferPutOp : public TensorBufferOp {
102+ #ifdef GOOGLE_CUDA
103+ class TensorBufferPutGpuOp : public TensorBufferOp {
104+ public:
105+ explicit TensorBufferPutGpuOp (OpKernelConstruction *ctx)
106+ : TensorBufferOp(ctx) {
107+ OP_REQUIRES_OK (ctx, ctx->GetAttr (" timeout_millis" , &timeout_millis_));
108+ }
109+
110+ inline size_t AlignBytes (size_t s) {
111+ #if EIGEN_MAX_ALIGN_BYTES == 0
112+ return s;
113+ #else
114+ return std::ceil (s * 1.0 / EIGEN_MAX_ALIGN_BYTES) * EIGEN_MAX_ALIGN_BYTES;
115+ #endif
116+ }
117+
118+ void ComputeWithTensorBuf (OpKernelContext *ctx, TensorBuf *buf) override {
119+ std::vector<int > input_offsets;
120+ int total_bytes = 0 ;
121+ int input_nums = ctx->num_inputs ();
122+ for (int i = 0 ; i < input_nums; ++i) {
123+ auto &tensor_in = ctx->input (i);
124+ input_offsets.push_back (total_bytes);
125+ total_bytes += AlignBytes (tensor_in.TotalBytes ());
126+ }
127+
128+ Tensor fused_tensor;
129+ // Allocate Pinned memory
130+ AllocatorAttributes cpu_alloc_attr;
131+ cpu_alloc_attr.set_on_host (true );
132+ cpu_alloc_attr.set_gpu_compatible (true );
133+ OP_REQUIRES_OK (ctx, ctx->allocate_temp (DT_INT8, {total_bytes},
134+ &fused_tensor, cpu_alloc_attr));
135+ auto fused_tensor_data =
136+ const_cast <char *>(fused_tensor.tensor_data ().data ());
137+
138+ auto copy_task = [this , ctx, &input_offsets, &fused_tensor_data](
139+ int64 start, int64 end) {
140+ for (auto i = start; i < end; ++i) {
141+ const Tensor &input_tensor = ctx->input (i);
142+ size_t tensor_bytes = input_tensor.TotalBytes ();
143+ std::copy_n (input_tensor.tensor_data ().data (), tensor_bytes,
144+ fused_tensor_data + input_offsets[i]);
145+ }
146+ };
147+ static const int cost = 1000 ;
148+ auto worker_threads = *(ctx->device ()->tensorflow_cpu_worker_threads ());
149+ Shard (worker_threads.num_threads , worker_threads.workers , input_nums, cost,
150+ copy_task);
151+
152+ auto *d_context =
153+ static_cast <const GPUDeviceContext *>(ctx->op_device_context ());
154+ se::Stream *copy_stream = d_context->host_to_device_stream ();
155+ se::Stream *compute_stream = d_context->stream ();
156+
157+ Tensor gpu_tensor;
158+ OP_REQUIRES_OK (ctx,
159+ ctx->allocate_temp (DT_INT8, {total_bytes}, &gpu_tensor));
160+
161+ copy_stream->ThenWaitFor (compute_stream);
162+ se::DeviceMemoryBase wrapped_dst (
163+ const_cast <char *>(gpu_tensor.tensor_data ().data ()),
164+ gpu_tensor.TotalBytes ());
165+ copy_stream
166+ ->ThenMemcpy (&wrapped_dst, const_cast <char *>(fused_tensor_data),
167+ gpu_tensor.TotalBytes ())
168+ .ok ();
169+ compute_stream->ThenWaitFor (copy_stream);
170+
171+ std::vector<Tensor> record;
172+ record.reserve (input_nums);
173+ for (int i = 0 ; i < input_nums; ++i) {
174+ size_t bytes_tensor_offset = input_offsets[i];
175+ Tensor tensor_slice =
176+ gpu_tensor.Slice (bytes_tensor_offset,
177+ bytes_tensor_offset + ctx->input (i).TotalBytes ());
178+ Tensor output (ctx->input (i).dtype ());
179+ OP_REQUIRES_OK (ctx,
180+ output.BitcastFrom (tensor_slice, ctx->input (i).dtype (),
181+ ctx->input (i).shape ()));
182+ record.emplace_back (output);
183+ }
184+ ctx->SetStatus (buf->Put (record, timeout_millis_));
185+ }
186+
187+ private:
188+ int64 timeout_millis_;
189+ };
190+ #endif // GOOGLE_CUDA
191+
192+ class TensorBufferPutCpuOp : public TensorBufferOp {
95193 public:
96- explicit TensorBufferPutOp (OpKernelConstruction* ctx) : TensorBufferOp(ctx) {
194+ explicit TensorBufferPutCpuOp (OpKernelConstruction *ctx)
195+ : TensorBufferOp(ctx) {
97196 OP_REQUIRES_OK (ctx, ctx->GetAttr (" timeout_millis" , &timeout_millis_));
98197 }
99198
100- void ComputeWithTensorBuf (OpKernelContext* ctx, TensorBuf* buf) override {
199+ void ComputeWithTensorBuf (OpKernelContext * ctx, TensorBuf * buf) override {
101200 std::vector<Tensor> record;
102201 record.reserve (ctx->num_inputs ());
103202 for (int i = 0 ; i < ctx->num_inputs (); ++i) {
@@ -111,24 +210,26 @@ class TensorBufferPutOp : public TensorBufferOp {
111210};
112211
113212REGISTER_KERNEL_BUILDER (Name(" TensorBufferPut" ).Device(DEVICE_CPU),
114- TensorBufferPutOp );
213+ TensorBufferPutCpuOp );
115214#if GOOGLE_CUDA
116- REGISTER_KERNEL_BUILDER (Name(" TensorBufferPut" ).Device(DEVICE_GPU),
117- TensorBufferPutOp);
215+ REGISTER_KERNEL_BUILDER (
216+ Name (" TensorBufferPut" ).Device(DEVICE_GPU).HostMemory(" record" ),
217+ TensorBufferPutGpuOp);
118218#endif // GOOGLE_CUDA
219+
119220#ifdef TENSORFLOW_USE_SYCL
120221REGISTER_KERNEL_BUILDER (Name(" TensorBufferPut" ).Device(DEVICE_SYCL),
121222 TensorBufferPutOp);
122223#endif // TENSORFLOW_USE_SYCL
123224
124225class TensorBufferTakeOp : public TensorBufferAsyncOp {
125226 public:
126- explicit TensorBufferTakeOp (OpKernelConstruction* ctx)
227+ explicit TensorBufferTakeOp (OpKernelConstruction * ctx)
127228 : TensorBufferAsyncOp(ctx) {}
128229
129- void ComputeAsyncWithTensorBuf (OpKernelContext* ctx,
230+ void ComputeAsyncWithTensorBuf (OpKernelContext * ctx,
130231 AsyncOpKernel::DoneCallback done,
131- TensorBuf* buf) override {
232+ TensorBuf * buf) override {
132233 std::vector<Tensor> record;
133234 Status s = buf->Take (&record);
134235 if (TF_PREDICT_FALSE (!s.ok ())) {
@@ -163,11 +264,12 @@ REGISTER_KERNEL_BUILDER(Name("TensorBufferTake").Device(DEVICE_SYCL),
163264
164265class TensorBufferCancelOp : public TensorBufferOp {
165266 public:
166- explicit TensorBufferCancelOp (OpKernelConstruction* ctx) : TensorBufferOp(ctx) {
267+ explicit TensorBufferCancelOp (OpKernelConstruction *ctx)
268+ : TensorBufferOp(ctx) {
167269 OP_REQUIRES_OK (ctx, ctx->GetAttr (" is_cancelled" , &is_cancelled_));
168270 }
169271
170- void ComputeWithTensorBuf (OpKernelContext* ctx, TensorBuf* buf) override {
272+ void ComputeWithTensorBuf (OpKernelContext * ctx, TensorBuf * buf) override {
171273 ctx->SetStatus (buf->Cancel (is_cancelled_));
172274 }
173275
@@ -188,9 +290,10 @@ REGISTER_KERNEL_BUILDER(Name("TensorBufferCancel").Device(DEVICE_SYCL),
188290
189291class TensorBufferCloseOp : public TensorBufferOp {
190292 public:
191- explicit TensorBufferCloseOp (OpKernelConstruction* ctx) : TensorBufferOp(ctx) {}
293+ explicit TensorBufferCloseOp (OpKernelConstruction *ctx)
294+ : TensorBufferOp(ctx) {}
192295
193- void ComputeWithTensorBuf (OpKernelContext* ctx, TensorBuf* buf) override {
296+ void ComputeWithTensorBuf (OpKernelContext * ctx, TensorBuf * buf) override {
194297 ctx->SetStatus (buf->Close ());
195298 }
196299};
@@ -208,10 +311,11 @@ REGISTER_KERNEL_BUILDER(Name("TensorBufferClose").Device(DEVICE_SYCL),
208311
209312class TensorBufferSizeOp : public TensorBufferOp {
210313 public:
211- explicit TensorBufferSizeOp (OpKernelConstruction* ctx) : TensorBufferOp(ctx) {}
314+ explicit TensorBufferSizeOp (OpKernelConstruction *ctx)
315+ : TensorBufferOp(ctx) {}
212316
213- void ComputeWithTensorBuf (OpKernelContext* ctx, TensorBuf* buf) override {
214- Tensor* size = nullptr ;
317+ void ComputeWithTensorBuf (OpKernelContext * ctx, TensorBuf * buf) override {
318+ Tensor * size = nullptr ;
215319 OP_REQUIRES_OK (ctx, ctx->allocate_output (0 , TensorShape ({}), &size));
216320 OP_REQUIRES_OK (ctx, buf->GetSize (size));
217321 }
0 commit comments