@@ -165,27 +165,27 @@ static void ShareTensorsIntoScopeWithName(
165
165
name == paddle::framework::kEmptyVarName ) {
166
166
continue ;
167
167
}
168
- auto *var = scope->Var (name);
168
+ auto *var = scope->VarLockFree (name);
169
169
// share tensor
170
170
auto tensor_base = tensors[i].impl ();
171
171
if (phi::DenseTensor::classof (tensor_base.get ())) {
172
172
auto *dst_tensor = var->GetMutable <phi::DenseTensor>();
173
- auto t = std::dynamic_pointer_cast <phi::DenseTensor>(tensor_base);
173
+ auto t = std::static_pointer_cast <phi::DenseTensor>(tensor_base);
174
174
*dst_tensor = *t;
175
175
} else if (phi::SelectedRows::classof (tensor_base.get ())) {
176
176
auto *dst_tensor = var->GetMutable <phi::SelectedRows>();
177
- auto t = std::dynamic_pointer_cast <phi::SelectedRows>(tensor_base);
177
+ auto t = std::static_pointer_cast <phi::SelectedRows>(tensor_base);
178
178
*dst_tensor = *t;
179
179
} else if (paddle::framework::VariableRefArray::classof (
180
180
tensor_base.get ())) {
181
181
auto *dst_tensor = var->GetMutable <paddle::framework::VariableRefArray>();
182
- auto t = std::dynamic_pointer_cast <paddle::framework::VariableRefArray>(
182
+ auto t = std::static_pointer_cast <paddle::framework::VariableRefArray>(
183
183
tensor_base);
184
184
*dst_tensor = *t;
185
185
} else if (phi::distributed::DistTensor::classof (tensor_base.get ())) {
186
186
auto *dst_tensor = var->GetMutable <phi::DenseTensor>();
187
187
auto t =
188
- std::dynamic_pointer_cast <phi::distributed::DistTensor>(tensor_base);
188
+ std::static_pointer_cast <phi::distributed::DistTensor>(tensor_base);
189
189
*dst_tensor = t->value ();
190
190
} else {
191
191
PADDLE_THROW (common::errors::InvalidArgument (
@@ -230,7 +230,7 @@ static void ShareTensorsFromScopeWithName(const std::vector<Tensor *> &tensors,
230
230
// skip stop_gradient.
231
231
continue ;
232
232
}
233
- auto *var = scope->FindVar (name);
233
+ auto *var = scope->FindVarLockFree (name);
234
234
PADDLE_ENFORCE_NOT_NULL (
235
235
var,
236
236
common::errors::NotFound (" The output tensor %s is not in "
0 commit comments