5151
5252namespace at ::native {
5353
54+ #ifdef USE_ROCM
55+ // Custom configuration for vectorized elementwise kernel
56+ // with template instantiation.
57+ namespace vectorized_templated_config {
58+ constexpr int num_threads () {
59+ return 512 ;
60+ }
61+
62+ constexpr int elems_per_thread () {
63+ return 32 ;
64+ }
65+
66+ constexpr int block_work_size () {
67+ return elems_per_thread () * num_threads ();
68+ }
69+ } // namespace vectorized_templated_config
70+ #endif
5471
5572template <typename args_t , size_t ... Is>
5673constexpr auto sum_of_sizes (args_t args, std::index_sequence<Is...>) {
@@ -255,6 +272,139 @@ static inline void launch_vectorized_kernel(
255272 }
256273}
257274
275+ #ifdef USE_ROCM
276+ template <
277+ int vec_size,
278+ typename func_t ,
279+ typename array_t ,
280+ typename inp_calc_t ,
281+ typename out_calc_t ,
282+ typename loader_t ,
283+ typename storer_t ,
284+ typename OutputType,
285+ typename ... InputTypes>
286+ C10_LAUNCH_BOUNDS_1 (vectorized_templated_config::num_threads())
287+ __global__ void vectorized_templated_elementwise_kernel(
288+ int N,
289+ func_t f,
290+ array_t data,
291+ inp_calc_t inp_calc,
292+ out_calc_t out_calc,
293+ loader_t loader,
294+ storer_t storer) {
295+ int remaining =
296+ N - vectorized_templated_config::block_work_size () * blockIdx .x ;
297+ if (remaining <
298+ vectorized_templated_config::block_work_size ()) { // if this block handles
299+ // the reminder,
300+ // just do a naive unrolled loop
301+ auto policy = memory::policies::unroll_base<
302+ vectorized_templated_config::num_threads (),
303+ array_t ,
304+ inp_calc_t ,
305+ out_calc_t ,
306+ loader_t ,
307+ storer_t ,
308+ vectorized_templated_config::elems_per_thread ()>(
309+ data, remaining, inp_calc, out_calc, loader, storer);
310+ elementwise_kernel_helper (f, policy);
311+ } else { // if this block has a full `block_work_size` data to handle, use
312+ // vectorized memory access
313+ elementwise_kernel_helper (
314+ f,
315+ memory::policies::vectorized_templated<
316+ vec_size,
317+ array_t ,
318+ vectorized_templated_config::elems_per_thread (),
319+ vectorized_templated_config::num_threads (),
320+ OutputType,
321+ InputTypes...>(data));
322+ }
323+ }
324+
325+ // This function assume trivial 1d and supports template specialization
326+ // to avoid dynamic casting.
327+ // Input vectorization size is based on runtime information, i.e.
328+ // the actual data types of the input and output tensor and cannot
329+ // be determined using the functor type, as in regular non-templated
330+ // vectorized kernels. The caller is in charge of selecting the correct input
331+ // vectorization length.
332+ template <
333+ typename func_t ,
334+ typename array_t ,
335+ typename inp_calc_t ,
336+ typename out_calc_t ,
337+ typename loader_t ,
338+ typename storer_t ,
339+ typename OutputType,
340+ typename ... InputTypes>
341+ static inline void launch_vectorized_templated_kernel (
342+ int64_t N,
343+ const func_t & f,
344+ array_t data,
345+ inp_calc_t ic,
346+ out_calc_t oc,
347+ loader_t l,
348+ storer_t s) {
349+ TORCH_INTERNAL_ASSERT (N > 0 && N <= std::numeric_limits<int32_t >::max ());
350+ using traits = function_traits<func_t >;
351+ int64_t grid = (N + vectorized_templated_config::block_work_size () - 1 ) /
352+ vectorized_templated_config::block_work_size ();
353+ auto stream = at::cuda::getCurrentCUDAStream ();
354+ int vec_size = memory::can_vectorize_up_to<func_t >(data);
355+ switch (vec_size) {
356+ case 8 :
357+ vectorized_templated_elementwise_kernel<
358+ 8 ,
359+ func_t ,
360+ array_t ,
361+ inp_calc_t ,
362+ out_calc_t ,
363+ loader_t ,
364+ storer_t ,
365+ OutputType,
366+ InputTypes...>
367+ <<<grid, vectorized_templated_config::num_threads(), 0 , stream>>> (
368+ N, f, data, ic, oc, l, s);
369+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
370+ break ;
371+ case 4 :
372+ vectorized_templated_elementwise_kernel<
373+ 4 ,
374+ func_t ,
375+ array_t ,
376+ inp_calc_t ,
377+ out_calc_t ,
378+ loader_t ,
379+ storer_t ,
380+ OutputType,
381+ InputTypes...>
382+ <<<grid, vectorized_templated_config::num_threads(), 0 , stream>>> (
383+ N, f, data, ic, oc, l, s);
384+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
385+ break ;
386+ case 2 :
387+ vectorized_templated_elementwise_kernel<
388+ 2 ,
389+ func_t ,
390+ array_t ,
391+ inp_calc_t ,
392+ out_calc_t ,
393+ loader_t ,
394+ storer_t ,
395+ OutputType,
396+ InputTypes...>
397+ <<<grid, vectorized_templated_config::num_threads(), 0 , stream>>> (
398+ N, f, data, ic, oc, l, s);
399+ C10_CUDA_KERNEL_LAUNCH_CHECK ();
400+ break ;
401+ default :
402+ // vector size 1 is not handled as part of vectorize_templated kernel
403+ TORCH_INTERNAL_ASSERT (false , " Unexpected vectorization size" );
404+ }
405+ }
406+ #endif
407+
258408template <
259409 typename func_t ,
260410 typename array_t ,
@@ -500,6 +650,46 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
500650#endif
501651}
502652
653+ #ifdef USE_ROCM
654+ namespace {
655+ template <typename TupleLike, size_t arity, size_t arg_num = 0 >
656+ struct check_types {
657+ constexpr static inline bool check () {
658+ if constexpr (arity != 2 )
659+ return false ;
660+ if constexpr (arg_num == 0 ) {
661+ using SelectedType = std::tuple_element_t <arg_num, TupleLike>;
662+ if constexpr (std::is_same_v<float , SelectedType>)
663+ return check_types<TupleLike, arity, arg_num + 1 >::check ();
664+ } else if constexpr (arg_num == 1 ) {
665+ using SelectedType2 = std::tuple_element_t <arg_num, TupleLike>;
666+ if constexpr (std::is_same_v<float , SelectedType2>)
667+ return check_types<TupleLike, arity, arg_num + 1 >::check ();
668+ }
669+ return false ;
670+ }
671+ };
672+
673+ // Bottom case: if we got this far, assume correct type matching except
674+ // when there are no arguments (arity == 0).
675+ template <typename TupleLike, size_t arity>
676+ struct check_types <TupleLike, arity, arity> {
677+ constexpr static inline bool check () {
678+ if constexpr (arity != 0 )
679+ return true ;
680+ return false ;
681+ }
682+ };
683+
684+ template <typename TupleLike>
685+ struct check_types <TupleLike, 0 , 0 > {
686+ constexpr static inline bool check () {
687+ return false ;
688+ }
689+ };
690+ } // namespace
691+ #endif
692+
503693template <typename func_t >
504694void gpu_kernel_impl (TensorIteratorBase& iter, const func_t & f) {
505695 if (!needs_dynamic_casting<func_t >::check (iter)) {
@@ -524,6 +714,45 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
524714
525715 if (contiguous) {
526716#ifdef USE_ROCM
717+ // Attempt to call specialized vectorized elementwise kernel
718+ // that enables interleaving.
719+ using float_map = c10::CppTypeToScalarType<float >;
720+ using bfloat16_map = c10::CppTypeToScalarType<BFloat16>;
721+ if (iter.ninputs () == 2 && iter.input_dtype (0 ) == float_map::value &&
722+ iter.input_dtype (1 ) == bfloat16_map::value &&
723+ memory::can_vectorize_up_to<func_t >(data) > 1 ) {
724+ // constexpr to reduce the amount of kernels (empty) generated for
725+ // vectorized templated elementwise and limit which functors are actually
726+ // applied to the load and store at compile time.
727+ using func_tuple = typename traits::ArgsTuple;
728+ if constexpr (
729+ std::is_same_v<float , arg0_t > && traits::arity == 2 &&
730+ check_types<func_tuple, traits::arity, 0 >::check ()) {
731+ auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
732+ auto output_offset_calculator = TrivialOffsetCalculator<1 >();
733+ auto loader = memory::LoadWithCast<traits::arity>(iter);
734+ auto storer = memory::StoreWithCast<1 >(iter);
735+ launch_vectorized_templated_kernel<
736+ func_t ,
737+ std::array<char *, ntensors>,
738+ decltype (input_offset_calculator),
739+ decltype (output_offset_calculator),
740+ decltype (loader),
741+ decltype (storer),
742+ float ,
743+ float ,
744+ BFloat16>(
745+ numel,
746+ f,
747+ data,
748+ input_offset_calculator,
749+ output_offset_calculator,
750+ loader,
751+ storer);
752+ return ;
753+ }
754+ }
755+
527756 std::array<ScalarType, ntensors> dtypes;
528757 auto inner_strides = iter.get_inner_strides ();
529758 std::array<int , ntensors> strides;
0 commit comments