File tree Expand file tree Collapse file tree 1 file changed +3
-6
lines changed Expand file tree Collapse file tree 1 file changed +3
-6
lines changed Original file line number Diff line number Diff line change @@ -30,19 +30,16 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
30
30
PADDLE_THROW (" SplitIds do not support GPU kernel" );
31
31
}
32
32
33
- const auto * ids_t = ctx.Input <framework::LoDTensor>(" Ids" );
34
- auto & ids_dims = ids_t -> dims ();
33
+ auto & ids_dims = ctx.Input <framework::LoDTensor>(" Ids" )-> dims ( );
34
+ const T* ids = ctx. Input <framework::LoDTensor>( " Ids " )-> data <T> ();
35
35
auto outs = ctx.MultiOutput <framework::LoDTensor>(" Out" );
36
-
37
- const T* ids = ids_t ->data <T>();
38
-
39
36
const size_t shard_num = outs.size ();
40
37
41
38
std::vector<std::vector<T>> out_ids;
42
39
out_ids.resize (outs.size ());
43
40
44
41
// split id by their shard_num.
45
- for (size_t i = 0 ; i < ids_dims[0 ]; ++i) {
42
+ for (int i = 0 ; i < ids_dims[0 ]; ++i) {
46
43
T id = ids[i];
47
44
size_t shard_id = static_cast <size_t >(id) % shard_num;
48
45
out_ids[shard_id].push_back (id);
You can’t perform that action at this time.
0 commit comments