@@ -893,6 +893,69 @@ struct type_specialized_kernel_launcher {
893893 }
894894};
895895
896+ template <int arg_index>
897+ struct type_specialized_broadcast_kernel_launcher {
898+ template <
899+ typename func_t ,
900+ typename array_t ,
901+ typename dtypes_t ,
902+ typename calc_t >
903+ static void apply (
904+ int64_t numel,
905+ func_t f,
906+ array_t data,
907+ dtypes_t dtypes,
908+ calc_t offset_calc) {
909+ using traits = function_traits<func_t >;
910+ using ret_t = typename traits::result_type;
911+ using arg0_t = typename traits::template arg<0 >::type;
912+ using arg1_t = typename traits::template arg<1 >::type;
913+ if (dtypes[0 ] == rt_binary_specializations[arg_index][0 ] &&
914+ dtypes[1 ] == rt_binary_specializations[arg_index][1 ] &&
915+ dtypes[2 ] == rt_binary_specializations[arg_index][2 ]) {
916+ using ret_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][0 ]>;
917+ using arg0_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][1 ]>;
918+ using arg1_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][2 ]>;
919+ constexpr int grp_sz = 128 ;
920+ launch_legacy_kernel_manual_unroll<grp_sz, 4 >(numel, [=] GPU_LAMBDA (int idx, bool unrl) {
921+ if (unrl) {
922+ auto offsets0 = offset_calc.get (idx);
923+ auto offsets1 = offset_calc.get (idx + grp_sz);
924+ auto offsets2 = offset_calc.get (idx + grp_sz * 2 );
925+ auto offsets3 = offset_calc.get (idx + grp_sz * 3 );
926+ void * out0 = data[0 ] + offsets0[0 ];
927+ void * out1 = data[0 ] + offsets1[0 ];
928+ void * out2 = data[0 ] + offsets2[0 ];
929+ void * out3 = data[0 ] + offsets3[0 ];
930+ auto u = c10::load<arg0_cpp_t >(data[1 ] + offsets0[1 ]);
931+ auto v = c10::load<arg1_cpp_t >(data[2 ] + offsets0[2 ]);
932+ ret_t result0 = f (c10::convert<arg0_t >(u), c10::convert<arg1_t >(v));
933+ auto u1 = c10::load<arg0_cpp_t >(data[1 ] + offsets1[1 ]);
934+ auto v1 = c10::load<arg1_cpp_t >(data[2 ]+ offsets1[2 ]);
935+ ret_t result1 = f (c10::convert<arg0_t >(u1), c10::convert<arg1_t >(v1));
936+ auto u2 = c10::load<arg0_cpp_t >(data[1 ] + offsets2[1 ]);
937+ auto v2 = c10::load<arg1_cpp_t >(data[2 ] + offsets2[2 ]);
938+ ret_t result2 = f (c10::convert<arg0_t >(u2), c10::convert<arg1_t >(v2));
939+ auto u3 = c10::load<arg0_cpp_t >(data[1 ] + offsets3[1 ]);
940+ auto v3 = c10::load<arg1_cpp_t >(data[2 ] + offsets3[2 ]);
941+ ret_t result3 = f (c10::convert<arg0_t >(u3), c10::convert<arg1_t >(v3));
942+ *(ret_cpp_t *)out0 = c10::convert<ret_cpp_t >(result0);
943+ *(ret_cpp_t *)out1 = c10::convert<ret_cpp_t >(result1);
944+ *(ret_cpp_t *)out2 = c10::convert<ret_cpp_t >(result2);
945+ *(ret_cpp_t *)out3 = c10::convert<ret_cpp_t >(result3);
946+ } else {
947+ auto offsets = offset_calc.get (idx);
948+ void * out = data[0 ] + offsets[0 ];
949+ auto u = c10::load<arg0_cpp_t >(data[1 ] + offsets[1 ]);
950+ auto v = c10::load<arg1_cpp_t >(data[2 ] + offsets[2 ]);
951+ ret_t result = f (c10::convert<arg0_t >(u), c10::convert<arg1_t >(v));
952+ *(ret_cpp_t *)out = c10::convert<ret_cpp_t >(result);
953+ }
954+ });
955+ }
956+ }
957+ };
958+
896959} // namespace
897960#endif
898961
@@ -995,6 +1058,32 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
9951058 }
9961059 auto offset_calc = ::make_offset_calculator<traits::arity + 1 >(iter);
9971060#ifdef USE_ROCM
1061+ if (check_binary_rt_types_for_specialization (iter)) {
1062+ // constexpr to reduce the amount of kernels generated for
1063+ // broadcast elementwise with mexed dtypes and limit which functors are actually
1064+ // applied to the load and store at compile time.
1065+ using func_tuple = typename traits::ArgsTuple;
1066+ if constexpr (
1067+ std::is_same_v<float , arg0_t > && traits::arity == 2 &&
1068+ check_binary_functor_types_for_specialization<
1069+ func_tuple,
1070+ float ,
1071+ float ,
1072+ traits::arity,
1073+ /* arg_num=*/ 0 >::check ()) {
1074+ memory::detail::static_unroll<
1075+ type_specialized_broadcast_kernel_launcher,
1076+ rt_binary_specializations.size ()>::with_args (
1077+ numel,
1078+ f,
1079+ data,
1080+ dtypes,
1081+ offset_calc
1082+ );
1083+ return ;
1084+ }
1085+ }
1086+
9981087 constexpr int grp_sz = 128 ;
9991088 launch_legacy_kernel_manual_unroll<grp_sz, 4 >(numel, [=] GPU_LAMBDA (int idx, bool unrl) {
10001089 if (unrl) {
0 commit comments