Skip to content

Commit 130d937

Browse files
authored
[ROCm] Specialized binary elementwise broadcast kernel for mixed dtypes with float/bfloat16/half (#2795)
cherry-pick of pytorch#167233 Fixes #SWDEV-551924
1 parent a7b6c00 commit 130d937

File tree

1 file changed

+89
-0
lines changed

1 file changed

+89
-0
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -893,6 +893,69 @@ struct type_specialized_kernel_launcher {
893893
}
894894
};
895895

896+
template <int arg_index>
897+
struct type_specialized_broadcast_kernel_launcher {
898+
template <
899+
typename func_t,
900+
typename array_t,
901+
typename dtypes_t,
902+
typename calc_t>
903+
static void apply(
904+
int64_t numel,
905+
func_t f,
906+
array_t data,
907+
dtypes_t dtypes,
908+
calc_t offset_calc) {
909+
using traits = function_traits<func_t>;
910+
using ret_t = typename traits::result_type;
911+
using arg0_t = typename traits::template arg<0>::type;
912+
using arg1_t = typename traits::template arg<1>::type;
913+
if (dtypes[0] == rt_binary_specializations[arg_index][0] &&
914+
dtypes[1] == rt_binary_specializations[arg_index][1] &&
915+
dtypes[2] == rt_binary_specializations[arg_index][2]) {
916+
using ret_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][0]>;
917+
using arg0_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][1]>;
918+
using arg1_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][2]>;
919+
constexpr int grp_sz = 128;
920+
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
921+
if (unrl) {
922+
auto offsets0 = offset_calc.get(idx);
923+
auto offsets1 = offset_calc.get(idx + grp_sz);
924+
auto offsets2 = offset_calc.get(idx + grp_sz * 2);
925+
auto offsets3 = offset_calc.get(idx + grp_sz * 3);
926+
void* out0 = data[0] + offsets0[0];
927+
void* out1 = data[0] + offsets1[0];
928+
void* out2 = data[0] + offsets2[0];
929+
void* out3 = data[0] + offsets3[0];
930+
auto u = c10::load<arg0_cpp_t>(data[1] + offsets0[1]);
931+
auto v = c10::load<arg1_cpp_t>(data[2] + offsets0[2]);
932+
ret_t result0 = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
933+
auto u1 = c10::load<arg0_cpp_t>(data[1] + offsets1[1]);
934+
auto v1 = c10::load<arg1_cpp_t>(data[2]+ offsets1[2]);
935+
ret_t result1 = f(c10::convert<arg0_t>(u1), c10::convert<arg1_t>(v1));
936+
auto u2 = c10::load<arg0_cpp_t>(data[1] + offsets2[1]);
937+
auto v2 = c10::load<arg1_cpp_t>(data[2] + offsets2[2]);
938+
ret_t result2 = f(c10::convert<arg0_t>(u2), c10::convert<arg1_t>(v2));
939+
auto u3 = c10::load<arg0_cpp_t>(data[1] + offsets3[1]);
940+
auto v3 = c10::load<arg1_cpp_t>(data[2] + offsets3[2]);
941+
ret_t result3 = f(c10::convert<arg0_t>(u3), c10::convert<arg1_t>(v3));
942+
*(ret_cpp_t*)out0 = c10::convert<ret_cpp_t>(result0);
943+
*(ret_cpp_t*)out1 = c10::convert<ret_cpp_t>(result1);
944+
*(ret_cpp_t*)out2 = c10::convert<ret_cpp_t>(result2);
945+
*(ret_cpp_t*)out3 = c10::convert<ret_cpp_t>(result3);
946+
} else {
947+
auto offsets = offset_calc.get(idx);
948+
void* out = data[0] + offsets[0];
949+
auto u = c10::load<arg0_cpp_t>(data[1] + offsets[1]);
950+
auto v = c10::load<arg1_cpp_t>(data[2] + offsets[2]);
951+
ret_t result = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
952+
*(ret_cpp_t*)out = c10::convert<ret_cpp_t>(result);
953+
}
954+
});
955+
}
956+
}
957+
};
958+
896959
} // namespace
897960
#endif
898961

@@ -995,6 +1058,32 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
9951058
}
9961059
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
9971060
#ifdef USE_ROCM
1061+
if (check_binary_rt_types_for_specialization(iter)) {
1062+
// constexpr to reduce the amount of kernels generated for
1063+
// broadcast elementwise with mexed dtypes and limit which functors are actually
1064+
// applied to the load and store at compile time.
1065+
using func_tuple = typename traits::ArgsTuple;
1066+
if constexpr (
1067+
std::is_same_v<float, arg0_t> && traits::arity == 2 &&
1068+
check_binary_functor_types_for_specialization<
1069+
func_tuple,
1070+
float,
1071+
float,
1072+
traits::arity,
1073+
/*arg_num=*/0>::check()) {
1074+
memory::detail::static_unroll<
1075+
type_specialized_broadcast_kernel_launcher,
1076+
rt_binary_specializations.size()>::with_args(
1077+
numel,
1078+
f,
1079+
data,
1080+
dtypes,
1081+
offset_calc
1082+
);
1083+
return;
1084+
}
1085+
}
1086+
9981087
constexpr int grp_sz = 128;
9991088
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
10001089
if (unrl) {

0 commit comments

Comments
 (0)