Skip to content

Commit bf0079d

Browse files
okakarpajeffdaily
andauthored
[release/2.7] [SWDEV-535259] enable miopen channels last 3d for conv and batchnorm (#2232)
Cherry-pick of #2209 Co-authored-by: Jeff Daily <[email protected]>
1 parent 2337da4 commit bf0079d

File tree

4 files changed

+18
-13
lines changed

4 files changed

+18
-13
lines changed

aten/src/ATen/native/ConvUtils.h

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -362,20 +362,24 @@ inline bool miopen_conv_use_channels_last(const at::Tensor& input, const at::Ten
362362
return false;
363363
}
364364

365-
bool can_use_miopen_channels_last_2d = false;
366365
// TODO: Remove PYTORCH_MIOPEN_SUGGEST_NHWC once ROCm officially supports NHWC in MIOpen
367366
// See #64427
368367
static std::optional<bool> PYTORCH_MIOPEN_SUGGEST_NHWC = c10::utils::check_env("PYTORCH_MIOPEN_SUGGEST_NHWC");
368+
static bool suggest_nhwc = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC;
369369

370370
auto input_memory_format = input.suggest_memory_format();
371371
auto weight_memory_format = weight.suggest_memory_format();
372+
auto weight_ndim = weight.ndimension();
372373

373-
can_use_miopen_channels_last_2d = PYTORCH_MIOPEN_SUGGEST_NHWC && *PYTORCH_MIOPEN_SUGGEST_NHWC && (
374-
( (input_memory_format == at::MemoryFormat::ChannelsLast) ||
375-
(weight_memory_format == at::MemoryFormat::ChannelsLast) )
376-
);
374+
bool can_use_miopen_channels_last_2d = suggest_nhwc && (weight_ndim == 4) && (
375+
(input_memory_format == at::MemoryFormat::ChannelsLast) ||
376+
(weight_memory_format == at::MemoryFormat::ChannelsLast)
377+
);
377378

378-
bool can_use_miopen_channels_last_3d = false;
379+
bool can_use_miopen_channels_last_3d = suggest_nhwc && (weight_ndim == 5) && (
380+
(input_memory_format == at::MemoryFormat::ChannelsLast3d) ||
381+
(weight_memory_format == at::MemoryFormat::ChannelsLast3d)
382+
);
379383

380384
return can_use_miopen_channels_last_2d || can_use_miopen_channels_last_3d;
381385
}

aten/src/ATen/native/Convolution.cpp

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1430,7 +1430,7 @@ static inline at::MemoryFormat determine_backend_memory_format(
14301430
if (detail::getCUDAHooks().compiledWithMIOpen() && miopen_conv_use_channels_last(input, weight)) {
14311431
TORCH_INTERNAL_ASSERT((k == 4 || k == 5),
14321432
"Expected 4D or 5D input for miopen memory format selection in determine_backend_memory_format()");
1433-
backend_memory_format = (k == 5) ? at::MemoryFormat::Contiguous /*at::MemoryFormat::ChannelsLast3d*/ : at::MemoryFormat::ChannelsLast;
1433+
backend_memory_format = (k == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
14341434
}
14351435
break;
14361436
case ConvBackend::Mkldnn:

aten/src/ATen/native/Normalization.cpp

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -544,6 +544,7 @@ BatchNormBackend _select_batch_norm_backend(
544544
&& (input.suggest_memory_format() == MemoryFormat::Contiguous
545545
#if (defined(USE_ROCM) && ROCM_VERSION >= 60500)
546546
|| (input.suggest_memory_format() == MemoryFormat::ChannelsLast && PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM)
547+
|| (input.suggest_memory_format() == MemoryFormat::ChannelsLast3d && PYTORCH_MIOPEN_SUGGEST_NHWC_BATCHNORM)
547548
#endif
548549
)
549550
) {

aten/src/ATen/native/miopen/Conv_miopen.cpp

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -763,7 +763,7 @@ Tensor miopen_convolution_forward(
763763

764764
auto memory_format = at::MemoryFormat::Contiguous;
765765
if (miopen_conv_use_channels_last(*input, *weight)) {
766-
memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
766+
memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
767767
}
768768

769769
Tensor output_t = at::detail::empty_cuda(
@@ -872,7 +872,7 @@ Tensor miopen_depthwise_convolution_forward(
872872

873873
auto memory_format = at::MemoryFormat::Contiguous;
874874
if (miopen_conv_use_channels_last(*input, *weight)) {
875-
memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
875+
memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
876876
}
877877

878878
Tensor output_t = at::detail::empty_cuda(
@@ -1074,7 +1074,7 @@ Tensor miopen_depthwise_convolution_backward_weight(
10741074

10751075
auto memory_format = at::MemoryFormat::Contiguous;
10761076
if (miopen_conv_use_channels_last(*input, *grad_output)) {
1077-
memory_format = (input->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
1077+
memory_format = (input->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
10781078
}
10791079

10801080
Tensor grad_output_contig_t = grad_output->contiguous(memory_format);
@@ -1127,7 +1127,7 @@ Tensor miopen_convolution_backward_weight(
11271127

11281128
auto memory_format = at::MemoryFormat::Contiguous;
11291129
if (miopen_conv_use_channels_last(*input, *grad_output)) {
1130-
memory_format = (input->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
1130+
memory_format = (input->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
11311131
}
11321132

11331133
Tensor grad_output_contig_t = grad_output->contiguous(memory_format);
@@ -1281,7 +1281,7 @@ Tensor miopen_convolution_backward_input(
12811281

12821282
auto memory_format = at::MemoryFormat::Contiguous;
12831283
if (miopen_conv_use_channels_last(*grad_output, *weight)) {
1284-
memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
1284+
memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
12851285
}
12861286

12871287
Tensor grad_input_t = at::detail::empty_cuda(
@@ -1389,7 +1389,7 @@ Tensor miopen_depthwise_convolution_backward_input(
13891389

13901390
auto memory_format = at::MemoryFormat::Contiguous;
13911391
if (miopen_conv_use_channels_last(*grad_output, *weight)) {
1392-
memory_format = (weight->ndimension() == 5) ? /*at::MemoryFormat::ChannelsLast3d*/at::MemoryFormat::Contiguous : at::MemoryFormat::ChannelsLast;
1392+
memory_format = (weight->ndimension() == 5) ? at::MemoryFormat::ChannelsLast3d : at::MemoryFormat::ChannelsLast;
13931393
}
13941394

13951395
Tensor grad_input_t = at::detail::empty_cuda(

0 commit comments

Comments
 (0)