Skip to content

Commit 02cee70

Browse files
committed
[ROCm] Fix 3D tensor perf degradation with NHWC format (#2175)
Co-author: @doru1004
1 parent 1fee196 commit 02cee70

File tree

1 file changed

+2
-1
lines changed

1 file changed

+2
-1
lines changed

aten/src/ATen/native/cuda/Reduce.cuh

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1159,7 +1159,8 @@ ReduceConfig setReduceConfig(const TensorIterator& iter){
11591159
config.ctas_per_output = div_up(num_mp, 2);
11601160
else if (config.ctas_per_output < 16)
11611161
config.ctas_per_output = 1;
1162-
if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension)
1162+
bool is_channel_last = iter.tensor_base(1).is_contiguous(at::MemoryFormat::ChannelsLast);
1163+
if (iter.ndim() == 3 && !reduction_on_fastest_striding_dimension && !is_channel_last)
11631164
config.ctas_per_output = 4;
11641165
#endif
11651166
if (config.ctas_per_output > 1) {

0 commit comments

Comments
 (0)