Skip to content

Commit 1cb818b

Browse files
carlobertollipruthvistony
authored andcommitted
[ROCM] Fix in-place aten sum with specialized templated kernels. (pytorch#151230)
We noticed a regression when doing aten.sum in-place (a+=b) and the type of the output is not the same as the functor. Co-authored by: Jerry Mannil <[email protected]> Pull Request resolved: pytorch#151230 Approved by: https://github.com/jeffdaily
1 parent 6e62a7c commit 1cb818b

File tree

2 files changed

+31
-17
lines changed

2 files changed

+31
-17
lines changed

aten/src/ATen/native/cuda/CUDALoops.cuh

Lines changed: 29 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -720,28 +720,41 @@ struct check_binary_functor_types_for_specialization<
720720
};
721721

722722
// The following is a list of type specializations for vectorized_templated
723-
// elementwise kernel. It refers to the first and second runtime types of the
724-
// arguments of a binary functor.
725-
723+
// elementwise kernel. The three types refer to runtime types of the output
724+
// tensor, first tensor argument, and the second tensor argument used for a
725+
// binary functor.
726726
constexpr std::array rt_binary_specializations = {
727-
std::array<c10::ScalarType, 2>(
727+
std::array<c10::ScalarType, 3>(
728728
{c10::CppTypeToScalarType<float>::value,
729+
c10::CppTypeToScalarType<float>::value,
729730
c10::CppTypeToScalarType<BFloat16>::value}),
730-
std::array<c10::ScalarType, 2>(
731+
std::array<c10::ScalarType, 3>(
732+
{c10::CppTypeToScalarType<float>::value,
733+
c10::CppTypeToScalarType<BFloat16>::value,
734+
c10::CppTypeToScalarType<float>::value}),
735+
std::array<c10::ScalarType, 3>(
731736
{c10::CppTypeToScalarType<BFloat16>::value,
737+
c10::CppTypeToScalarType<BFloat16>::value,
732738
c10::CppTypeToScalarType<float>::value}),
733-
std::array<c10::ScalarType, 2>(
739+
std::array<c10::ScalarType, 3>(
734740
{c10::CppTypeToScalarType<float>::value,
741+
c10::CppTypeToScalarType<float>::value,
735742
c10::CppTypeToScalarType<Half>::value}),
736-
std::array<c10::ScalarType, 2>(
743+
std::array<c10::ScalarType, 3>(
744+
{c10::CppTypeToScalarType<float>::value,
745+
c10::CppTypeToScalarType<Half>::value,
746+
c10::CppTypeToScalarType<float>::value}),
747+
std::array<c10::ScalarType, 3>(
737748
{c10::CppTypeToScalarType<Half>::value,
749+
c10::CppTypeToScalarType<Half>::value,
738750
c10::CppTypeToScalarType<float>::value})};
739751

740752
bool check_binary_rt_types_for_specialization(TensorIteratorBase& iter) {
741753
if (iter.ninputs() != 2)
742754
return false;
743755
for (auto spec : rt_binary_specializations)
744-
if (iter.input_dtype(0) == spec[0] && iter.input_dtype(1) == spec[1])
756+
if (iter.dtype(0) == spec[0] && iter.input_dtype(0) == spec[1] &&
757+
iter.input_dtype(1) == spec[2])
745758
return true;
746759
return false;
747760
}
@@ -756,6 +769,7 @@ struct type_specialized_kernel_launcher {
756769
typename loader_t,
757770
typename storer_t>
758771
static void apply(
772+
ScalarType ret_t,
759773
ScalarType arg0_t,
760774
ScalarType arg1_t,
761775
int64_t numel,
@@ -765,22 +779,22 @@ struct type_specialized_kernel_launcher {
765779
out_calc_t output_offset_calculator,
766780
loader_t loader,
767781
storer_t storer) {
768-
using traits = function_traits<func_t>;
769-
using return_t = typename traits::result_type;
770-
if (arg0_t == rt_binary_specializations[arg_index][0] &&
771-
arg1_t == rt_binary_specializations[arg_index][1])
782+
if (ret_t == rt_binary_specializations[arg_index][0] &&
783+
arg0_t == rt_binary_specializations[arg_index][1] &&
784+
arg1_t == rt_binary_specializations[arg_index][2])
772785
launch_vectorized_templated_kernel<
773786
func_t,
774787
array_t,
775788
inp_calc_t,
776789
out_calc_t,
777790
loader_t,
778791
storer_t,
779-
return_t,
780792
decltype(c10::impl::ScalarTypeToCPPType<
781793
rt_binary_specializations[arg_index][0]>::t),
782794
decltype(c10::impl::ScalarTypeToCPPType<
783-
rt_binary_specializations[arg_index][1]>::t)>(
795+
rt_binary_specializations[arg_index][1]>::t),
796+
decltype(c10::impl::ScalarTypeToCPPType<
797+
rt_binary_specializations[arg_index][2]>::t)>(
784798
numel,
785799
f,
786800
data,
@@ -820,7 +834,6 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
820834
#ifdef USE_ROCM
821835
// Attempt to call specialized vectorized elementwise kernel
822836
// that enables interleaving.
823-
824837
if (check_binary_rt_types_for_specialization(iter) &&
825838
memory::can_vectorize_up_to<func_t>(data) > 1) {
826839
// constexpr to reduce the amount of kernels generated for
@@ -848,6 +861,7 @@ void gpu_kernel_impl(TensorIteratorBase& iter, const func_t& f) {
848861
type_specialized_kernel_launcher,
849862
rt_binary_specializations.size()>::
850863
with_args(
864+
iter.dtype(0),
851865
iter.input_dtype(0),
852866
iter.input_dtype(1),
853867
numel,

aten/src/ATen/native/cuda/MemoryAccess.cuh

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -407,8 +407,8 @@ struct vectorized_templated {
407407
// float(float,bfloat16) and functor add on float(float,float).
408408
template <typename scalar_t>
409409
__device__ inline void store(scalar_t* from, int idx) {
410-
using vec_t = aligned_vector<scalar_t, vec_size>;
411-
scalar_t* to = reinterpret_cast<scalar_t*>(data[0]) + block_work_size * idx;
410+
using vec_t = aligned_vector<CastToT, vec_size>;
411+
CastToT* to = reinterpret_cast<CastToT*>(data[0]) + block_work_size * idx;
412412
vec_t* to_ = reinterpret_cast<vec_t*>(to);
413413
int thread_idx = threadIdx.x;
414414
#pragma unroll

0 commit comments

Comments
 (0)