@@ -60,7 +60,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
60
60
const int w_in_end = w_in_start + filter_width * dilate_width;
61
61
62
62
int in_offset;
63
- if (data_layout == DataLayout::kNCHW ) {
63
+ if (data_layout != DataLayout::kNHWC ) {
64
64
in_offset =
65
65
((batch * input_channels + c_in) * input_height) * input_width;
66
66
} else {
@@ -78,7 +78,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
78
78
if (h_in >= h_start && h_in < h_end && w_in >= w_start &&
79
79
w_in < w_end) {
80
80
int offset;
81
- if (data_layout == DataLayout::kNCHW ) {
81
+ if (data_layout != DataLayout::kNHWC ) {
82
82
offset = in_offset + h_in * input_width + w_in;
83
83
} else {
84
84
offset = in_offset +
@@ -94,7 +94,7 @@ __device__ __inline__ void KernelDepthwiseConv(ARG_DEFINE_KernelDepthwiseConv) {
94
94
}
95
95
}
96
96
int index;
97
- if (data_layout == DataLayout::kNCHW ) {
97
+ if (data_layout != DataLayout::kNHWC ) {
98
98
index = ((batch * gridDim .x + c_out) * output_height + h_out) *
99
99
output_width +
100
100
w_out;
@@ -131,7 +131,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
131
131
const int w_in_end = w_in_start + c_filter * dilate_width;
132
132
133
133
int in_offset;
134
- if (data_layout == DataLayout::kNCHW ) {
134
+ if (data_layout != DataLayout::kNHWC ) {
135
135
in_offset =
136
136
((batch * input_channels + c_in) * input_height) * input_width;
137
137
} else {
@@ -150,7 +150,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
150
150
if (h_in >= 0 && h_in < input_height && w_in >= 0 &&
151
151
w_in < input_width) {
152
152
int offset;
153
- if (data_layout == DataLayout::kNCHW ) {
153
+ if (data_layout != DataLayout::kNHWC ) {
154
154
offset = in_offset + h_in * input_width + w_in;
155
155
} else {
156
156
offset = in_offset +
@@ -166,7 +166,7 @@ __device__ __inline__ void KernelDepthwiseConvCFilter(
166
166
}
167
167
}
168
168
int index;
169
- if (data_layout == DataLayout::kNCHW ) {
169
+ if (data_layout != DataLayout::kNHWC ) {
170
170
index = ((batch * gridDim .x + c_out) * output_height + h_out) *
171
171
output_width +
172
172
w_out;
@@ -252,7 +252,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad(
252
252
253
253
T value = 0 ;
254
254
int index;
255
- if (data_layout == DataLayout::kNCHW ) {
255
+ if (data_layout != DataLayout::kNHWC ) {
256
256
index =
257
257
((batch * gridDim .x + c_in) * input_height + h_in) * input_width +
258
258
w_in;
@@ -283,7 +283,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGrad(
283
283
s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
284
284
s_w_out < output_width) {
285
285
int output_grad_offset;
286
- if (data_layout == DataLayout::kNCHW ) {
286
+ if (data_layout != DataLayout::kNHWC ) {
287
287
output_grad_offset =
288
288
((batch * output_channels + c_out) * output_height +
289
289
s_h_out) *
@@ -335,7 +335,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
335
335
336
336
T value = 0 ;
337
337
int index;
338
- if (data_layout == DataLayout::kNCHW ) {
338
+ if (data_layout != DataLayout::kNHWC ) {
339
339
index =
340
340
((batch * gridDim .x + c_in) * input_height + h_in) * input_width +
341
341
w_in;
@@ -363,7 +363,7 @@ __device__ __inline__ void KernelDepthwiseConvInputGradCFilter(
363
363
s_h_out >= 0 && s_h_out < output_height && s_w_out >= 0 &&
364
364
s_w_out < output_width) {
365
365
int output_grad_offset;
366
- if (data_layout == DataLayout::kNCHW ) {
366
+ if (data_layout != DataLayout::kNHWC ) {
367
367
output_grad_offset =
368
368
((batch * output_channels + c_out) * output_height +
369
369
s_h_out) *
@@ -449,7 +449,7 @@ __device__ __inline__ void KernelDepthwiseConvFilterGrad(
449
449
#define gaid_nhwc (N, H, W, C ) \
450
450
((((N)*output_height + (H)) * output_width + (W)) * gridDim .z + (C))
451
451
int input_id;
452
- if (data_layout == DataLayout::kNCHW ) {
452
+ if (data_layout != DataLayout::kNHWC ) {
453
453
input_id = ((bid * (gridDim .z / filter_multiplier) +
454
454
kernel_id / filter_multiplier) *
455
455
input_height +
@@ -528,19 +528,19 @@ class DepthwiseConvFunctor<platform::CUDADeviceContext, T,
528
528
const DataLayout data_layout = DataLayout::kNCHW ) {
529
529
const int batch_size = input.dims ()[0 ];
530
530
const int input_channels =
531
- (data_layout == DataLayout::kNCHW ? input.dims ()[1 ] : input.dims ()[3 ]);
531
+ (data_layout != DataLayout::kNHWC ? input.dims ()[1 ] : input.dims ()[3 ]);
532
532
const int input_height =
533
- (data_layout == DataLayout::kNCHW ? input.dims ()[2 ] : input.dims ()[1 ]);
533
+ (data_layout != DataLayout::kNHWC ? input.dims ()[2 ] : input.dims ()[1 ]);
534
534
const int input_width =
535
- (data_layout == DataLayout::kNCHW ? input.dims ()[3 ] : input.dims ()[2 ]);
535
+ (data_layout != DataLayout::kNHWC ? input.dims ()[3 ] : input.dims ()[2 ]);
536
536
const int output_channels =
537
- (data_layout == DataLayout::kNCHW ? output->dims ()[1 ]
537
+ (data_layout != DataLayout::kNHWC ? output->dims ()[1 ]
538
538
: output->dims ()[3 ]);
539
539
const int output_height =
540
- (data_layout == DataLayout::kNCHW ? output->dims ()[2 ]
540
+ (data_layout != DataLayout::kNHWC ? output->dims ()[2 ]
541
541
: output->dims ()[1 ]);
542
542
const int output_width =
543
- (data_layout == DataLayout::kNCHW ? output->dims ()[3 ]
543
+ (data_layout != DataLayout::kNHWC ? output->dims ()[3 ]
544
544
: output->dims ()[2 ]);
545
545
const int ksize_height = filter.dims ()[2 ];
546
546
const int ksize_width = filter.dims ()[3 ];
@@ -614,19 +614,19 @@ class DepthwiseConvInputGradFunctor<platform::CUDADeviceContext, T,
614
614
const DataLayout data_layout = DataLayout::kNCHW ) {
615
615
const int batch_size = input.dims ()[0 ];
616
616
const int input_channels =
617
- (data_layout == DataLayout::kNCHW ? input.dims ()[1 ] : input.dims ()[3 ]);
617
+ (data_layout != DataLayout::kNHWC ? input.dims ()[1 ] : input.dims ()[3 ]);
618
618
const int input_height =
619
- (data_layout == DataLayout::kNCHW ? input.dims ()[2 ] : input.dims ()[1 ]);
619
+ (data_layout != DataLayout::kNHWC ? input.dims ()[2 ] : input.dims ()[1 ]);
620
620
const int input_width =
621
- (data_layout == DataLayout::kNCHW ? input.dims ()[3 ] : input.dims ()[2 ]);
621
+ (data_layout != DataLayout::kNHWC ? input.dims ()[3 ] : input.dims ()[2 ]);
622
622
const int output_channels =
623
- (data_layout == DataLayout::kNCHW ? output_grad.dims ()[1 ]
623
+ (data_layout != DataLayout::kNHWC ? output_grad.dims ()[1 ]
624
624
: output_grad.dims ()[3 ]);
625
625
const int output_height =
626
- (data_layout == DataLayout::kNCHW ? output_grad.dims ()[2 ]
626
+ (data_layout != DataLayout::kNHWC ? output_grad.dims ()[2 ]
627
627
: output_grad.dims ()[1 ]);
628
628
const int output_width =
629
- (data_layout == DataLayout::kNCHW ? output_grad.dims ()[3 ]
629
+ (data_layout != DataLayout::kNHWC ? output_grad.dims ()[3 ]
630
630
: output_grad.dims ()[2 ]);
631
631
const int ksize_height = filter.dims ()[2 ];
632
632
const int ksize_width = filter.dims ()[3 ];
@@ -702,19 +702,19 @@ class DepthwiseConvFilterGradFunctor<platform::CUDADeviceContext, T,
702
702
const DataLayout data_layout = DataLayout::kNCHW ) {
703
703
const int batch_size = input.dims ()[0 ];
704
704
const int input_channels =
705
- (data_layout == DataLayout::kNCHW ? input.dims ()[1 ] : input.dims ()[3 ]);
705
+ (data_layout != DataLayout::kNHWC ? input.dims ()[1 ] : input.dims ()[3 ]);
706
706
const int input_height =
707
- (data_layout == DataLayout::kNCHW ? input.dims ()[2 ] : input.dims ()[1 ]);
707
+ (data_layout != DataLayout::kNHWC ? input.dims ()[2 ] : input.dims ()[1 ]);
708
708
const int input_width =
709
- (data_layout == DataLayout::kNCHW ? input.dims ()[3 ] : input.dims ()[2 ]);
709
+ (data_layout != DataLayout::kNHWC ? input.dims ()[3 ] : input.dims ()[2 ]);
710
710
const int output_channels =
711
- (data_layout == DataLayout::kNCHW ? output_grad.dims ()[1 ]
711
+ (data_layout != DataLayout::kNHWC ? output_grad.dims ()[1 ]
712
712
: output_grad.dims ()[3 ]);
713
713
const int output_height =
714
- (data_layout == DataLayout::kNCHW ? output_grad.dims ()[2 ]
714
+ (data_layout != DataLayout::kNHWC ? output_grad.dims ()[2 ]
715
715
: output_grad.dims ()[1 ]);
716
716
const int output_width =
717
- (data_layout == DataLayout::kNCHW ? output_grad.dims ()[3 ]
717
+ (data_layout != DataLayout::kNHWC ? output_grad.dims ()[3 ]
718
718
: output_grad.dims ()[2 ]);
719
719
const int ksize_height = filter_grad->dims ()[2 ];
720
720
const int ksize_width = filter_grad->dims ()[3 ];
0 commit comments