Skip to content

Commit f001ac4

Browse files
carlobertolliamd-hhashemi
authored andcommitted
[ROCm] Input vectorization in elementwise kernels for tensors with heterogeneous types (pytorch#147527)
This patch exemplifies its use for input tensors with types (float,bfloat16) when functor type is float(float,float). Pull Request resolved: pytorch#147527 Approved by: https://github.com/jeffdaily Co-authored-by: Hashem Hashemi <[email protected]>
1 parent 3cdddfe commit f001ac4

File tree

2 files changed

+388
-24
lines changed

2 files changed

+388
-24
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 229 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -51,6 +51,23 @@
5151

5252
namespace at::native {
5353

54+
#ifdef USE_ROCM
55+
// Custom configuration for vectorized elementwise kernel
56+
// with template instantiation.
57+
namespace vectorized_templated_config {
58+
constexpr int num_threads() {
59+
return 512;
60+
}
61+
62+
constexpr int elems_per_thread() {
63+
return 32;
64+
}
65+
66+
constexpr int block_work_size() {
67+
return elems_per_thread() * num_threads();
68+
}
69+
} // namespace vectorized_templated_config
70+
#endif
5471

5572
template <typename args_t, size_t... Is>
5673
constexpr auto sum_of_sizes(args_t args, std::index_sequence<Is...>) {
@@ -255,6 +272,139 @@ static inline void launch_vectorized_kernel(
255272
}
256273
}
257274

275+
#ifdef USE_ROCM
276+
template <
277+
int vec_size,
278+
typename func_t,
279+
typename array_t,
280+
typename inp_calc_t,
281+
typename out_calc_t,
282+
typename loader_t,
283+
typename storer_t,
284+
typename OutputType,
285+
typename... InputTypes>
286+
C10_LAUNCH_BOUNDS_1(vectorized_templated_config::num_threads())
287+
__global__ void vectorized_templated_elementwise_kernel(
288+
int N,
289+
func_t f,
290+
array_t data,
291+
inp_calc_t inp_calc,
292+
out_calc_t out_calc,
293+
loader_t loader,
294+
storer_t storer) {
295+
int remaining =
296+
N - vectorized_templated_config::block_work_size() * blockIdx.x;
297+
if (remaining <
298+
vectorized_templated_config::block_work_size()) { // if this block handles
299+
// the reminder,
300+
// just do a naive unrolled loop
301+
auto policy = memory::policies::unroll_base<
302+
vectorized_templated_config::num_threads(),
303+
array_t,
304+
inp_calc_t,
305+
out_calc_t,
306+
loader_t,
307+
storer_t,
308+
vectorized_templated_config::elems_per_thread()>(
309+
data, remaining, inp_calc, out_calc, loader, storer);
310+
elementwise_kernel_helper(f, policy);
311+
} else { // if this block has a full `block_work_size` data to handle, use
312+
// vectorized memory access
313+
elementwise_kernel_helper(
314+
f,
315+
memory::policies::vectorized_templated<
316+
vec_size,
317+
array_t,
318+
vectorized_templated_config::elems_per_thread(),
319+
vectorized_templated_config::num_threads(),
320+
OutputType,
321+
InputTypes...>(data));
322+
}
323+
}
324+
325+
// This function assume trivial 1d and supports template specialization
326+
// to avoid dynamic casting.
327+
// Input vectorization size is based on runtime information, i.e.
328+
// the actual data types of the input and output tensor and cannot
329+
// be determined using the functor type, as in regular non-templated
330+
// vectorized kernels. The caller is in charge of selecting the correct input
331+
// vectorization length.
332+
template <
333+
typename func_t,
334+
typename array_t,
335+
typename inp_calc_t,
336+
typename out_calc_t,
337+
typename loader_t,
338+
typename storer_t,
339+
typename OutputType,
340+
typename... InputTypes>
341+
static inline void launch_vectorized_templated_kernel(
342+
int64_t N,
343+
const func_t& f,
344+
array_t data,
345+
inp_calc_t ic,
346+
out_calc_t oc,
347+
loader_t l,
348+
storer_t s) {
349+
TORCH_INTERNAL_ASSERT(N > 0 && N <= std::numeric_limits<int32_t>::max());
350+
using traits = function_traits<func_t>;
351+
int64_t grid = (N + vectorized_templated_config::block_work_size() - 1) /
352+
vectorized_templated_config::block_work_size();
353+
auto stream = at::cuda::getCurrentCUDAStream();
354+
int vec_size = memory::can_vectorize_up_to<func_t>(data);
355+
switch (vec_size) {
356+
case 8:
357+
vectorized_templated_elementwise_kernel<
358+
8,
359+
func_t,
360+
array_t,
361+
inp_calc_t,
362+
out_calc_t,
363+
loader_t,
364+
storer_t,
365+
OutputType,
366+
InputTypes...>
367+
<<<grid, vectorized_templated_config::num_threads(), 0, stream>>>(
368+
N, f, data, ic, oc, l, s);
369+
C10_CUDA_KERNEL_LAUNCH_CHECK();
370+
break;
371+
case 4:
372+
vectorized_templated_elementwise_kernel<
373+
4,
374+
func_t,
375+
array_t,
376+
inp_calc_t,
377+
out_calc_t,
378+
loader_t,
379+
storer_t,
380+
OutputType,
381+
InputTypes...>
382+
<<<grid, vectorized_templated_config::num_threads(), 0, stream>>>(
383+
N, f, data, ic, oc, l, s);
384+
C10_CUDA_KERNEL_LAUNCH_CHECK();
385+
break;
386+
case 2:
387+
vectorized_templated_elementwise_kernel<
388+
2,
389+
func_t,
390+
array_t,
391+
inp_calc_t,
392+
out_calc_t,
393+
loader_t,
394+
storer_t,
395+
OutputType,
396+
InputTypes...>
397+
<<<grid, vectorized_templated_config::num_threads(), 0, stream>>>(
398+
N, f, data, ic, oc, l, s);
399+
C10_CUDA_KERNEL_LAUNCH_CHECK();
400+
break;
401+
default:
402+
// vector size 1 is not handled as part of vectorize_templated kernel
403+
TORCH_INTERNAL_ASSERT(false, "Unexpected vectorization size");
404+
}
405+
}
406+
#endif
407+
258408
template <
259409
typename func_t,
260410
typename array_t,
@@ -500,6 +650,46 @@ void gpu_kernel_impl_nocast(TensorIteratorBase& iter, const func_t& f) {
500650
#endif
501651
}
502652

653+
#ifdef USE_ROCM
654+
namespace {
655+
template <typename TupleLike, size_t arity, size_t arg_num = 0>
656+
struct check_types {
657+
constexpr static inline bool check() {
658+
if constexpr (arity != 2)
659+
return false;
660+
if constexpr (arg_num == 0) {
661+
using SelectedType = std::tuple_element_t<arg_num, TupleLike>;
662+
if constexpr (std::is_same_v<float, SelectedType>)
663+
return check_types<TupleLike, arity, arg_num + 1>::check();
664+
} else if constexpr (arg_num == 1) {
665+
using SelectedType2 = std::tuple_element_t<arg_num, TupleLike>;
666+
if constexpr (std::is_same_v<float, SelectedType2>)
667+
return check_types<TupleLike, arity, arg_num + 1>::check();
668+
}
669+
return false;
670+
}
671+
};
672+
673+
// Bottom case: if we got this far, assume correct type matching except
674+
// when there are no arguments (arity == 0).
675+
template <typename TupleLike, size_t arity>
676+
struct check_types<TupleLike, arity, arity> {
677+
constexpr static inline bool check() {
678+
if constexpr (arity != 0)
679+
return true;
680+
return false;
681+
}
682+
};
683+
684+
template <typename TupleLike>
685+
struct check_types<TupleLike, 0, 0> {
686+
constexpr static inline bool check() {
687+
return false;
688+
}
689+
};
690+
} // namespace
691+
#endif
692+
503693
template <typename func_t>
504694
void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
505695
if (!needs_dynamic_casting<func_t>::check(iter)) {
@@ -524,6 +714,45 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
524714

525715
if (contiguous) {
526716
#ifdef USE_ROCM
717+
// Attempt to call specialized vectorized elementwise kernel
718+
// that enables interleaving.
719+
using float_map = c10::CppTypeToScalarType<float>;
720+
using bfloat16_map = c10::CppTypeToScalarType<BFloat16>;
721+
if (iter.ninputs() == 2 && iter.input_dtype(0) == float_map::value &&
722+
iter.input_dtype(1) == bfloat16_map::value &&
723+
memory::can_vectorize_up_to<func_t>(data) > 1) {
724+
// constexpr to reduce the amount of kernels (empty) generated for
725+
// vectorized templated elementwise and limit which functors are actually
726+
// applied to the load and store at compile time.
727+
using func_tuple = typename traits::ArgsTuple;
728+
if constexpr (
729+
std::is_same_v<float, arg0_t> && traits::arity == 2 &&
730+
check_types<func_tuple, traits::arity, 0>::check()) {
731+
auto input_offset_calculator = TrivialOffsetCalculator<traits::arity>();
732+
auto output_offset_calculator = TrivialOffsetCalculator<1>();
733+
auto loader = memory::LoadWithCast<traits::arity>(iter);
734+
auto storer = memory::StoreWithCast<1>(iter);
735+
launch_vectorized_templated_kernel<
736+
func_t,
737+
std::array<char*, ntensors>,
738+
decltype(input_offset_calculator),
739+
decltype(output_offset_calculator),
740+
decltype(loader),
741+
decltype(storer),
742+
float,
743+
float,
744+
BFloat16>(
745+
numel,
746+
f,
747+
data,
748+
input_offset_calculator,
749+
output_offset_calculator,
750+
loader,
751+
storer);
752+
return;
753+
}
754+
}
755+
527756
std::array<ScalarType, ntensors> dtypes;
528757
auto inner_strides = iter.get_inner_strides();
529758
std::array<int, ntensors> strides;

0 commit comments

Comments
 (0)