@@ -20,6 +20,7 @@ limitations under the License. */
20
20
#include " paddle/fluid/framework/block_desc.h"
21
21
#include " paddle/fluid/framework/op_registry.h"
22
22
#include " paddle/fluid/framework/operator.h"
23
+ #include " paddle/fluid/framework/tensor_util.h"
23
24
24
25
#include " paddle/fluid/operators/distributed/collective_client.h"
25
26
#include " paddle/fluid/operators/distributed/collective_server.h"
@@ -57,7 +58,7 @@ std::unique_ptr<framework::Scope> GenerateVars(platform::Place place) {
57
58
auto * tensor = slr->mutable_value ();
58
59
auto * rows = slr->mutable_rows ();
59
60
60
- tensor->Resize (framework::make_ddim ({20000 , 1024 }));
61
+ tensor->Resize (framework::make_ddim ({3 , 1024 }));
61
62
tensor->mutable_data <float >(place);
62
63
63
64
paddle::operators::math::set_constant (ctx, tensor, 32.7 );
@@ -80,6 +81,20 @@ void Gather(const std::vector<distributed::RemoteVar>& vars,
80
81
std::vector<const framework::SelectedRows*> dst;
81
82
client->Gather (vars, &dst, *dev_ctx, scope);
82
83
std::cout << " dst:" << distributed::GetSelectedRowsInfo (*dst[0 ]);
84
+ dev_ctx->Wait ();
85
+
86
+ ASSERT_EQ (dst[0 ]->value ().dims (), framework::make_ddim ({3 , 1024 }));
87
+ ASSERT_EQ (dst[0 ]->height (), 20000 );
88
+ ASSERT_EQ (dst[0 ]->rows ().size (), static_cast <size_t >(3 ));
89
+ for (int i = 0 ; i < 3 ; i++) {
90
+ ASSERT_EQ (dst[0 ]->rows ()[i], i);
91
+ }
92
+
93
+ std::vector<float > vec;
94
+ TensorToVector (dst[0 ]->value (), *dev_ctx, &vec);
95
+ for (size_t i = 0 ; i < 3 * 1024 ; i++) {
96
+ ASSERT_FLOAT_EQ (vec[i], 32.7 );
97
+ }
83
98
}
84
99
85
100
TEST (CollectiveServer, GPU) {
0 commit comments