Skip to content

Commit 1d91a49

Browse files
author
chengduo
authored
Some trivial optimization (#13530)
* some trivial opt * remove the fix of lod_tensor and shrink_rnn_memory_op * refine ShrinkRNNMemoryOp test=develop
1 parent 5093afc commit 1d91a49

File tree

10 files changed

+116
-44
lines changed

10 files changed

+116
-44
lines changed

paddle/fluid/framework/op_info.h

Lines changed: 11 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -38,27 +38,31 @@ struct OpInfo {
3838
OpAttrChecker* checker_{nullptr};
3939
InferVarTypeFN infer_var_type_;
4040
InferShapeFN infer_shape_;
41+
std::string op_type_;
4142

4243
bool HasOpProtoAndChecker() const {
4344
return proto_ != nullptr && checker_ != nullptr;
4445
}
4546

4647
const proto::OpProto& Proto() const {
47-
PADDLE_ENFORCE_NOT_NULL(proto_, "Operator Proto has not been registered");
48+
PADDLE_ENFORCE_NOT_NULL(proto_, "Operator %s Proto has not been registered",
49+
op_type_);
4850
PADDLE_ENFORCE(proto_->IsInitialized(),
49-
"Operator Proto must be initialized in op info");
51+
"Operator %s Proto must be initialized in op info",
52+
op_type_);
5053
return *proto_;
5154
}
5255

5356
const OpCreator& Creator() const {
54-
PADDLE_ENFORCE_NOT_NULL(creator_,
55-
"Operator Creator has not been registered");
57+
PADDLE_ENFORCE_NOT_NULL(
58+
creator_, "Operator %s Creator has not been registered", op_type_);
5659
return creator_;
5760
}
5861

5962
const GradOpMakerFN& GradOpMaker() const {
6063
PADDLE_ENFORCE_NOT_NULL(grad_op_maker_,
61-
"Operator GradOpMaker has not been registered.");
64+
"Operator %s GradOpMaker has not been registered.",
65+
op_type_);
6266
return grad_op_maker_;
6367
}
6468

@@ -73,8 +77,9 @@ class OpInfoMap {
7377
return map_.find(op_type) != map_.end();
7478
}
7579

76-
void Insert(const std::string& type, const OpInfo& info) {
80+
void Insert(const std::string& type, OpInfo info) {
7781
PADDLE_ENFORCE(!Has(type), "Operator %s has been registered", type);
82+
info.op_type_ = type;
7883
map_.insert({type, info});
7984
}
8085

paddle/fluid/operators/read_op.cc

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -45,10 +45,12 @@ class ReadInferVarType : public framework::VarTypeInference {
4545
framework::VarDesc* reader = block->FindVarRecursive(reader_name);
4646
auto dtypes = reader->GetDataTypes();
4747
PADDLE_ENFORCE_EQ(dtypes.size(), out_names.size());
48+
auto lod_levels = reader->GetLoDLevels();
4849
for (size_t i = 0; i < dtypes.size(); ++i) {
4950
framework::VarDesc& out = block->FindRecursiveOrCreateVar(out_names[i]);
5051
out.SetType(framework::proto::VarType::LOD_TENSOR);
5152
out.SetDataType(dtypes[i]);
53+
out.SetLoDLevel(lod_levels[i]);
5254
}
5355
}
5456
};

paddle/fluid/operators/sgd_op.cu

Lines changed: 21 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,7 @@ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
1212
See the License for the specific language governing permissions and
1313
limitations under the License. */
1414

15-
#define EIGEN_USE_GPU
15+
#include <algorithm>
1616
#include "paddle/fluid/operators/sgd_op.h"
1717
#include "paddle/fluid/platform/cuda_primitives.h"
1818

@@ -33,22 +33,21 @@ __global__ void SGDKernel(const T* g, const T* p, const T* learning_rate,
3333
}
3434
}
3535

