Skip to content

Commit 5a1a04f

Browse files
authored
Merge pull request #6326 from jacquesqiao/fix-int-overflow
fix int overflow
2 parents b30e8bc + d303f7a commit 5a1a04f

File tree

1 file changed

+6
-3
lines changed

1 file changed

+6
-3
lines changed

paddle/operators/conv_cudnn_op.cu.cc

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,8 @@ using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
2828
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
2929
using DataLayout = platform::DataLayout;
3030

31-
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024;
31+
static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
32+
static_cast<size_t>(1024) * 1024 * 1024;
3233

3334
template <typename T>
3435
class CudnnConvOpKernel : public framework::OpKernel<T> {
@@ -44,7 +45,8 @@ class CudnnConvOpKernel : public framework::OpKernel<T> {
4445
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
4546
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
4647
int groups = ctx.Attr<int>("groups");
47-
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");
48+
int64_t user_workspace_size =
49+
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
4850

4951
const T* input_data = input->data<T>();
5052
const T* filter_data = filter->data<T>();
@@ -163,7 +165,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
163165
std::vector<int> paddings = ctx.Attr<std::vector<int>>("paddings");
164166
std::vector<int> dilations = ctx.Attr<std::vector<int>>("dilations");
165167
int groups = ctx.Attr<int>("groups");
166-
int user_workspace_size = ctx.Attr<int>("workspace_size_MB");
168+
int64_t user_workspace_size =
169+
static_cast<size_t>(ctx.Attr<int>("workspace_size_MB"));
167170

168171
// ------------------- cudnn descriptors ---------------------
169172
ScopedTensorDescriptor input_desc;

0 commit comments

Comments
 (0)