Skip to content

Commit 6e62a7c

Browse files
carlobertollipruthvistony
authored andcommitted
[ROCm] Extend vectorized elementwise kernel to more heterogenous tensor types. (pytorch#149738)
This patch extends the initial support for "vectorized templated" kernels to the following input tensor types: (BFloat16, float) (float, float16) (float16, float) Pull Request resolved: pytorch#149738 Approved by: https://github.com/jeffdaily
1 parent f001ac4 commit 6e62a7c

File tree

2 files changed

+156
-47
lines changed

2 files changed

+156
-47
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 153 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,11 @@ __global__ void vectorized_templated_elementwise_kernel(
292292
out_calc_t out_calc,
293293
loader_t loader,
294294
storer_t storer) {
295-
int remaining =
296-
N - vectorized_templated_config::block_work_size() * blockIdx.x;
295+
int remaining = N -
296+
vectorized_templated_config::block_work_size() *
297+
(gridDim.x - blockIdx.x - 1);
298+
constexpr bool reverted_idx = true;
299+
297300
if (remaining <
298301
vectorized_templated_config::block_work_size()) { // if this block handles
299302
// the reminder,
@@ -307,18 +310,17 @@ __global__ void vectorized_templated_elementwise_kernel(
307310
storer_t,
308311
vectorized_templated_config::elems_per_thread()>(
309312
data, remaining, inp_calc, out_calc, loader, storer);
310-
elementwise_kernel_helper(f, policy);
313+
elementwise_kernel_helper<reverted_idx>(f, policy);
311314
} else { // if this block has a full `block_work_size` data to handle, use
312315
// vectorized memory access
313-
elementwise_kernel_helper(
314-
f,
315-
memory::policies::vectorized_templated<
316-
vec_size,
317-
array_t,
318-
vectorized_templated_config::elems_per_thread(),
319-
vectorized_templated_config::num_threads(),
320-
OutputType,
321-
InputTypes...>(data));
316+
auto policy = memory::policies::vectorized_templated<
317+
vec_size,
318+
array_t,
319+
vectorized_templated_config::elems_per_thread(),
320+
vectorized_templated_config::num_threads(),
321+
OutputType,
322+
InputTypes...>(data);
323+
elementwise_kernel_helper<reverted_idx>(f, policy);
322324
}
323325
}
324326

@@ -652,41 +654,143 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
652654

653655
#ifdef USE_ROCM
654656
namespace {
655-
template <typename TupleLike, size_t arity, size_t arg_num = 0>
656-
struct check_types {
657+
template <
658+
typename TupleLike,
659+
typename FirstParamTy,
660+
typename SecondParamTy,
661+
size_t arity,
662+
size_t arg_num = 0>
663+
struct check_binary_functor_types_for_specialization {
657664
constexpr static inline bool check() {
658665
if constexpr (arity != 2)
659666
return false;
660667
if constexpr (arg_num == 0) {
661668
using SelectedType = std::tuple_element_t<arg_num, TupleLike>;
662-
if constexpr (std::is_same_v<float, SelectedType>)
663-
return check_types<TupleLike, arity, arg_num + 1>::check();
669+
if constexpr (std::is_same_v<FirstParamTy, SelectedType>)
670+
return check_binary_functor_types_for_specialization<
671+
TupleLike,
672+
FirstParamTy,
673+
SecondParamTy,
674+
arity,
675+
arg_num + 1>::check();
664676
} else if constexpr (arg_num == 1) {
665677
using SelectedType2 = std::tuple_element_t<arg_num, TupleLike>;
666-
if constexpr (std::is_same_v<float, SelectedType2>)
667-
return check_types<TupleLike, arity, arg_num + 1>::check();
678+
if constexpr (std::is_same_v<SecondParamTy, SelectedType2>)
679+
return check_binary_functor_types_for_specialization<
680+
TupleLike,
681+
FirstParamTy,
682+
SecondParamTy,
683+
arity,
684+
arg_num + 1>::check();
668685
}
669686
return false;
670687
}
671688
};
672689

673690
// Bottom case: if we got this far, assume correct type matching except
674691
// when there are no arguments (arity == 0).
675-
template <typename TupleLike, size_t arity>
676-
struct check_types<TupleLike, arity, arity> {
692+
template <
693+
typename TupleLike,
694+
typename FirstParamTy,
695+
typename SecondParamTy,
696+
size_t arity>
697+
struct check_binary_functor_types_for_specialization<
698+
TupleLike,
699+
FirstParamTy,
700+
SecondParamTy,
701+
arity,
702+
arity> {
677703
constexpr static inline bool check() {
678704
if constexpr (arity != 0)
679705
return true;
680706
return false;
681707
}
682708
};
683709

684-
template <typename TupleLike>
685-
struct check_types<TupleLike, 0, 0> {
710+
template <typename TupleLike, typename FirstParamTy, typename SecondParamTy>
711+
struct check_binary_functor_types_for_specialization<
712+
TupleLike,
713+
FirstParamTy,
714+
SecondParamTy,
715+
0,
716+
0> {
686717
constexpr static inline bool check() {
687718
return false;
688719
}
689720
};
721+
722+
// The following is a list of type specializations for vectorized_templated
723+
// elementwise kernel. It refers to the first and second runtime types of the
724+
// arguments of a binary functor.
725+
726+
constexpr std::array rt_binary_specializations = {
727+
std::array<c10::ScalarType, 2>(
728+
{c10::CppTypeToScalarType<float>::value,
729+
c10::CppTypeToScalarType<BFloat16>::value}),
730+
std::array<c10::ScalarType, 2>(
731+
{c10::CppTypeToScalarType<BFloat16>::value,
732+
c10::CppTypeToScalarType<float>::value}),
733+
std::array<c10::ScalarType, 2>(
734+
{c10::CppTypeToScalarType<float>::value,
735+
c10::CppTypeToScalarType<Half>::value}),
736+
std::array<c10::ScalarType, 2>(
737+
{c10::CppTypeToScalarType<Half>::value,
738+
c10::CppTypeToScalarType<float>::value})};
739+
740+
bool check_binary_rt_types_for_specialization(TensorIteratorBase& iter) {
741+
if (iter.ninputs() != 2)
742+
return false;
743+
for (auto spec : rt_binary_specializations)
744+
if (iter.input_dtype(0) == spec[0] && iter.input_dtype(1) == spec[1])
745+
return true;
746+
return false;
747+
}
748+
749+
template <int arg_index>
750+
struct type_specialized_kernel_launcher {
751+
template <
752+
typename func_t,
753+
typename array_t,
754+
typename inp_calc_t,
755+
typename out_calc_t,
756+
typename loader_t,
757+
typename storer_t>
758+
static void apply(
759+
ScalarType arg0_t,
760+
ScalarType arg1_t,
761+
int64_t numel,
762+
func_t f,
763+
array_t data,
764+
inp_calc_t input_offset_calculator,
765+
out_calc_t output_offset_calculator,
766+
loader_t loader,
767+
storer_t storer) {
768+
using traits = function_traits<func_t>;
769+
using return_t = typename traits::result_type;
770+
if (arg0_t == rt_binary_specializations[arg_index][0] &&
771+
arg1_t == rt_binary_specializations[arg_index][1])
772+
launch_vectorized_templated_kernel<
773+
func_t,
774+
array_t,
775+
inp_calc_t,
776+
out_calc_t,
777+
loader_t,
778+
storer_t,
779+
return_t,
780+
decltype(c10::impl::ScalarTypeToCPPType<
781+
rt_binary_specializations[arg_index][0]>::t),
782+
decltype(c10::impl::ScalarTypeToCPPType<
783+
rt_binary_specializations[arg_index][1]>::t)>(
784+
numel,
785+
f,
786+
data,
787+
input_offset_calculator,
788+
output_offset_calculator,
789+
loader,
790+
storer);
791+
}
792+
};
793+
690794
} // namespace
691795
#endif
692796

@@ -716,43 +820,46 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
716820
#ifdef USE_ROCM
717821
// Attempt to call specialized vectorized elementwise kernel
718822
// that enables interleaving.
719-
using float_map = c10::CppTypeToScalarType<float>;
720-
using bfloat16_map = c10::CppTypeToScalarType<BFloat16>;
721-
if (iter.ninputs() == 2 && iter.input_dtype(0) == float_map::value &&
722-
iter.input_dtype(1) == bfloat16_map::value &&
823+
824+
if (check_binary_rt_types_for_specialization(iter) &&
723825
memory::can_vectorize_up_to<func_t>(data) > 1) {
724-
// constexpr to reduce the amount of kernels (empty) generated for
826+
// constexpr to reduce the amount of kernels generated for
725827
// vectorized templated elementwise and limit which functors are actually
726828
// applied to the load and store at compile time.
727829
using func_tuple = typename traits::ArgsTuple;
728830
if constexpr (
729831
std::is_same_v<float, arg0_t> && traits::arity == 2 &&
730-
check_types<func_tuple, traits::arity, 0>::check()) {
832+
check_binary_functor_types_for_specialization<
833+
func_tuple,
834+
float,
835+
float,
836+
traits::arity,
837+
/*arg_num=*/0>::check()) {
838+
// If we got here, we know we are in one of the specialized cases. We
839+
// need to translate the runtime type to a statically known type. This
840+
// is effectively hoisting to the host the switch over runtime type in
841+
// the kernel in fetch_and_cast. Loader, storer, offset calculators are
842+
// only needed for the reminder loop.
731843
auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
732844
auto output_offset_calculator = TrivialOffsetCalculator<1>();
733845
auto loader = memory::LoadWithCast<traits::arity>(iter);
734846
auto storer = memory::StoreWithCast<1>(iter);
735-
launch_vectorized_templated_kernel<
736-
func_t,
737-
std::array<char*, ntensors>,
738-
decltype(input_offset_calculator),
739-
decltype(output_offset_calculator),
740-
decltype(loader),
741-
decltype(storer),
742-
float,
743-
float,
744-
BFloat16>(
745-
numel,
746-
f,
747-
data,
748-
input_offset_calculator,
749-
output_offset_calculator,
750-
loader,
751-
storer);
847+
memory::detail::static_unroll<
848+
type_specialized_kernel_launcher,
849+
rt_binary_specializations.size()>::
850+
with_args(
851+
iter.input_dtype(0),
852+
iter.input_dtype(1),
853+
numel,
854+
f,
855+
data,
856+
input_offset_calculator,
857+
output_offset_calculator,
858+
loader,
859+
storer);
752860
return;
753861
}
754862
}
755-
756863
std::array<ScalarType, ntensors> dtypes;
757864
auto inner_strides = iter.get_inner_strides();
758865
std::array<int, ntensors> strides;

aten/src/ATen/native/cuda/Loops.cuh

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,14 +41,16 @@ static OffsetCalculator<num_outputs> make_output_offset_calculator(const TensorI
4141
return OffsetCalculator<num_outputs>(iter.ndim(), iter.shape().data(), strides.data(), element_sizes);
4242
}
4343

44-
template<typename func_t, typename policy_t>
44+
template <bool reverted_idx = false, typename func_t, typename policy_t>
4545
__device__ inline void elementwise_kernel_helper(func_t f, policy_t policy) {
4646
using traits = function_traits<func_t>;
4747
using return_t = typename traits::result_type;
4848
using args_t = typename traits::ArgsTuple;
4949
constexpr int elems_per_thread = policy_t::tws;
5050

5151
int idx = blockIdx.x;
52+
if constexpr (reverted_idx)
53+
idx = gridDim.x - blockIdx.x - 1;
5254

5355
return_t results[elems_per_thread];
5456
args_t args[elems_per_thread];

0 commit comments

Comments
 (0)