Skip to content

Commit 0c29fe4

Browse files
authored
[Dy2St] Add lock free interface for scope (#73522)
1 parent c42f0a7 commit 0c29fe4

File tree

3 files changed

+20
-6
lines changed

3 files changed

+20
-6
lines changed

paddle/fluid/eager/to_static/run_program_op_node.h

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -165,27 +165,27 @@ static void ShareTensorsIntoScopeWithName(
165165
name == paddle::framework::kEmptyVarName) {
166166
continue;
167167
}
168-
auto *var = scope->Var(name);
168+
auto *var = scope->VarLockFree(name);
169169
// share tensor
170170
auto tensor_base = tensors[i].impl();
171171
if (phi::DenseTensor::classof(tensor_base.get())) {
172172
auto *dst_tensor = var->GetMutable<phi::DenseTensor>();
173-
auto t = std::dynamic_pointer_cast<phi::DenseTensor>(tensor_base);
173+
auto t = std::static_pointer_cast<phi::DenseTensor>(tensor_base);
174174
*dst_tensor = *t;
175175
} else if (phi::SelectedRows::classof(tensor_base.get())) {
176176
auto *dst_tensor = var->GetMutable<phi::SelectedRows>();
177-
auto t = std::dynamic_pointer_cast<phi::SelectedRows>(tensor_base);
177+
auto t = std::static_pointer_cast<phi::SelectedRows>(tensor_base);
178178
*dst_tensor = *t;
179179
} else if (paddle::framework::VariableRefArray::classof(
180180
tensor_base.get())) {
181181
auto *dst_tensor = var->GetMutable<paddle::framework::VariableRefArray>();
182-
auto t = std::dynamic_pointer_cast<paddle::framework::VariableRefArray>(
182+
auto t = std::static_pointer_cast<paddle::framework::VariableRefArray>(
183183
tensor_base);
184184
*dst_tensor = *t;
185185
} else if (phi::distributed::DistTensor::classof(tensor_base.get())) {
186186
auto *dst_tensor = var->GetMutable<phi::DenseTensor>();
187187
auto t =
188-
std::dynamic_pointer_cast<phi::distributed::DistTensor>(tensor_base);
188+
std::static_pointer_cast<phi::distributed::DistTensor>(tensor_base);
189189
*dst_tensor = t->value();
190190
} else {
191191
PADDLE_THROW(common::errors::InvalidArgument(
@@ -230,7 +230,7 @@ static void ShareTensorsFromScopeWithName(const std::vector<Tensor *> &tensors,
230230
// skip stop_gradient.
231231
continue;
232232
}
233-
auto *var = scope->FindVar(name);
233+
auto *var = scope->FindVarLockFree(name);
234234
PADDLE_ENFORCE_NOT_NULL(
235235
var,
236236
common::errors::NotFound("The output tensor %s is not in "

paddle/fluid/framework/scope.cc

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -54,6 +54,10 @@ Variable* Scope::Var(const std::string& name) {
5454
return ret;
5555
}
5656

57+
Variable* Scope::VarLockFree(const std::string& name) {
58+
return VarInternal(name);
59+
}
60+
5761
Variable* Scope::Var(std::string* name) {
5862
Variable* ret = nullptr;
5963
std::string new_name;
@@ -74,6 +78,10 @@ Variable* Scope::FindVar(const std::string& name) const {
7478
return FindVarInternal(name);
7579
}
7680

81+
Variable* Scope::FindVarLockFree(const std::string& name) const {
82+
return FindVarInternal(name);
83+
}
84+
7785
Variable* Scope::GetVar(const std::string& name) const {
7886
auto* var = FindVar(name);
7987
PADDLE_ENFORCE_NOT_NULL(

paddle/fluid/framework/scope.h

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,9 @@ class TEST_API Scope {
7474
/// Caller doesn't own the returned Variable.
7575
Variable* Var(const std::string& name);
7676

77+
/// Lock free version of Var
78+
Variable* VarLockFree(const std::string& name);
79+
7780
/// Create a variable with a scope-unique name.
7881
/// Caller doesn't own the returned Variable.
7982
Variable* Var(std::string* name = nullptr);
@@ -88,6 +91,9 @@ class TEST_API Scope {
8891
/// Caller doesn't own the returned Variable.
8992
Variable* FindVar(const std::string& name) const;
9093

94+
/// Lock free version of FindVar
95+
Variable* FindVarLockFree(const std::string& name) const;
96+
9197
// Get a variable in the scope or any of its ancestors. Enforce
9298
/// the returned Variable is not nullptr
9399
Variable* GetVar(const std::string& name) const;

0 commit comments

Comments
 (0)