@@ -45,43 +45,45 @@ framework::BlockDesc* AppendPrefetchBlcok(framework::ProgramDesc* program) {
45
45
op->SetOutput (" Out" , {" out" });
46
46
47
47
auto & out = *root_block->Var (" out" );
48
- out.SetType (framework::proto::VarType::LOD_TENSOR );
48
+ out.SetType (framework::proto::VarType::SELECTED_ROWS );
49
49
out.SetShape ({10 , 10 });
50
50
51
51
return block;
52
52
}
53
53
54
54
void CreateVarsOnScope (framework::Scope* scope, platform::CPUPlace* place) {
55
55
auto w_var = scope->Var (" w" );
56
- auto w = w_var->GetMutable <framework::LoDTensor>();
57
- w->Resize ({10 , 10 });
58
- w->mutable_data <float >(*place);
56
+ w_var->GetMutable <framework::SelectedRows>();
59
57
60
58
auto out_var = scope->Var (" out" );
61
- auto out = out_var->GetMutable <framework::LoDTensor>();
62
- out->Resize ({5 , 10 });
63
- out->mutable_data <float >(*place);
59
+ out_var->GetMutable <framework::SelectedRows>();
64
60
65
61
auto ids_var = scope->Var (" ids" );
66
- auto ids = ids_var->GetMutable <framework::LoDTensor>();
67
- ids->Resize ({5 , 1 });
62
+ ids_var->GetMutable <framework::SelectedRows>();
68
63
}
69
64
70
- void InitTensorsOnClient (framework::Scope* scope, platform::CPUPlace* place) {
65
+ void InitTensorsOnClient (framework::Scope* scope, platform::CPUPlace* place,
66
+ int64_t rows_numel) {
71
67
CreateVarsOnScope (scope, place);
72
- auto ids = scope->Var (" ids" )->GetMutable <framework::LoDTensor >();
73
- auto ptr = ids-> mutable_data < int64_t >(*place );
74
- for (int64_t i = 0 ; i < ids-> numel () ; ++i) {
75
- ptr[i] = i * 2 ;
76
- }
68
+ auto ids_var = scope->Var (" ids" )->GetMutable <framework::SelectedRows >();
69
+ auto rows = ids_var-> mutable_rows ( );
70
+ for (int64_t i = 0 ; i < rows_numel ; ++i) rows-> push_back (i * 2 );
71
+ ids_var-> mutable_value ()-> Resize ({rows_numel, 1 }) ;
72
+ ids_var-> mutable_value ()-> mutable_data < float >(*place);
77
73
}
78
74
79
- void InitTensorsOnServer (framework::Scope* scope, platform::CPUPlace* place) {
75
+ void InitTensorsOnServer (framework::Scope* scope, platform::CPUPlace* place,
76
+ int64_t rows_numel) {
80
77
CreateVarsOnScope (scope, place);
81
- auto w_var = scope->Var (" w" );
82
- auto w = w_var->GetMutable <framework::LoDTensor>();
83
- auto ptr = w->mutable_data <float >(*place);
84
- for (int64_t i = 0 ; i < w->numel (); ++i) {
78
+ auto w = scope->Var (" w" )->GetMutable <framework::SelectedRows>();
79
+ auto rows = w->mutable_rows ();
80
+ for (int64_t i = 0 ; i < rows_numel; ++i) rows->push_back (i);
81
+ auto w_value = w->mutable_value ();
82
+ w_value->Resize ({rows_numel, 10 });
83
+
84
+ auto ptr = w_value->mutable_data <float >(*place);
85
+
86
+ for (int64_t i = 0 ; i < w_value->numel (); ++i) {
85
87
ptr[i] = static_cast <float >(i / 10 );
86
88
}
87
89
}
@@ -94,7 +96,7 @@ void StartServer(const std::string& endpoint) {
94
96
framework::Executor exe (place);
95
97
platform::CPUDeviceContext ctx (place);
96
98
auto * block = AppendPrefetchBlcok (&program);
97
- InitTensorsOnServer (&scope, &place);
99
+ InitTensorsOnServer (&scope, &place, 10 );
98
100
99
101
rpc_service_->SetProgram (&program);
100
102
rpc_service_->SetPrefetchBlkdId (block->ID ());
@@ -107,15 +109,14 @@ void StartServer(const std::string& endpoint) {
107
109
108
110
TEST (PREFETCH, CPU) {
109
111
// start up a server instance backend
110
- // TODO(Yancey1989): Need to start a server with optimize blocks and
111
- // prefetch blocks.
112
112
std::thread server_thread (StartServer, " 127.0.0.1:8889" );
113
113
sleep (2 );
114
114
framework::Scope scope;
115
115
platform::CPUPlace place;
116
116
platform::CPUDeviceContext ctx (place);
117
117
// create var on local scope
118
- InitTensorsOnClient (&scope, &place);
118
+ int64_t rows_numel = 5 ;
119
+ InitTensorsOnClient (&scope, &place, rows_numel);
119
120
std::string in_var_name (" ids" );
120
121
std::string out_var_name (" out" );
121
122
@@ -124,18 +125,16 @@ TEST(PREFETCH, CPU) {
124
125
out_var_name);
125
126
client.Wait ();
126
127
127
- auto out_var = scope.Var (out_var_name);
128
- auto out = out_var->Get <framework::LoDTensor>();
128
+ // auto out_var = scope.Var(out_var_name);
129
+ auto var = scope.Var (out_var_name);
130
+ auto value = var->GetMutable <framework::SelectedRows>()->value ();
131
+ auto ptr = value.mutable_data <float >(place);
129
132
130
- auto out_ptr = out.data <float >();
131
133
rpc_service_->ShutDown ();
132
134
server_thread.join ();
133
135
rpc_service_.reset (nullptr );
134
136
135
- EXPECT_EQ (out.dims ().size (), 2 );
136
- EXPECT_EQ (out_ptr[0 ], static_cast <float >(0 ));
137
- EXPECT_EQ (out_ptr[0 + 1 * out.dims ()[1 ]], static_cast <float >(2 ));
138
- EXPECT_EQ (out_ptr[0 + 2 * out.dims ()[1 ]], static_cast <float >(4 ));
139
- EXPECT_EQ (out_ptr[0 + 3 * out.dims ()[1 ]], static_cast <float >(6 ));
140
- EXPECT_EQ (out_ptr[0 + 4 * out.dims ()[1 ]], static_cast <float >(8 ));
137
+ for (int64_t i = 0 ; i < rows_numel; ++i) {
138
+ EXPECT_EQ (ptr[0 + i * value.dims ()[1 ]], static_cast <float >(i * 2 ));
139
+ }
141
140
}
0 commit comments