@@ -720,28 +720,41 @@ struct check_binary_functor_types_for_specialization<
720720};
721721
722722// The following is a list of type specializations for vectorized_templated
723- // elementwise kernel. It refers to the first and second runtime types of the
724- // arguments of a binary functor.
725-
723+ // elementwise kernel. The three types refer to runtime types of the output
724+ // tensor, first tensor argument, and the second tensor argument used for a
725+ // binary functor.
726726constexpr std::array rt_binary_specializations = {
727- std::array<c10::ScalarType, 2 >(
727+ std::array<c10::ScalarType, 3 >(
728728 {c10::CppTypeToScalarType<float >::value,
729+ c10::CppTypeToScalarType<float >::value,
729730 c10::CppTypeToScalarType<BFloat16>::value}),
730- std::array<c10::ScalarType, 2 >(
731+ std::array<c10::ScalarType, 3 >(
732+ {c10::CppTypeToScalarType<float >::value,
733+ c10::CppTypeToScalarType<BFloat16>::value,
734+ c10::CppTypeToScalarType<float >::value}),
735+ std::array<c10::ScalarType, 3 >(
731736 {c10::CppTypeToScalarType<BFloat16>::value,
737+ c10::CppTypeToScalarType<BFloat16>::value,
732738 c10::CppTypeToScalarType<float >::value}),
733- std::array<c10::ScalarType, 2 >(
739+ std::array<c10::ScalarType, 3 >(
734740 {c10::CppTypeToScalarType<float >::value,
741+ c10::CppTypeToScalarType<float >::value,
735742 c10::CppTypeToScalarType<Half>::value}),
736- std::array<c10::ScalarType, 2 >(
743+ std::array<c10::ScalarType, 3 >(
744+ {c10::CppTypeToScalarType<float >::value,
745+ c10::CppTypeToScalarType<Half>::value,
746+ c10::CppTypeToScalarType<float >::value}),
747+ std::array<c10::ScalarType, 3 >(
737748 {c10::CppTypeToScalarType<Half>::value,
749+ c10::CppTypeToScalarType<Half>::value,
738750 c10::CppTypeToScalarType<float >::value})};
739751
740752bool check_binary_rt_types_for_specialization (TensorIteratorBase& iter) {
741753 if (iter.ninputs () != 2 )
742754 return false ;
743755 for (auto spec : rt_binary_specializations)
744- if (iter.input_dtype (0 ) == spec[0 ] && iter.input_dtype (1 ) == spec[1 ])
756+ if (iter.dtype (0 ) == spec[0 ] && iter.input_dtype (0 ) == spec[1 ] &&
757+ iter.input_dtype (1 ) == spec[2 ])
745758 return true ;
746759 return false ;
747760}
@@ -756,6 +769,7 @@ struct type_specialized_kernel_launcher {
756769 typename loader_t ,
757770 typename storer_t >
758771 static void apply (
772+ ScalarType ret_t ,
759773 ScalarType arg0_t ,
760774 ScalarType arg1_t ,
761775 int64_t numel,
@@ -765,22 +779,22 @@ struct type_specialized_kernel_launcher {
765779 out_calc_t output_offset_calculator,
766780 loader_t loader,
767781 storer_t storer) {
768- using traits = function_traits<func_t >;
769- using return_t = typename traits::result_type;
770- if (arg0_t == rt_binary_specializations[arg_index][0 ] &&
771- arg1_t == rt_binary_specializations[arg_index][1 ])
782+ if (ret_t == rt_binary_specializations[arg_index][0 ] &&
783+ arg0_t == rt_binary_specializations[arg_index][1 ] &&
784+ arg1_t == rt_binary_specializations[arg_index][2 ])
772785 launch_vectorized_templated_kernel<
773786 func_t ,
774787 array_t ,
775788 inp_calc_t ,
776789 out_calc_t ,
777790 loader_t ,
778791 storer_t ,
779- return_t ,
780792 decltype (c10::impl::ScalarTypeToCPPType<
781793 rt_binary_specializations[arg_index][0 ]>::t),
782794 decltype (c10::impl::ScalarTypeToCPPType<
783- rt_binary_specializations[arg_index][1 ]>::t)>(
795+ rt_binary_specializations[arg_index][1 ]>::t),
796+ decltype (c10::impl::ScalarTypeToCPPType<
797+ rt_binary_specializations[arg_index][2 ]>::t)>(
784798 numel,
785799 f,
786800 data,
@@ -820,7 +834,6 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
820834#ifdef USE_ROCM
821835 // Attempt to call specialized vectorized elementwise kernel
822836 // that enables interleaving.
823-
824837 if (check_binary_rt_types_for_specialization (iter) &&
825838 memory::can_vectorize_up_to<func_t >(data) > 1 ) {
826839 // constexpr to reduce the amount of kernels generated for
@@ -848,6 +861,7 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
848861 type_specialized_kernel_launcher,
849862 rt_binary_specializations.size ()>::
850863 with_args (
864+ iter.dtype (0 ),
851865 iter.input_dtype (0 ),
852866 iter.input_dtype (1 ),
853867 numel,
0 commit comments