Skip to content

Commit 4cebc8c

Browse files
[slice]add IndexPutWithSortKernel in index_elementwise_get_grad slice… (#74344)
* [slice]add IndexPutWithSortKernel in index_elementwise_get_grad slice-check * slice-check * slice-check * update slice-check
1 parent 2d61a9b commit 4cebc8c

17 files changed

+802
-26
lines changed

paddle/fluid/pybind/slice_utils.h

Lines changed: 19 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -821,6 +821,7 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
821821
}
822822

823823
AdvancedIndex ad = AdvancedIndex(tensor, indices_int64);
824+
const bool is_combined = false;
824825
const bool accumulate = false;
825826

826827
return index_elementwise_get_ad_func(self_tensor,
@@ -830,7 +831,8 @@ static paddle::Tensor getValueForBoolTensor(const paddle::Tensor& tensor,
830831
ad.indexed_sizes,
831832
ad.indexed_strides,
832833
slice_offset,
833-
accumulate);
834+
accumulate,
835+
is_combined);
834836
} else {
835837
if (bool_index.shape().size() == 1)
836838
return gather_ad_func(tensor, bool_2_idx);
@@ -1286,23 +1288,22 @@ static void ApplyGetitem(const int index_size,
12861288
&transed_index_int64);
12871289

12881290
AdvancedIndex ad = AdvancedIndex(*transed_tensor, transed_index_int64);
1289-
if (index_size == 1) {
1290-
paddle::Tensor flattened_tensor =
1291-
flatten_ad_func((*transed_index)[0], 0, -1);
1292-
*out = gather_ad_func(*transed_tensor, flattened_tensor);
1293-
*out = reshape_ad_func(*out, ad.src_sizes);
1294-
} else {
1295-
const bool accumulate = true;
1296-
*out = index_elementwise_get_ad_func(*self_tensor,
1297-
ad.indices,
1298-
ad.src_sizes,
1299-
ad.src_strides,
1300-
ad.indexed_sizes,
1301-
ad.indexed_strides,
1302-
slice_offset,
1303-
accumulate);
1304-
}
1305-
1291+
// is_combined:
1292+
// Distinguishes between regular indexing (single index) and combined
1293+
// indexing (multiple indices). When false (single index case), enables
1294+
// optimized backward pass using IndexPutWithSortKernel for better
1295+
// performance.
1296+
const bool is_combined = (index_size == 1) ? false : true;
1297+
const bool accumulate = true;
1298+
*out = index_elementwise_get_ad_func(*self_tensor,
1299+
ad.indices,
1300+
ad.src_sizes,
1301+
ad.src_strides,
1302+
ad.indexed_sizes,
1303+
ad.indexed_strides,
1304+
slice_offset,
1305+
accumulate,
1306+
is_combined);
13061307
return;
13071308
} else {
13081309
paddle::Tensor transed_advanced_index_tensor;

paddle/phi/infermeta/backward.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2168,6 +2168,7 @@ void IndexElementwiseGetGradInferMeta(
21682168
const std::vector<int64_t>& index_strides,
21692169
const int64_t slice_offset,
21702170
const bool accumulate,
2171+
const bool is_combined,
21712172
MetaTensor* x_grad) {
21722173
if (x_grad) {
21732174
x_grad->share_meta(x);

paddle/phi/infermeta/backward.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -788,5 +788,6 @@ void IndexElementwiseGetGradInferMeta(
788788
const std::vector<int64_t>& index_strides,
789789
const int64_t slice_offset,
790790
const bool accumulate,
791+
const bool is_combined,
791792
MetaTensor* x_grad);
792793
} // namespace phi

paddle/phi/infermeta/binary.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2598,6 +2598,7 @@ void IndexElementwiseGetInferMeta(const MetaTensor& x,
25982598
const std::vector<int64_t>& index_stride,
25992599
const int64_t slice_offset,
26002600
const bool accumulate,
2601+
const bool is_combined,
26012602
MetaTensor* out) {
26022603
out->set_dims(common::make_ddim(input_dims));
26032604
out->set_dtype(x.dtype());

paddle/phi/infermeta/binary.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -493,6 +493,7 @@ void IndexElementwiseGetInferMeta(const MetaTensor& x,
493493
const std::vector<int64_t>& index_stride,
494494
const int64_t slice_offset,
495495
const bool accumulate,
496+
const bool is_combined,
496497
MetaTensor* out);
497498

498499
void KronInferMeta(const MetaTensor& x, const MetaTensor& y, MetaTensor* out);

paddle/phi/kernels/cpu/index_elementwise_get_grad_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,7 @@ void IndexElementwiseGetGradKernel(const Context& dev_ctx,
131131
const std::vector<int64_t>& index_strides,
132132
const int64_t slice_offset,
133133
const bool accumulate,
134+
const bool is_combined,
134135
DenseTensor* x_grad) {
135136
dev_ctx.template Alloc<T>(x_grad);
136137
auto dxt = phi::EigenVector<T>::Flatten(*x_grad);

paddle/phi/kernels/cpu/index_elementwise_get_kernel.cc

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ void IndexElementwiseGetKernel(const Context& dev_ctx,
100100
const std::vector<int64_t>& index_stride,
101101
const int64_t slice_offset,
102102
const bool accumulate,
103+
const bool is_combined,
103104
DenseTensor* out) {
104105
const auto& index_type = index[0]->dtype();
105106
PADDLE_ENFORCE_EQ(index_type == phi::DataType::INT64,
Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#include "paddle/phi/kernels/funcs/radix_sort.h"
16+
#include "paddle/phi/common/memory_utils.h"
17+
18+
namespace phi {
19+
namespace funcs {
20+
21+
#ifdef PADDLE_WITH_CUDA
22+
namespace {
23+
template <typename T>
24+
struct CudaType {
25+
using type = T;
26+
};
27+
28+
template <>
29+
struct CudaType<int64_t> {
30+
using type = long long; // NOLINT
31+
};
32+
33+
#define PADDLE_CUB_WRAPPER(func, ...) \
34+
do { \
35+
size_t temp_storage_bytes = 0; \
36+
func(nullptr, temp_storage_bytes, __VA_ARGS__); \
37+
auto temp_storage = \
38+
phi::memory_utils::Alloc(dev_ctx.GetPlace(), temp_storage_bytes); \
39+
func(temp_storage->ptr(), temp_storage_bytes, __VA_ARGS__); \
40+
} while (0)
41+
42+
} // namespace
43+
44+
template <typename key_t, int value_size>
45+
void RadixSortPairsImpl(const phi::GPUContext& dev_ctx,
46+
const key_t* keys_in,
47+
key_t* keys_out,
48+
const OpaqueTypeRadix<value_size>* values_in,
49+
OpaqueTypeRadix<value_size>* values_out,
50+
int64_t n,
51+
bool descending,
52+
int64_t begin_bit,
53+
int64_t end_bit) {
54+
PADDLE_ENFORCE_LE(
55+
n,
56+
std::numeric_limits<int>::max(),
57+
phi::errors::InvalidArgument(
58+
"CUB sort does not support sorting more than INT_MAX elements"));
59+
60+
using key_t_ = typename CudaType<key_t>::type;
61+
62+
phi::Allocator::AllocationPtr keys_out_owner;
63+
if (keys_out == nullptr) {
64+
keys_out_owner =
65+
phi::memory_utils::Alloc(dev_ctx.GetPlace(), n * sizeof(key_t));
66+
keys_out = reinterpret_cast<key_t*>(keys_out_owner->ptr());
67+
}
68+
69+
const key_t_* keys_in_ = reinterpret_cast<const key_t_*>(keys_in);
70+
key_t_* keys_out_ = reinterpret_cast<key_t_*>(keys_out);
71+
72+
if (descending) {
73+
PADDLE_CUB_WRAPPER(cub::DeviceRadixSort::SortPairsDescending,
74+
keys_in_,
75+
keys_out_,
76+
values_in,
77+
values_out,
78+
static_cast<int>(n),
79+
begin_bit,
80+
end_bit,
81+
dev_ctx.stream());
82+
} else {
83+
PADDLE_CUB_WRAPPER(cub::DeviceRadixSort::SortPairs,
84+
keys_in_,
85+
keys_out_,
86+
values_in,
87+
values_out,
88+
static_cast<int>(n),
89+
begin_bit,
90+
end_bit,
91+
dev_ctx.stream());
92+
}
93+
}
94+
95+
#define INSTANTIATE_SORT_PAIRS(key_t, value_size) \
96+
template void RadixSortPairsImpl<key_t, value_size>( \
97+
const phi::GPUContext&, \
98+
const key_t*, \
99+
key_t*, \
100+
const OpaqueTypeRadix<value_size>*, \
101+
OpaqueTypeRadix<value_size>*, \
102+
int64_t, \
103+
bool, \
104+
int64_t, \
105+
int64_t);
106+
107+
INSTANTIATE_SORT_PAIRS(int32_t, 1)
108+
INSTANTIATE_SORT_PAIRS(int32_t, 2)
109+
INSTANTIATE_SORT_PAIRS(int32_t, 4)
110+
INSTANTIATE_SORT_PAIRS(int64_t, 1)
111+
INSTANTIATE_SORT_PAIRS(int64_t, 2)
112+
INSTANTIATE_SORT_PAIRS(int64_t, 4)
113+
INSTANTIATE_SORT_PAIRS(int32_t, 8)
114+
INSTANTIATE_SORT_PAIRS(int64_t, 8)
115+
116+
#endif
117+
} // namespace funcs
118+
} // namespace phi

paddle/phi/kernels/funcs/radix_sort.h

Lines changed: 84 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,84 @@
1+
// Copyright (c) 2025 PaddlePaddle Authors. All Rights Reserved.
2+
//
3+
// Licensed under the Apache License, Version 2.0 (the "License");
4+
// you may not use this file except in compliance with the License.
5+
// You may obtain a copy of the License at
6+
//
7+
// http://www.apache.org/licenses/LICENSE-2.0
8+
//
9+
// Unless required by applicable law or agreed to in writing, software
10+
// distributed under the License is distributed on an "AS IS" BASIS,
11+
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
// See the License for the specific language governing permissions and
13+
// limitations under the License.
14+
15+
#pragma once
16+
#ifdef PADDLE_WITH_CUDA
17+
#include <cub/cub.cuh>
18+
#endif
19+
#include "paddle/phi/backends/gpu/gpu_context.h"
20+
#include "paddle/phi/core/dense_tensor.h"
21+
22+
namespace phi {
23+
namespace funcs {
24+
25+
#ifdef PADDLE_WITH_CUDA
26+
template <int kValueSize>
27+
struct OpaqueTypeRadix {
28+
uint8_t data[kValueSize];
29+
__device__ __host__ OpaqueTypeRadix() = default;
30+
};
31+
32+
template <typename key_t, int kValueSize>
33+
void RadixSortPairsImpl(const phi::GPUContext& dev_ctx,
34+
const key_t* keys_in,
35+
key_t* keys_out,
36+
const OpaqueTypeRadix<kValueSize>* values_in,
37+
OpaqueTypeRadix<kValueSize>* values_out,
38+
int64_t n,
39+
bool descending = false,
40+
int64_t begin_bit = 0,
41+
int64_t end_bit = sizeof(key_t) * 8);
42+
43+
template <typename key_t, typename value_t>
44+
void RadixSortPairs(const phi::GPUContext& dev_ctx,
45+
const key_t* keys_in,
46+
key_t* keys_out,
47+
const value_t* values_in,
48+
value_t* values_out,
49+
int64_t n,
50+
bool descending = false,
51+
int64_t begin_bit = 0,
52+
int64_t end_bit = sizeof(key_t) * 8) {
53+
PADDLE_ENFORCE_EQ(
54+
std::is_trivially_copyable<value_t>::value,
55+
true,
56+
phi::errors::InvalidArgument(
57+
"RadixSortPairs value type must be trivially copyable"));
58+
59+
using opaque_t = OpaqueTypeRadix<sizeof(value_t)>;
60+
PADDLE_ENFORCE_EQ(
61+
sizeof(value_t) <= 8 && (sizeof(value_t) & (sizeof(value_t) - 1)) == 0,
62+
true,
63+
phi::errors::InvalidArgument(
64+
"Unsupported value_t size (must be 1, 2, 4, or 8 bytes)"));
65+
PADDLE_ENFORCE_EQ(
66+
sizeof(value_t),
67+
alignof(value_t),
68+
phi::errors::InvalidArgument("Expected value_t to be size-aligned"));
69+
70+
RadixSortPairsImpl<key_t, sizeof(value_t)>(
71+
dev_ctx,
72+
keys_in,
73+
keys_out,
74+
reinterpret_cast<const opaque_t*>(values_in),
75+
reinterpret_cast<opaque_t*>(values_out),
76+
n,
77+
descending,
78+
begin_bit,
79+
end_bit);
80+
}
81+
82+
#endif
83+
} // namespace funcs
84+
} // namespace phi

0 commit comments

Comments
 (0)