@@ -28,7 +28,8 @@ using ScopedFilterDescriptor = platform::ScopedFilterDescriptor;
28
28
using ScopedConvolutionDescriptor = platform::ScopedConvolutionDescriptor;
29
29
using DataLayout = platform::DataLayout;
30
30
31
- static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES = 1024 * 1024 * 1024 ;
31
+ static constexpr size_t kCONV_CUDNN_WORKSPACE_LIMIT_BYTES =
32
+ static_cast <size_t >(1024 ) * 1024 * 1024 ;
32
33
33
34
template <typename T>
34
35
class CudnnConvOpKernel : public framework ::OpKernel<T> {
@@ -44,7 +45,8 @@ class CudnnConvOpKernel : public framework::OpKernel<T> {
44
45
std::vector<int > paddings = ctx.Attr <std::vector<int >>(" paddings" );
45
46
std::vector<int > dilations = ctx.Attr <std::vector<int >>(" dilations" );
46
47
int groups = ctx.Attr <int >(" groups" );
47
- int user_workspace_size = ctx.Attr <int >(" workspace_size_MB" );
48
+ int64_t user_workspace_size =
49
+ static_cast <size_t >(ctx.Attr <int >(" workspace_size_MB" ));
48
50
49
51
const T* input_data = input->data <T>();
50
52
const T* filter_data = filter->data <T>();
@@ -163,7 +165,8 @@ class CudnnConvGradOpKernel : public framework::OpKernel<T> {
163
165
std::vector<int > paddings = ctx.Attr <std::vector<int >>(" paddings" );
164
166
std::vector<int > dilations = ctx.Attr <std::vector<int >>(" dilations" );
165
167
int groups = ctx.Attr <int >(" groups" );
166
- int user_workspace_size = ctx.Attr <int >(" workspace_size_MB" );
168
+ int64_t user_workspace_size =
169
+ static_cast <size_t >(ctx.Attr <int >(" workspace_size_MB" ));
167
170
168
171
// ------------------- cudnn descriptors ---------------------
169
172
ScopedTensorDescriptor input_desc;
0 commit comments