36-
template <typename T, int block_size>
36+
template <typename T>
3737
__global__ void SparseSGDFunctorKernel(const T* selected_rows,
3838
const int64_t* rows,
3939
const T* learning_rate, T* tensor_out,
40-
int64_t row_numel) {
41-
const int ty = blockIdx.y;
42-
int tid = threadIdx.x;
43-
44-
selected_rows += ty * row_numel;
45-
tensor_out += rows[ty] * row_numel;
46-
47-
for (int index = tid; index < row_numel; index += block_size) {
48-
// Since index in rows of SelectedRows can be duplicate, we have to use
49-
// Atomic Operation to avoid concurrent write error.
50-
paddle::platform::CudaAtomicAdd(
51-
tensor_out + index, -1.0 * learning_rate[0] * selected_rows[index]);
40+
int64_t row_numel, int64_t limit) {
41+
for (int64_t i = blockIdx.x; i < limit; i += gridDim.x) {
42+
const T* selected_rows_ptr = selected_rows + i * row_numel;
43+
T* tensor_out_ptr = tensor_out + rows[i] * row_numel;
44+
for (int64_t index = threadIdx.x; index < row_numel; index += blockDim.x) {
45+
// Since index in rows of SelectedRows can be duplicate, we have to use
46+
// Atomic Operation to avoid concurrent write error.
47+
paddle::platform::CudaAtomicAdd(
48+
tensor_out_ptr + index,
49+
-1.0 * learning_rate[0] * selected_rows_ptr[index]);
50+
}
5251
}
5352
}
5453
} // namespace
@@ -97,13 +96,15 @@ class SGDOpCUDAKernel : public framework::OpKernel<T> {
9796
auto* in_data = in_value.data<T>();
9897
auto* out_data = param_out->data<T>();
9998

100-
const int block_size = 256;
101-
dim3 threads(block_size, 1);
102-
dim3 grid(1, in_rows.size());
103-
SparseSGDFunctorKernel<
104-
T, 256><<<grid, threads, 0, ctx.cuda_device_context().stream()>>>(
99+
const int kThreadsPerBlock = 256;
100+
int thread_x = kThreadsPerBlock;
101+
int max_threads = ctx.cuda_device_context().GetMaxPhysicalThreadCount();
102+
int max_blocks = std::max(max_threads / kThreadsPerBlock, 1);
103+
104+
SparseSGDFunctorKernel<<<max_blocks, thread_x, 0,
105+
ctx.cuda_device_context().stream()>>>(
105106
in_data, in_rows.CUDAData(ctx.GetPlace()), learning_rate->data<T>(),
106-
out_data, in_row_numel);
107+
out_data, in_row_numel, in_rows.size());
107108

108109
} else {
109110
PADDLE_THROW("Unsupported Variable Type of Grad");

paddle/fluid/operators/shrink_rnn_memory_op.cc

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -52,16 +52,26 @@ class ShrinkRNNMemoryOp : public ArrayOp {
5252
size_t height = dst_num_rows;
5353

5454
// do shrink for the top level LoD
55+
5556
if (x_tensor.lod().size() > 0 &&
5657
x_tensor.lod()[0].size() > static_cast<size_t>(dst_num_rows)) {
57-
auto lod_offset = framework::GetSubLoDAndAbsoluteOffset(x_tensor.lod(), 0,
58-
dst_num_rows, 0);
59-
height = lod_offset.second.second;
60-
auto out_lod = out_tensor.mutable_lod();
61-
framework::AppendLoD(out_lod, lod_offset.first);
58+
if (x_tensor.lod().size() > 1) { // MultiLevel LoD
59+
auto lod_offset = framework::GetSubLoDAndAbsoluteOffset(
60+
x_tensor.lod(), 0, dst_num_rows, 0);
61+
height = lod_offset.second.second;
62+
auto out_lod = out_tensor.mutable_lod();
63+
framework::AppendLoD(out_lod, lod_offset.first);
64+
} else {
65+
// Shrink LoD
66+
auto lod_item = x_tensor.lod()[0];
67+
lod_item.resize(dst_num_rows + 1);
68+
out_tensor.set_lod({lod_item});
69+
const auto &const_lod_item = lod_item;
70+
height = const_lod_item.back();
71+
}
6272
}
6373

64-
if (dst_num_rows != 0) {
74+
if (height != 0) {
6575
out_tensor.mutable_data(place, x_tensor.type());
6676
auto dev_ctx = platform::DeviceContextPool::Instance().Get(place);
6777
framework::TensorCopy(x_tensor.Slice(0, height), place, *dev_ctx,
@@ -134,8 +144,11 @@ class ShrinkRNNMemoryGradOp : public ArrayOp {
134144
} else {
135145
auto &dout_tensor = dout_var->Get<framework::LoDTensor>();
136146
auto height = dout_tensor.dims()[0];
137-
auto slice = dx_tensor.Slice(0, static_cast<int>(height));
138-
framework::TensorCopy(dout_tensor, dout_tensor.place(), dev_ctx, &slice);
147+
if (height != 0) {
148+
auto slice = dx_tensor.Slice(0, static_cast<int>(height));
149+
framework::TensorCopy(dout_tensor, dout_tensor.place(), dev_ctx,
150+
&slice);
151+
}
139152
if (dx_tensor.dims()[0] > height) {
140153
auto rest_tensor = dx_tensor.Slice(
141154
static_cast<int>(height), static_cast<int>(dx_tensor.dims()[0]));

paddle/fluid/platform/device_context.cc

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,7 @@ CUDADeviceContext::CUDADeviceContext(CUDAPlace place)
201201
compute_capability = GetCUDAComputeCapability(place_.device);
202202
multi_process = GetCUDAMultiProcessors(place_.device);
203203
max_threads_per_mp = GetCUDAMaxThreadsPerMultiProcessor(place_.device);
204+
grid_max_dims_ = GpuMaxGridDim(place_.device);
204205
PADDLE_ENFORCE(cudaStreamCreate(&stream_));
205206
eigen_stream_.reset(new EigenCudaStreamDevice());
206207
eigen_stream_->Reinitialize(&stream_, place);
@@ -239,6 +240,10 @@ int CUDADeviceContext::GetMaxPhysicalThreadCount() const {
239240
return multi_process * max_threads_per_mp;
240241
}
241242

243+
std::tuple<int, int, int> CUDADeviceContext::GetMaxGridDims() const {
244+
return grid_max_dims_;
245+
}
246+
242247
Eigen::GpuDevice* CUDADeviceContext::eigen_device() const {
243248
return eigen_device_.get();
244249
}

paddle/fluid/platform/device_context.h

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@ limitations under the License. */
1313
#include <memory>
1414
#include <mutex> // NOLINT
1515
#include <string>
16+
#include <tuple>
1617
#include <unordered_map>
1718
#include <vector>
1819

@@ -91,6 +92,8 @@ class CUDADeviceContext : public DeviceContext {
9192
/*! \brief Return the max physical thread count in the device context */
9293
int GetMaxPhysicalThreadCount() const;
9394

95+
std::tuple<int, int, int> GetMaxGridDims() const;
96+
9497
/*! \brief Return eigen device in the device context. */
9598
Eigen::GpuDevice* eigen_device() const;
9699

@@ -135,6 +138,8 @@ class CUDADeviceContext : public DeviceContext {
135138
cudaStream_t stream_;
136139
cublasHandle_t cublas_handle_;
137140

141+
std::tuple<int, int, int> grid_max_dims_;
142+
138143
int compute_capability;
139144
int multi_process;
140145
int max_threads_per_mp;

paddle/fluid/platform/for_range.h

Lines changed: 29 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -48,35 +48,54 @@ __global__ static void ForRangeElemwiseOpGridIsOne(Function func) {
4848
}
4949

5050
template <typename Function>
51-
__global__ static void ForRangeElemwiseOp(Function func, int limit) {
51+
__global__ static void ForRangeElemwiseOp(Function func, size_t limit) {
5252
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
5353
if (idx < limit) {
5454
func(idx);
5555
}
5656
}
5757

58+
template <typename Function>
59+
__global__ static void ForRangeElemwiseOpGridLarge(Function func, size_t limit,
60+
int grid_dim) {
61+
size_t idx = static_cast<size_t>(blockIdx.x * blockDim.x + threadIdx.x);
62+
while (idx < limit) {
63+
func(idx);
64+
idx += grid_dim;
65+
}
66+
}
67+
5868
template <>
5969
struct ForRange<CUDADeviceContext> {
6070
ForRange(const CUDADeviceContext& dev_ctx, size_t limit)
61-
: dev_ctx_(dev_ctx), limit_(static_cast<int>(limit)) {}
71+
: dev_ctx_(dev_ctx), limit_(limit) {}
6272

6373
template <typename Function>
6474
inline void operator()(Function func) const {
6575
constexpr int num_threads = 1024;
6676
int block_size = limit_ <= num_threads ? limit_ : num_threads;
67-
int grid_size = (limit_ + num_threads - 1) / num_threads;
68-
69-
if (grid_size == 1) {
70-
ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>(
71-
func);
77+
size_t grid_size = (limit_ + num_threads - 1) / num_threads;
78+
79+
int max_grid_dim = std::get<0>(dev_ctx_.GetMaxGridDims());
80+
81+
if (grid_size < max_grid_dim) {
82+
int grid_size_int = static_cast<int>(grid_size);
83+
if (grid_size == 1) {
84+
ForRangeElemwiseOpGridIsOne<<<1, block_size, 0, dev_ctx_.stream()>>>(
85+
func);
86+
} else {
87+
ForRangeElemwiseOp<<<grid_size_int, block_size, 0, dev_ctx_.stream()>>>(
88+
func, limit_);
89+
}
7290
} else {
73-
ForRangeElemwiseOp<<<grid_size, block_size, 0, dev_ctx_.stream()>>>(
74-
func, limit_);
91+
ForRangeElemwiseOpGridLarge<<<max_grid_dim, block_size, 0,
92+
dev_ctx_.stream()>>>(func, limit_,
93+
max_grid_dim);
7594
}
7695
}
7796

7897
const CUDADeviceContext& dev_ctx_;
79-
int limit_;
98+
size_t limit_;
8099
};
81100

82101
#endif

paddle/fluid/platform/gpu_info.cc

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -152,5 +152,22 @@ void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream) {
152152
PADDLE_ENFORCE(cudaMemsetAsync(dst, value, count, stream),
153153
"cudaMemsetAsync failed in paddle::platform::GpuMemsetAsync");
154154
}
155+
156+
std::tuple<int, int, int> GpuMaxGridDim(int id) {
157+
std::tuple<int, int, int> result;
158+
PADDLE_ENFORCE(
159+
cudaDeviceGetAttribute(&std::get<0>(result), cudaDevAttrMaxBlockDimX, id),
160+
"cudaDeviceGetAttribute failed in "
161+
"cudaDevAttrMaxBlockDim");
162+
PADDLE_ENFORCE(
163+
cudaDeviceGetAttribute(&std::get<1>(result), cudaDevAttrMaxBlockDimY, id),
164+
"cudaDeviceGetAttribute failed in "
165+
"cudaDevAttrMaxBlockDim");
166+
PADDLE_ENFORCE(
167+
cudaDeviceGetAttribute(&std::get<2>(result), cudaDevAttrMaxBlockDimZ, id),
168+
"cudaDeviceGetAttribute failed in "
169+
"cudaDevAttrMaxBlockDim");
170+
return result;
171+
}
155172
} // namespace platform
156173
} // namespace paddle

paddle/fluid/platform/gpu_info.h

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@ limitations under the License. */
1919
#include <cuda_runtime.h>
2020
#include <stddef.h>
2121
#include <string>
22+
#include <tuple>
2223

2324
namespace paddle {
2425
namespace platform {
@@ -72,6 +73,8 @@ void GpuMemcpyPeerSync(void *dst, int dst_device, const void *src,
7273
//! Set memory dst with value count size asynchronously
7374
void GpuMemsetAsync(void *dst, int value, size_t count, cudaStream_t stream);
7475

76+
std::tuple<int, int, int> GpuMaxGridDim(int id);
77+
7578
} // namespace platform
7679
} // namespace paddle
7780

python/paddle/fluid/layers/io.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,7 @@ def _copy_reader_var_(block, var):
311311
new_var = block.create_var(name=var.name, type=core.VarDesc.VarType.READER)
312312
new_var.desc.set_shapes(var.desc.shapes())
313313
new_var.desc.set_dtypes(var.desc.dtypes())
314+
new_var.desc.set_lod_levels(var.desc.lod_levels())
314315
new_var.persistable = True
315316
return new_var
316317

@@ -632,6 +633,7 @@ def py_reader(capacity,
632633
})
633634

634635
startup_var.desc.set_dtypes(dtypes)
636+
startup_var.desc.set_lod_levels(lod_levels)
635637
startup_var.persistable = True
636638

637639
main_prog_var = _copy_reader_var_(default_main_program().current_block(),

0 commit comments

Comments
 (0)