@@ -54,15 +54,21 @@ class CudnnConvTransposeOpKernel : public framework::OpKernel<T> {
54
54
ScopedTensorDescriptor output_desc;
55
55
ScopedFilterDescriptor filter_desc;
56
56
ScopedConvolutionDescriptor conv_desc;
57
- DataLayout layout = DataLayout::kNCHW ;
57
+ DataLayout layout;
58
+
59
+ if (strides.size () == 2U ) {
60
+ layout = DataLayout::kNCHW ;
61
+ } else {
62
+ layout = DataLayout::kNCDHW ;
63
+ }
58
64
59
- // N, M, H, W
65
+ // ( N, M, H, W) or (N, M, D, H, W)
60
66
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor <T>(
61
67
layout, framework::vectorize2int (input->dims ()));
62
- // N, C, O_h, O_w
68
+ // ( N, C, O_h, O_w) or (N, C, O_d, O_h, O_w)
63
69
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor <T>(
64
70
layout, framework::vectorize2int (output->dims ()));
65
- // M, C, K_h, K_w
71
+ // ( M, C, K_h, K_w) or (M, C, K_d, K_h, K_w)
66
72
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor <T>(
67
73
layout, framework::vectorize2int (filter->dims ()));
68
74
cudnnConvolutionDescriptor_t cudnn_conv_desc =
@@ -136,13 +142,13 @@ class CudnnConvTransposeGradOpKernel : public framework::OpKernel<T> {
136
142
ScopedConvolutionDescriptor conv_desc;
137
143
DataLayout layout = DataLayout::kNCHW ;
138
144
139
- // Input: (N, M, H, W)
145
+ // Input: (N, M, H, W) or (N, M, D, H, W)
140
146
cudnnTensorDescriptor_t cudnn_input_desc = input_desc.descriptor <T>(
141
147
layout, framework::vectorize2int (input->dims ()));
142
- // Output: (N, C, O_H, O_W )
148
+ // Output: (N, C, O_h, O_w) or (N, C, O_d, O_h, O_w )
143
149
cudnnTensorDescriptor_t cudnn_output_desc = output_desc.descriptor <T>(
144
150
layout, framework::vectorize2int (output_grad->dims ()));
145
- // Filter (M, C, K_H, K_W )
151
+ // Filter (M, C, K_h, K_w) or (M, C, K_d K_h, K_w )
146
152
cudnnFilterDescriptor_t cudnn_filter_desc = filter_desc.descriptor <T>(
147
153
layout, framework::vectorize2int (filter->dims ()));
148
154
0 commit comments