Skip to content

Commit 5729657

Browse files
authored
[ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2791)
cherry-pick of pytorch#167233 Fixes #SWDEV-551924
1 parent 93dd572 commit 5729657

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -882,6 +882,69 @@ struct type_specialized_kernel_launcher {
882882
}
883883
};
884884

885+
template <int arg_index>
886+
struct type_specialized_broadcast_kernel_launcher {
887+
template <
888+
typename func_t,
889+
typename array_t,
890+
typename dtypes_t,
891+
typename calc_t>
892+
static void apply(
893+
int64_t numel,
894+
func_t f,
895+
array_t data,
896+
dtypes_t dtypes,
897+
calc_t offset_calc) {
898+
using traits = function_traits<func_t>;
899+
using ret_t = typename traits::result_type;
900+
using arg0_t = typename traits::template arg<0>::type;
901+
using arg1_t = typename traits::template arg<1>::type;
902+
if (dtypes[0] == rt_binary_specializations[arg_index][0] &&
903+
dtypes[1] == rt_binary_specializations[arg_index][1] &&
904+
dtypes[2] == rt_binary_specializations[arg_index][2]) {
905+
using ret_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][0]>;
906+
using arg0_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][1]>;
907+
using arg1_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][2]>;
908+
constexpr int grp_sz = 128;
909+
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
910+
if (unrl) {
911+
auto offsets0 = offset_calc.get(idx);
912+
auto offsets1 = offset_calc.get(idx + grp_sz);
913+
auto offsets2 = offset_calc.get(idx + grp_sz * 2);
914+
auto offsets3 = offset_calc.get(idx + grp_sz * 3);
915+
void* out0 = data[0] + offsets0[0];
916+
void* out1 = data[0] + offsets1[0];
917+
void* out2 = data[0] + offsets2[0];
918+
void* out3 = data[0] + offsets3[0];
919+
auto u = c10::load<arg0_cpp_t>(data[1] + offsets0[1]);
920+
auto v = c10::load<arg1_cpp_t>(data[2] + offsets0[2]);
921+
ret_t result0 = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
922+
auto u1 = c10::load<arg0_cpp_t>(data[1] + offsets1[1]);
923+
auto v1 = c10::load<arg1_cpp_t>(data[2]+ offsets1[2]);
924+
ret_t result1 = f(c10::convert<arg0_t>(u1), c10::convert<arg1_t>(v1));
925+
auto u2 = c10::load<arg0_cpp_t>(data[1] + offsets2[1]);
926+
auto v2 = c10::load<arg1_cpp_t>(data[2] + offsets2[2]);
927+
ret_t result2 = f(c10::convert<arg0_t>(u2), c10::convert<arg1_t>(v2));
928+
auto u3 = c10::load<arg0_cpp_t>(data[1] + offsets3[1]);
929+
auto v3 = c10::load<arg1_cpp_t>(data[2] + offsets3[2]);
930+
ret_t result3 = f(c10::convert<arg0_t>(u3), c10::convert<arg1_t>(v3));
931+
*(ret_cpp_t*)out0 = c10::convert<ret_cpp_t>(result0);
932+
*(ret_cpp_t*)out1 = c10::convert<ret_cpp_t>(result1);
933+
*(ret_cpp_t*)out2 = c10::convert<ret_cpp_t>(result2);
934+
*(ret_cpp_t*)out3 = c10::convert<ret_cpp_t>(result3);
935+
} else {
936+
auto offsets = offset_calc.get(idx);
937+
void* out = data[0] + offsets[0];
938+
auto u = c10::load<arg0_cpp_t>(data[1] + offsets[1]);
939+
auto v = c10::load<arg1_cpp_t>(data[2] + offsets[2]);
940+
ret_t result = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
941+
*(ret_cpp_t*)out = c10::convert<ret_cpp_t>(result);
942+
}
943+
});
944+
}
945+
}
946+
};
947+
885948
} // namespace
886949
#endif
887950

@@ -1000,6 +1063,32 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
10001063
}
10011064
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
10021065
#ifdef USE_ROCM
1066+
if (check_binary_rt_types_for_specialization(iter)) {
1067+
// constexpr to reduce the amount of kernels generated for
1068+
// broadcast elementwise with mexed dtypes and limit which functors are actually
1069+
// applied to the load and store at compile time.
1070+
using func_tuple = typename traits::ArgsTuple;
1071+
if constexpr (
1072+
std::is_same_v<float, arg0_t> && traits::arity == 2 &&
1073+
check_binary_functor_types_for_specialization<
1074+
func_tuple,
1075+
float,
1076+
float,
1077+
traits::arity,
1078+
/*arg_num=*/0>::check()) {
1079+
memory::detail::static_unroll<
1080+
type_specialized_broadcast_kernel_launcher,
1081+
rt_binary_specializations.size()>::with_args(
1082+
numel,
1083+
f,
1084+
data,
1085+
dtypes,
1086+
offset_calc
1087+
);
1088+
return;
1089+
}
1090+
}
1091+
10031092
constexpr int grp_sz = 128;
10041093
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
10051094
if (unrl) {

0 commit comments

Comments
 (0)