Skip to content

Commit ba927b8

Browse files
authored
Merge pull request #10060 from jacquesqiao/update-variable-response
VariableResponse support deserialize var into local scope
2 parents 84ceffd + 65b3138 commit ba927b8

File tree

7 files changed

+32
-17
lines changed

7 files changed

+32
-17
lines changed

paddle/fluid/framework/scope.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -91,7 +91,7 @@ std::vector<std::string> Scope::LocalVarNames() const {
9191
return known_vars;
9292
}
9393

94-
void Scope::DeleteScope(Scope* scope) {
94+
void Scope::DeleteScope(Scope* scope) const {
9595
std::unique_lock<std::mutex> lock(mutex_);
9696
auto it = std::find(this->kids_.begin(), this->kids_.end(), scope);
9797
PADDLE_ENFORCE(it != this->kids_.end(), "Cannot find %p as kid scope", scope);

paddle/fluid/framework/scope.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ class Scope {
6363
/// Find the scope or an ancestor scope that contains the given variable.
6464
const Scope* FindScope(const Variable* var) const;
6565

66-
void DeleteScope(Scope* scope);
66+
void DeleteScope(Scope* scope) const;
6767

6868
/// Drop all kids scopes belonged to this scope.
6969
void DropKids();

paddle/fluid/operators/detail/grpc_server.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ class RequestSend final : public RequestBase {
6060
framework::Scope* scope, ReceivedQueue* queue,
6161
const platform::DeviceContext* dev_ctx)
6262
: RequestBase(service, cq, dev_ctx), queue_(queue), responder_(&ctx_) {
63-
request_.reset(new VariableResponse(scope, dev_ctx_));
63+
request_.reset(new VariableResponse(false, scope, dev_ctx_));
6464
int method_id = static_cast<int>(detail::GrpcMethod::kSendVariable);
6565
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
6666
cq_, cq_, this);
@@ -146,7 +146,7 @@ class RequestPrefetch final : public RequestBase {
146146
executor_(executor),
147147
program_(program),
148148
prefetch_ctx_(prefetch_ctx) {
149-
request_.reset(new VariableResponse(scope, dev_ctx_));
149+
request_.reset(new VariableResponse(false, scope, dev_ctx_));
150150
int method_id = static_cast<int>(detail::GrpcMethod::kPrefetchVariable);
151151
service_->RequestAsyncUnary(method_id, &ctx_, request_.get(), &responder_,
152152
cq_, cq_, this);

paddle/fluid/operators/detail/sendrecvop_utils.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,7 @@ void DeserializeFromByteBuffer(const ::grpc::ByteBuffer& msg,
186186
const platform::DeviceContext& ctx,
187187
const framework::Scope* scope,
188188
framework::Variable** var) {
189-
operators::detail::VariableResponse resp(scope, &ctx);
189+
operators::detail::VariableResponse resp(false, scope, &ctx);
190190
PADDLE_ENFORCE(resp.Parse(msg) == 0, "parse bytebuffer to tensor error!");
191191
*var = resp.GetVar();
192192
}

paddle/fluid/operators/detail/serde_test.cc

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -84,7 +84,7 @@ void RunSerdeTestSelectedRows(platform::Place place) {
8484
// operators::detail::DeserializeFromByteBuffer(msg, ctx, &var2);
8585
framework::Scope scope;
8686
scope.Var("myvar");
87-
operators::detail::VariableResponse resp(&scope, &ctx);
87+
operators::detail::VariableResponse resp(false, &scope, &ctx);
8888
EXPECT_EQ(resp.Parse(msg), 0);
8989

9090
framework::Variable* var2 = resp.GetVar();
@@ -171,7 +171,7 @@ void RunTestLodTensor(platform::Place place, int from_type = 0) {
171171
// deserialize zero-copy
172172
framework::Scope scope;
173173
scope.Var("myvar");
174-
operators::detail::VariableResponse resp(&scope, &ctx);
174+
operators::detail::VariableResponse resp(false, &scope, &ctx);
175175
if (from_type == 0) {
176176
EXPECT_EQ(resp.Parse(msg), 0);
177177
} else {

paddle/fluid/operators/detail/variable_response.cc

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -114,8 +114,7 @@ bool VariableResponse::CopyLodTensorData(
114114
::google::protobuf::io::CodedInputStream* input,
115115
const platform::DeviceContext& ctx, const framework::DDim& dims,
116116
int length) {
117-
auto var = scope_->FindVar(meta_.varname());
118-
auto* tensor = var->GetMutable<framework::LoDTensor>();
117+
auto* tensor = InitVar()->GetMutable<framework::LoDTensor>();
119118
tensor->Resize(dims);
120119

121120
framework::LoD lod;
@@ -151,8 +150,7 @@ bool VariableResponse::CopySelectRowsTensorData(
151150
::google::protobuf::io::CodedInputStream* input,
152151
const platform::DeviceContext& ctx, const framework::DDim& dims,
153152
int length) {
154-
auto var = scope_->FindVar(meta_.varname());
155-
auto* slr = var->GetMutable<framework::SelectedRows>();
153+
auto* slr = InitVar()->GetMutable<framework::SelectedRows>();
156154
slr->set_height(meta_.slr_height());
157155
auto* tensor = slr->mutable_value();
158156
tensor->Resize(dims);
@@ -174,8 +172,7 @@ bool VariableResponse::CopySelectRowsTensorData(
174172
bool VariableResponse::CopySelectRowsData(
175173
::google::protobuf::io::CodedInputStream* input,
176174
const platform::DeviceContext& ctx, int length) {
177-
auto var = scope_->FindVar(meta_.varname());
178-
auto* slr = var->GetMutable<framework::SelectedRows>();
175+
auto* slr = InitVar()->GetMutable<framework::SelectedRows>();
179176
slr->mutable_rows()->resize(length /
180177
framework::SizeOfType(typeid(int64_t))); // int64
181178
int64_t* rows_data = slr->mutable_rows()->data();

paddle/fluid/operators/detail/variable_response.h

Lines changed: 22 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -36,11 +36,13 @@ namespace detail {
3636

3737
class VariableResponse {
3838
public:
39-
VariableResponse(const framework::Scope* scope,
39+
VariableResponse(bool use_local_scope, const framework::Scope* scope,
4040
const platform::DeviceContext* dev_ctx)
41-
: scope_(scope), dev_ctx_(dev_ctx) {}
41+
: use_local_scope_(use_local_scope), scope_(scope), dev_ctx_(dev_ctx) {
42+
local_scope_ = &scope->NewScope();
43+
}
4244

43-
virtual ~VariableResponse() {}
45+
virtual ~VariableResponse() { scope_->DeleteScope(local_scope_); }
4446

4547
// return:
4648
// 0:ok.
@@ -54,11 +56,25 @@ class VariableResponse {
5456
// other: number of error field.
5557
int Parse(const ::grpc::ByteBuffer& byte_buffer);
5658

59+
const framework::Scope& GetLocalScope() const { return *local_scope_; }
60+
5761
inline std::string Varname() { return meta_.varname(); }
5862
inline std::string OutVarname() { return meta_.out_varname(); }
5963

6064
// should call parse first.
61-
framework::Variable* GetVar() { return scope_->FindVar(meta_.varname()); }
65+
framework::Variable* GetVar() {
66+
return local_scope_->FindVar(meta_.varname());
67+
}
68+
69+
framework::Variable* InitVar() {
70+
if (use_local_scope_) {
71+
bool has_var = (scope_->FindVar(meta_.varname()) != nullptr);
72+
PADDLE_ENFORCE(has_var);
73+
return local_scope_->Var(meta_.varname());
74+
} else {
75+
return scope_->FindVar(meta_.varname());
76+
}
77+
}
6278

6379
private:
6480
bool CopySelectRowsTensorData(::google::protobuf::io::CodedInputStream* input,
@@ -73,7 +89,9 @@ class VariableResponse {
7389
const framework::DDim& dims, int length);
7490

7591
private:
92+
bool use_local_scope_ = false;
7693
const framework::Scope* scope_;
94+
framework::Scope* local_scope_ = nullptr;
7795
const platform::DeviceContext* dev_ctx_;
7896
// only Skeleton
7997
sendrecv::VariableMessage meta_;

0 commit comments

Comments
 (0)