Skip to content

Commit 99ccf24

Browse files
authored
[ROCm] Improve perf for elementwise broadcast with mixed dtype (#2671)
* cherry-pick of pytorch@2aadcea
1 parent 8d42697 commit 99ccf24

File tree

1 file changed

+29
-0
lines changed

1 file changed

+29
-0
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -994,12 +994,41 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
994994
dtypes[i] = iter.dtype(i);
995995
}
996996
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
997+
#ifdef USE_ROCM
998+
constexpr int grp_sz = 128;
999+
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
1000+
if (unrl) {
1001+
auto offsets0 = offset_calc.get(idx);
1002+
auto offsets1 = offset_calc.get(idx + grp_sz);
1003+
auto offsets2 = offset_calc.get(idx + grp_sz * 2);
1004+
auto offsets3 = offset_calc.get(idx + grp_sz * 3);
1005+
void* out0 = data[0] + offsets0[0];
1006+
void* out1 = data[0] + offsets1[0];
1007+
void* out2 = data[0] + offsets2[0];
1008+
void* out3 = data[0] + offsets3[0];
1009+
arg0_t result0 = invoke(f, &data[1], &offsets0[1], &dtypes[1], 1);
1010+
arg0_t result1 = invoke(f, &data[1], &offsets1[1], &dtypes[1], 1);
1011+
arg0_t result2 = invoke(f, &data[1], &offsets2[1], &dtypes[1], 1);
1012+
arg0_t result3 = invoke(f, &data[1], &offsets3[1], &dtypes[1], 1);
1013+
c10::cast_and_store<arg0_t>(dtypes[0], out0, result0);
1014+
c10::cast_and_store<arg0_t>(dtypes[0], out1, result1);
1015+
c10::cast_and_store<arg0_t>(dtypes[0], out2, result2);
1016+
c10::cast_and_store<arg0_t>(dtypes[0], out3, result3);
1017+
} else {
1018+
auto offsets = offset_calc.get(idx);
1019+
void* out = data[0] + offsets[0];
1020+
arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1);
1021+
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
1022+
}
1023+
});
1024+
#else
9971025
launch_legacy_kernel<128, 4>(numel, [=] GPU_LAMBDA(int idx) {
9981026
auto offsets = offset_calc.get(idx);
9991027
void* out = data[0] + offsets[0];
10001028
arg0_t result = invoke(f, &data[1], &offsets[1], &dtypes[1], 1);
10011029
c10::cast_and_store<arg0_t>(dtypes[0], out, result);
10021030
});
1031+
#endif
10031032
}
10041033
}
10051034

0 commit comments

Comments
 (0)