@@ -882,6 +882,69 @@ struct type_specialized_kernel_launcher {
882882 }
883883};
884884
885+ template <int arg_index>
886+ struct type_specialized_broadcast_kernel_launcher {
887+ template <
888+ typename func_t ,
889+ typename array_t ,
890+ typename dtypes_t ,
891+ typename calc_t >
892+ static void apply (
893+ int64_t numel,
894+ func_t f,
895+ array_t data,
896+ dtypes_t dtypes,
897+ calc_t offset_calc) {
898+ using traits = function_traits<func_t >;
899+ using ret_t = typename traits::result_type;
900+ using arg0_t = typename traits::template arg<0 >::type;
901+ using arg1_t = typename traits::template arg<1 >::type;
902+ if (dtypes[0 ] == rt_binary_specializations[arg_index][0 ] &&
903+ dtypes[1 ] == rt_binary_specializations[arg_index][1 ] &&
904+ dtypes[2 ] == rt_binary_specializations[arg_index][2 ]) {
905+ using ret_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][0 ]>;
906+ using arg0_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][1 ]>;
907+ using arg1_cpp_t = c10::impl::ScalarTypeToCPPTypeT<rt_binary_specializations[arg_index][2 ]>;
908+ constexpr int grp_sz = 128 ;
909+ launch_legacy_kernel_manual_unroll<grp_sz, 4 >(numel, [=] GPU_LAMBDA (int idx, bool unrl) {
910+ if (unrl) {
911+ auto offsets0 = offset_calc.get (idx);
912+ auto offsets1 = offset_calc.get (idx + grp_sz);
913+ auto offsets2 = offset_calc.get (idx + grp_sz * 2 );
914+ auto offsets3 = offset_calc.get (idx + grp_sz * 3 );
915+ void * out0 = data[0 ] + offsets0[0 ];
916+ void * out1 = data[0 ] + offsets1[0 ];
917+ void * out2 = data[0 ] + offsets2[0 ];
918+ void * out3 = data[0 ] + offsets3[0 ];
919+ auto u = c10::load<arg0_cpp_t >(data[1 ] + offsets0[1 ]);
920+ auto v = c10::load<arg1_cpp_t >(data[2 ] + offsets0[2 ]);
921+ ret_t result0 = f (c10::convert<arg0_t >(u), c10::convert<arg1_t >(v));
922+ auto u1 = c10::load<arg0_cpp_t >(data[1 ] + offsets1[1 ]);
923+ auto v1 = c10::load<arg1_cpp_t >(data[2 ]+ offsets1[2 ]);
924+ ret_t result1 = f (c10::convert<arg0_t >(u1), c10::convert<arg1_t >(v1));
925+ auto u2 = c10::load<arg0_cpp_t >(data[1 ] + offsets2[1 ]);
926+ auto v2 = c10::load<arg1_cpp_t >(data[2 ] + offsets2[2 ]);
927+ ret_t result2 = f (c10::convert<arg0_t >(u2), c10::convert<arg1_t >(v2));
928+ auto u3 = c10::load<arg0_cpp_t >(data[1 ] + offsets3[1 ]);
929+ auto v3 = c10::load<arg1_cpp_t >(data[2 ] + offsets3[2 ]);
930+ ret_t result3 = f (c10::convert<arg0_t >(u3), c10::convert<arg1_t >(v3));
931+ *(ret_cpp_t *)out0 = c10::convert<ret_cpp_t >(result0);
932+ *(ret_cpp_t *)out1 = c10::convert<ret_cpp_t >(result1);
933+ *(ret_cpp_t *)out2 = c10::convert<ret_cpp_t >(result2);
934+ *(ret_cpp_t *)out3 = c10::convert<ret_cpp_t >(result3);
935+ } else {
936+ auto offsets = offset_calc.get (idx);
937+ void * out = data[0 ] + offsets[0 ];
938+ auto u = c10::load<arg0_cpp_t >(data[1 ] + offsets[1 ]);
939+ auto v = c10::load<arg1_cpp_t >(data[2 ] + offsets[2 ]);
940+ ret_t result = f (c10::convert<arg0_t >(u), c10::convert<arg1_t >(v));
941+ *(ret_cpp_t *)out = c10::convert<ret_cpp_t >(result);
942+ }
943+ });
944+ }
945+ }
946+ };
947+
885948} // namespace
886949#endif
887950
@@ -1000,6 +1063,32 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
10001063 }
10011064 auto offset_calc = ::make_offset_calculator<traits::arity + 1 >(iter);
10021065#ifdef USE_ROCM
1066+ if (check_binary_rt_types_for_specialization (iter)) {
1067+ // constexpr to reduce the amount of kernels generated for
1068+ // broadcast elementwise with mexed dtypes and limit which functors are actually
1069+ // applied to the load and store at compile time.
1070+ using func_tuple = typename traits::ArgsTuple;
1071+ if constexpr (
1072+ std::is_same_v<float , arg0_t > && traits::arity == 2 &&
1073+ check_binary_functor_types_for_specialization<
1074+ func_tuple,
1075+ float ,
1076+ float ,
1077+ traits::arity,
1078+ /* arg_num=*/ 0 >::check ()) {
1079+ memory::detail::static_unroll<
1080+ type_specialized_broadcast_kernel_launcher,
1081+ rt_binary_specializations.size ()>::with_args (
1082+ numel,
1083+ f,
1084+ data,
1085+ dtypes,
1086+ offset_calc
1087+ );
1088+ return ;
1089+ }
1090+ }
1091+
10031092 constexpr int grp_sz = 128 ;
10041093 launch_legacy_kernel_manual_unroll<grp_sz, 4 >(numel, [=] GPU_LAMBDA (int idx, bool unrl) {
10051094 if (unrl) {
0 commit comments