Skip to content

Commit 2aadcea

Browse files
jerrymannilpytorchmergebot
authored andcommitted
[ROCm] Improve perf for elementwise broadcast with mixed dtype (pytorch#163562)
* Unroll loops manually to hide memory access latency Co-author: @amd-hhashemi Pull Request resolved: pytorch#163562 Approved by: https://github.com/jeffdaily
1 parent fde929c commit 2aadcea

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -999,12 +999,41 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
999999
dtypes[i] = iter.dtype(i);
10001000
}
10011001
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
1002+
#ifdef USE_ROCM
1003+
constexpr int grp_sz = 128;
1004+
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
1005+
if (unrl) {
1006+
auto offsets0 = offset_calc.get(idx);
1007+
auto offsets1 = offset_calc.get(idx + grp_sz);
1008+
auto offsets2 = offset_calc.get(idx + grp_sz * 2);
1009+
auto offsets3 = offset_calc.get(idx + grp_sz * 3);
1010+
void* out0 = data[0] + offsets0[0];
1011+
void* out1 = data[0] + offsets1[0];
1012+
void* out2 = data[0] + offsets2[0];
1013+
void* out3 = data[0] + offsets3[0];
1014+
arg0_t result0 = invoke(f, &data[1], &offsets0[1], &dtypes[1], 1);
1015+
arg0_t result1 = invoke(f, &data[1], &offsets1[1], &dtypes[1], 1);
1016+
arg0_t result2 = invoke(f, &data[1], &offsets2[1], &dtypes[1], 1);
1017+
arg0_t result3 = invoke(f, &data[1], &offsets3[1], &dtypes[1], 1);
1018+
c10::cast_and_store<arg0_t>(dtypes[0], out0, result0);
1019+
c10::cast_and_store<arg0_t>(dtypes[0], out1, result1);
1020+
c10::cast_and_store<arg0_t>(dtypes[0], out2, result2);
1021+
c10::cast_and_store<arg0_t>(dtypes[0], out3, result3);
1022+
} else {
1023+
auto offsets = offset_calc.get(idx);
1024+
void* out = data[0] + offsets[0];
1025+
arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1);
1026+
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
1027+
}
1028+
});
1029+
#else
10021030
launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) {
10031031
auto offsets = offset_calc.get(idx);
10041032
void* out = data[0] + offsets[0];
10051033
arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1);
10061034
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
10071035
});
1036+
#endif
10081037
}
10091038
}
10101039

0 commit comments

Comments
 (0)