Skip to content

Commit b0cf1fe

Browse files
authored
Merge pull request #12430 from jacquesqiao/add-test-for-split-ids-op
Add test for split ids op
2 parents f372f27 + 236fc1b commit b0cf1fe

File tree

2 files changed

+63
-4
lines changed

2 files changed

+63
-4
lines changed

paddle/fluid/operators/split_ids_op.h

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ limitations under the License. */
1414

1515
#pragma once
1616

17+
#include <unordered_map>
1718
#include <vector>
1819
#include "paddle/fluid/framework/op_registry.h"
1920
#include "paddle/fluid/operators/math/selected_rows_functor.h"
@@ -67,10 +68,15 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
6768
const auto &ids_rows = ids_selected_rows->rows();
6869
auto outs = ctx.MultiOutput<framework::SelectedRows>("Out");
6970
const size_t shard_num = outs.size();
71+
for (auto &out : outs) {
72+
out->mutable_rows()->clear();
73+
}
7074
// get rows for outputs
71-
for (auto &id : ids_rows) {
72-
size_t shard_id = static_cast<size_t>(id) % shard_num;
73-
outs[shard_id]->mutable_rows()->push_back(id);
75+
std::unordered_map<int64_t, size_t> id_to_index;
76+
for (size_t i = 0; i < ids_rows.size(); ++i) {
77+
id_to_index[ids_rows[i]] = i;
78+
size_t shard_id = static_cast<size_t>(ids_rows[i]) % shard_num;
79+
outs[shard_id]->mutable_rows()->push_back(ids_rows[i]);
7480
}
7581

7682
int64_t row_width = ids_dims[1];
@@ -80,7 +86,8 @@ class SplitIdsOpKernel : public framework::OpKernel<T> {
8086
{static_cast<int64_t>(out->rows().size()), row_width});
8187
T *output = out->mutable_value()->mutable_data<T>(ddim, place);
8288
for (int64_t i = 0; i < ddim[0]; ++i) {
83-
memcpy(output + i * row_width, ids + out->rows()[i] * row_width,
89+
memcpy(output + i * row_width,
90+
ids + id_to_index[out->rows()[i]] * row_width,
8491
row_width * sizeof(T));
8592
}
8693
}

python/paddle/fluid/tests/unittests/test_split_ids_op.py

Lines changed: 52 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import unittest
1616
import numpy as np
1717
from op_test import OpTest
18+
import paddle.fluid.core as core
19+
from paddle.fluid.op import Operator
1820

1921

2022
class TestSplitIdsOp(OpTest):
@@ -31,5 +33,55 @@ def test_check_output(self):
3133
self.check_output()
3234

3335

36+
class TestSpliteIds(unittest.TestCase):
37+
def get_places(self):
38+
places = [core.CPUPlace()]
39+
return places
40+
41+
def test_check_output(self):
42+
for place in self.get_places():
43+
self.check_with_place(place)
44+
45+
def check_with_place(self, place):
46+
scope = core.Scope()
47+
rows = [0, 5, 7, 4, 9]
48+
height = 20
49+
row_numel = 2
50+
51+
# initialize input variable X
52+
x = scope.var('X').get_selected_rows()
53+
x.set_rows(rows)
54+
x.set_height(height)
55+
np_array = np.ones((len(rows), row_numel)).astype("float32")
56+
for i in range(len(rows)):
57+
for j in range(row_numel):
58+
np_array[i, j] = rows[i] + j
59+
x_tensor = x.get_tensor()
60+
x_tensor.set(np_array, place)
61+
62+
outs_name = ["out%d" % i for i in xrange(3)]
63+
outs = [
64+
scope.var(var_name).get_selected_rows() for var_name in outs_name
65+
]
66+
67+
# expected output selected rows
68+
expected_out_rows = [[0, 9], [7, 4], [5]]
69+
70+
op = Operator("split_ids", Ids="X", Out=outs_name)
71+
72+
for _ in range(3):
73+
op.run(scope, place)
74+
75+
for i in range(len(outs)):
76+
expected_rows = expected_out_rows[i]
77+
self.assertEqual(outs[i].rows(), expected_rows)
78+
for j in range(len(expected_rows)):
79+
row = expected_rows[j]
80+
self.assertAlmostEqual(
81+
float(row), np.array(outs[i].get_tensor())[j, 0])
82+
self.assertAlmostEqual(
83+
float(row + 1), np.array(outs[i].get_tensor())[j, 1])
84+
85+
3486
if __name__ == '__main__':
3587
unittest.main()

0 commit comments

Comments
 (0)