@@ -292,8 +292,11 @@ __global__ void vectorized_templated_elementwise_kernel(
292292 out_calc_t out_calc,
293293 loader_t loader,
294294 storer_t storer) {
295- int remaining =
296- N - vectorized_templated_config::block_work_size () * blockIdx .x ;
295+ int remaining = N -
296+ vectorized_templated_config::block_work_size () *
297+ (gridDim .x - blockIdx .x - 1 );
298+ constexpr bool reverted_idx = true ;
299+
297300 if (remaining <
298301 vectorized_templated_config::block_work_size ()) { // if this block handles
299302 // the reminder,
@@ -307,18 +310,17 @@ __global__ void vectorized_templated_elementwise_kernel(
307310 storer_t ,
308311 vectorized_templated_config::elems_per_thread ()>(
309312 data, remaining, inp_calc, out_calc, loader, storer);
310- elementwise_kernel_helper (f, policy);
313+ elementwise_kernel_helper<reverted_idx> (f, policy);
311314 } else { // if this block has a full `block_work_size` data to handle, use
312315 // vectorized memory access
313- elementwise_kernel_helper (
314- f,
315- memory::policies::vectorized_templated<
316- vec_size,
317- array_t ,
318- vectorized_templated_config::elems_per_thread (),
319- vectorized_templated_config::num_threads (),
320- OutputType,
321- InputTypes...>(data));
316+ auto policy = memory::policies::vectorized_templated<
317+ vec_size,
318+ array_t ,
319+ vectorized_templated_config::elems_per_thread (),
320+ vectorized_templated_config::num_threads (),
321+ OutputType,
322+ InputTypes...>(data);
323+ elementwise_kernel_helper<reverted_idx>(f, policy);
322324 }
323325}
324326
@@ -652,41 +654,143 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
652654
653655#ifdef USE_ROCM
654656namespace {
655- template <typename TupleLike, size_t arity, size_t arg_num = 0 >
656- struct check_types {
657+ template <
658+ typename TupleLike,
659+ typename FirstParamTy,
660+ typename SecondParamTy,
661+ size_t arity,
662+ size_t arg_num = 0 >
663+ struct check_binary_functor_types_for_specialization {
657664 constexpr static inline bool check () {
658665 if constexpr (arity != 2 )
659666 return false ;
660667 if constexpr (arg_num == 0 ) {
661668 using SelectedType = std::tuple_element_t <arg_num, TupleLike>;
662- if constexpr (std::is_same_v<float , SelectedType>)
663- return check_types<TupleLike, arity, arg_num + 1 >::check ();
669+ if constexpr (std::is_same_v<FirstParamTy, SelectedType>)
670+ return check_binary_functor_types_for_specialization<
671+ TupleLike,
672+ FirstParamTy,
673+ SecondParamTy,
674+ arity,
675+ arg_num + 1 >::check ();
664676 } else if constexpr (arg_num == 1 ) {
665677 using SelectedType2 = std::tuple_element_t <arg_num, TupleLike>;
666- if constexpr (std::is_same_v<float , SelectedType2>)
667- return check_types<TupleLike, arity, arg_num + 1 >::check ();
678+ if constexpr (std::is_same_v<SecondParamTy, SelectedType2>)
679+ return check_binary_functor_types_for_specialization<
680+ TupleLike,
681+ FirstParamTy,
682+ SecondParamTy,
683+ arity,
684+ arg_num + 1 >::check ();
668685 }
669686 return false ;
670687 }
671688};
672689
673690// Bottom case: if we got this far, assume correct type matching except
674691// when there are no arguments (arity == 0).
675- template <typename TupleLike, size_t arity>
676- struct check_types <TupleLike, arity, arity> {
692+ template <
693+ typename TupleLike,
694+ typename FirstParamTy,
695+ typename SecondParamTy,
696+ size_t arity>
697+ struct check_binary_functor_types_for_specialization <
698+ TupleLike,
699+ FirstParamTy,
700+ SecondParamTy,
701+ arity,
702+ arity> {
677703 constexpr static inline bool check () {
678704 if constexpr (arity != 0 )
679705 return true ;
680706 return false ;
681707 }
682708};
683709
684- template <typename TupleLike>
685- struct check_types <TupleLike, 0 , 0 > {
710+ template <typename TupleLike, typename FirstParamTy, typename SecondParamTy>
711+ struct check_binary_functor_types_for_specialization <
712+ TupleLike,
713+ FirstParamTy,
714+ SecondParamTy,
715+ 0 ,
716+ 0 > {
686717 constexpr static inline bool check () {
687718 return false ;
688719 }
689720};
721+
722+ // The following is a list of type specializations for vectorized_templated
723+ // elementwise kernel. It refers to the first and second runtime types of the
724+ // arguments of a binary functor.
725+
726+ constexpr std::array rt_binary_specializations = {
727+ std::array<c10::ScalarType, 2 >(
728+ {c10::CppTypeToScalarType<float >::value,
729+ c10::CppTypeToScalarType<BFloat16>::value}),
730+ std::array<c10::ScalarType, 2 >(
731+ {c10::CppTypeToScalarType<BFloat16>::value,
732+ c10::CppTypeToScalarType<float >::value}),
733+ std::array<c10::ScalarType, 2 >(
734+ {c10::CppTypeToScalarType<float >::value,
735+ c10::CppTypeToScalarType<Half>::value}),
736+ std::array<c10::ScalarType, 2 >(
737+ {c10::CppTypeToScalarType<Half>::value,
738+ c10::CppTypeToScalarType<float >::value})};
739+
740+ bool check_binary_rt_types_for_specialization (TensorIteratorBase& iter) {
741+ if (iter.ninputs () != 2 )
742+ return false ;
743+ for (auto spec : rt_binary_specializations)
744+ if (iter.input_dtype (0 ) == spec[0 ] && iter.input_dtype (1 ) == spec[1 ])
745+ return true ;
746+ return false ;
747+ }
748+
749+ template <int arg_index>
750+ struct type_specialized_kernel_launcher {
751+ template <
752+ typename func_t ,
753+ typename array_t ,
754+ typename inp_calc_t ,
755+ typename out_calc_t ,
756+ typename loader_t ,
757+ typename storer_t >
758+ static void apply (
759+ ScalarType arg0_t ,
760+ ScalarType arg1_t ,
761+ int64_t numel,
762+ func_t f,
763+ array_t data,
764+ inp_calc_t input_offset_calculator,
765+ out_calc_t output_offset_calculator,
766+ loader_t loader,
767+ storer_t storer) {
768+ using traits = function_traits<func_t >;
769+ using return_t = typename traits::result_type;
770+ if (arg0_t == rt_binary_specializations[arg_index][0 ] &&
771+ arg1_t == rt_binary_specializations[arg_index][1 ])
772+ launch_vectorized_templated_kernel<
773+ func_t ,
774+ array_t ,
775+ inp_calc_t ,
776+ out_calc_t ,
777+ loader_t ,
778+ storer_t ,
779+ return_t ,
780+ decltype (c10::impl::ScalarTypeToCPPType<
781+ rt_binary_specializations[arg_index][0 ]>::t),
782+ decltype (c10::impl::ScalarTypeToCPPType<
783+ rt_binary_specializations[arg_index][1 ]>::t)>(
784+ numel,
785+ f,
786+ data,
787+ input_offset_calculator,
788+ output_offset_calculator,
789+ loader,
790+ storer);
791+ }
792+ };
793+
690794} // namespace
691795#endif
692796
@@ -716,43 +820,46 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
716820#ifdef USE_ROCM
717821 // Attempt to call specialized vectorized elementwise kernel
718822 // that enables interleaving.
719- using float_map = c10::CppTypeToScalarType<float >;
720- using bfloat16_map = c10::CppTypeToScalarType<BFloat16>;
721- if (iter.ninputs () == 2 && iter.input_dtype (0 ) == float_map::value &&
722- iter.input_dtype (1 ) == bfloat16_map::value &&
823+
824+ if (check_binary_rt_types_for_specialization (iter) &&
723825 memory::can_vectorize_up_to<func_t >(data) > 1 ) {
724- // constexpr to reduce the amount of kernels (empty) generated for
826+ // constexpr to reduce the amount of kernels generated for
725827 // vectorized templated elementwise and limit which functors are actually
726828 // applied to the load and store at compile time.
727829 using func_tuple = typename traits::ArgsTuple;
728830 if constexpr (
729831 std::is_same_v<float , arg0_t > && traits::arity == 2 &&
730- check_types<func_tuple, traits::arity, 0 >::check ()) {
832+ check_binary_functor_types_for_specialization<
833+ func_tuple,
834+ float ,
835+ float ,
836+ traits::arity,
837+ /* arg_num=*/ 0 >::check ()) {
838+ // If we got here, we know we are in one of the specialized cases. We
839+ // need to translate the runtime type to a statically known type. This
840+ // is effectively hoisting to the host the switch over runtime type in
841+ // the kernel in fetch_and_cast. Loader, storer, offset calculators are
842+ // only needed for the reminder loop.
731843 auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
732844 auto output_offset_calculator = TrivialOffsetCalculator<1 >();
733845 auto loader = memory::LoadWithCast<traits::arity>(iter);
734846 auto storer = memory::StoreWithCast<1 >(iter);
735- launch_vectorized_templated_kernel<
736- func_t ,
737- std::array<char *, ntensors>,
738- decltype (input_offset_calculator),
739- decltype (output_offset_calculator),
740- decltype (loader),
741- decltype (storer),
742- float ,
743- float ,
744- BFloat16>(
745- numel,
746- f,
747- data,
748- input_offset_calculator,
749- output_offset_calculator,
750- loader,
751- storer);
847+ memory::detail::static_unroll<
848+ type_specialized_kernel_launcher,
849+ rt_binary_specializations.size ()>::
850+ with_args (
851+ iter.input_dtype (0 ),
852+ iter.input_dtype (1 ),
853+ numel,
854+ f,
855+ data,
856+ input_offset_calculator,
857+ output_offset_calculator,
858+ loader,
859+ storer);
752860 return ;
753861 }
754862 }
755-
756863 std::array<ScalarType, ntensors> dtypes;
757864 auto inner_strides = iter.get_inner_strides ();
758865 std::array<int , ntensors> strides;
0 commit comments