Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
89 changes: 89 additions & 0 deletions aten/src/ATen/native/cuda/CUDALoops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -882,6 +882,69 @@ struct type_specialized_kernel_launcher {
}
};

template <int arg_index>
struct type_specialized_broadcast_kernel_launcher {
template <
typename func_t,
typename array_t,
typename dtypes_t,
typename calc_t>
static void apply(
int64_t numel,
func_t f,
array_t data,
dtypes_t dtypes,
calc_t offset_calc) {
using traits = function_traits<func_t>;
using ret_t = typename traits::result_type;
using arg0_t = typename traits::template arg<0>::type;
using arg1_t = typename traits::template arg<1>::type;
if (dtypes[0] == rt_binary_specializations[arg_index][0] &&
dtypes[1] == rt_binary_specializations[arg_index][1] &&
dtypes[2] == rt_binary_specializations[arg_index][2]) {
using ret_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][0]>;
using arg0_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][1]>;
using arg1_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][2]>;
constexpr int grp_sz = 128;
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
if (unrl) {
auto offsets0 = offset_calc.get(idx);
auto offsets1 = offset_calc.get(idx + grp_sz);
auto offsets2 = offset_calc.get(idx + grp_sz * 2);
auto offsets3 = offset_calc.get(idx + grp_sz * 3);
void* out0 = data[0] + offsets0[0];
void* out1 = data[0] + offsets1[0];
void* out2 = data[0] + offsets2[0];
void* out3 = data[0] + offsets3[0];
auto u = c10::load<arg0_cpp_t>(data[1] + offsets0[1]);
auto v = c10::load<arg1_cpp_t>(data[2] + offsets0[2]);
ret_t result0 = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
auto u1 = c10::load<arg0_cpp_t>(data[1] + offsets1[1]);
auto v1 = c10::load<arg1_cpp_t>(data[2]+ offsets1[2]);
ret_t result1 = f(c10::convert<arg0_t>(u1), c10::convert<arg1_t>(v1));
auto u2 = c10::load<arg0_cpp_t>(data[1] + offsets2[1]);
auto v2 = c10::load<arg1_cpp_t>(data[2] + offsets2[2]);
ret_t result2 = f(c10::convert<arg0_t>(u2), c10::convert<arg1_t>(v2));
auto u3 = c10::load<arg0_cpp_t>(data[1] + offsets3[1]);
auto v3 = c10::load<arg1_cpp_t>(data[2] + offsets3[2]);
ret_t result3 = f(c10::convert<arg0_t>(u3), c10::convert<arg1_t>(v3));
*(ret_cpp_t*)out0 = c10::convert<ret_cpp_t>(result0);
*(ret_cpp_t*)out1 = c10::convert<ret_cpp_t>(result1);
*(ret_cpp_t*)out2 = c10::convert<ret_cpp_t>(result2);
*(ret_cpp_t*)out3 = c10::convert<ret_cpp_t>(result3);
} else {
auto offsets = offset_calc.get(idx);
void* out = data[0] + offsets[0];
auto u = c10::load<arg0_cpp_t>(data[1] + offsets[1]);
auto v = c10::load<arg1_cpp_t>(data[2] + offsets[2]);
ret_t result = f(c10::convert<arg0_t>(u), c10::convert<arg1_t>(v));
*(ret_cpp_t*)out = c10::convert<ret_cpp_t>(result);
}
});
}
}
};

} // namespace
#endif

Expand Down Expand Up @@ -1000,6 +1063,32 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
}
auto offset_calc = ::make_offset_calculator<traits::arity + 1>(iter);
#ifdef USE_ROCM
if (check_binary_rt_types_for_specialization(iter)) {
// constexpr to reduce the amount of kernels generated for
// broadcast elementwise with mexed dtypes and limit which functors are actually
// applied to the load and store at compile time.
using func_tuple = typename traits::ArgsTuple;
if constexpr (
std::is_same_v<float, arg0_t> && traits::arity == 2 &&
check_binary_functor_types_for_specialization<
func_tuple,
float,
float,
traits::arity,
/*arg_num=*/0>::check()) {
memory::detail::static_unroll<
type_specialized_broadcast_kernel_launcher,
rt_binary_specializations.size()>::with_args(
numel,
f,
data,
dtypes,
offset_calc
);
return;
}
}

constexpr int grp_sz = 128;
launch_legacy_kernel_manual_unroll<grp_sz, 4>(numel, [=] GPU_LAMBDA(int idx, bool unrl) {
if (unrl) {
Expand Down