Skip to content

Commit f4dec5c

Browse files
authored
Check collective server's data. (#15449)
1 parent 58727e8 commit f4dec5c

File tree

1 file changed

+16
-1
lines changed

1 file changed

+16
-1
lines changed

paddle/fluid/operators/distributed/collective_server_test.cc

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@ limitations under the License. */
2020
#include "paddle/fluid/framework/block_desc.h"
2121
#include "paddle/fluid/framework/op_registry.h"
2222
#include "paddle/fluid/framework/operator.h"
23+
#include "paddle/fluid/framework/tensor_util.h"
2324

2425
#include "paddle/fluid/operators/distributed/collective_client.h"
2526
#include "paddle/fluid/operators/distributed/collective_server.h"
@@ -57,7 +58,7 @@ std::unique_ptr<framework::Scope> GenerateVars(platform::Place place) {
5758
auto* tensor = slr->mutable_value();
5859
auto* rows = slr->mutable_rows();
5960

60-
tensor->Resize(framework::make_ddim({20000, 1024}));
61+
tensor->Resize(framework::make_ddim({3, 1024}));
6162
tensor->mutable_data<float>(place);
6263

6364
paddle::operators::math::set_constant(ctx, tensor, 32.7);
@@ -80,6 +81,20 @@ void Gather(const std::vector<distributed::RemoteVar>& vars,
8081
std::vector<const framework::SelectedRows*> dst;
8182
client->Gather(vars, &dst, *dev_ctx, scope);
8283
std::cout << "dst:" << distributed::GetSelectedRowsInfo(*dst[0]);
84+
dev_ctx->Wait();
85+
86+
ASSERT_EQ(dst[0]->value().dims(), framework::make_ddim({3, 1024}));
87+
ASSERT_EQ(dst[0]->height(), 20000);
88+
ASSERT_EQ(dst[0]->rows().size(), static_cast<size_t>(3));
89+
for (int i = 0; i < 3; i++) {
90+
ASSERT_EQ(dst[0]->rows()[i], i);
91+
}
92+
93+
std::vector<float> vec;
94+
TensorToVector(dst[0]->value(), *dev_ctx, &vec);
95+
for (size_t i = 0; i < 3 * 1024; i++) {
96+
ASSERT_FLOAT_EQ(vec[i], 32.7);
97+
}
8398
}
8499

85100
TEST(CollectiveServer, GPU) {

0 commit comments

Comments
 (0)