Skip to content

Commit a9d7922

Browse files
authored
Merge pull request #13905 from PaddlePaddle/revert-13872-fix2
Revert "Revert "Revert "Make variable::GetMutable robust"""
2 parents c26f2b2 + 288a112 commit a9d7922

File tree

8 files changed

+12
-36
lines changed

8 files changed

+12
-36
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
6666
} else if (var_type == proto::VarType::FETCH_LIST) {
6767
var->GetMutable<FeedFetchList>();
6868
} else if (var_type == proto::VarType::STEP_SCOPES) {
69-
var->GetMutable<std::vector<framework::Scope*>>();
69+
var->GetMutable<std::vector<framework::Scope>>();
7070
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
7171
var->GetMutable<LoDRankTable>();
7272
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {

paddle/fluid/framework/feed_fetch_method.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,8 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input,
2727
// be created.
2828
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index;
2929
Variable* g_feed_value = scope->Var(var_name);
30-
auto& feed_inputs = *(g_feed_value->GetMutable<FeedFetchList>());
30+
auto& feed_inputs =
31+
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
3132
if (index >= feed_inputs.size()) {
3233
feed_inputs.resize(index + 1);
3334
}

paddle/fluid/framework/naive_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ static void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
3737
} else if (var_type == proto::VarType::FETCH_LIST) {
3838
var->GetMutable<FeedFetchList>();
3939
} else if (var_type == proto::VarType::STEP_SCOPES) {
40-
var->GetMutable<std::vector<framework::Scope *>>();
40+
var->GetMutable<std::vector<framework::Scope>>();
4141
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
4242
var->GetMutable<LoDRankTable>();
4343
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {

paddle/fluid/framework/var_desc.h

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -59,7 +59,6 @@ class VarDesc {
5959
public:
6060
explicit VarDesc(const std::string &name) {
6161
desc_.set_name(name);
62-
// TODO(paddle-dev): Why default to lodtensor.
6362
desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
6463
}
6564

paddle/fluid/framework/variable.h

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -38,12 +38,8 @@ class Variable {
3838

3939
template <typename T>
4040
T* GetMutable() {
41-
if (!holder_) {
41+
if (!IsType<T>()) {
4242
holder_.reset(new PlaceholderImpl<T>(new T()));
43-
} else {
44-
PADDLE_ENFORCE(IsType<T>(),
45-
"Variable must be type %s, the holding type is %s",
46-
typeid(T).name(), holder_->Type().name());
4743
}
4844
return static_cast<T*>(holder_->Ptr());
4945
}

paddle/fluid/framework/variable_test.cc

Lines changed: 5 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -33,10 +33,9 @@ TEST(Variable, GetMutable) {
3333
const Tensor& tt = v->Get<Tensor>();
3434
EXPECT_EQ(1234, tt.content_);
3535

36-
try {
37-
v->GetMutable<std::string>();
38-
} catch (std::exception& e) {
39-
return;
40-
}
41-
EXPECT_TRUE(false);
36+
std::string* s = v->GetMutable<std::string>();
37+
*s = "hello";
38+
39+
const std::string& ss = v->Get<std::string>();
40+
EXPECT_EQ("hello", ss);
4241
}

paddle/fluid/operators/parallel_do_op.cc

Lines changed: 1 addition & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -397,30 +397,11 @@ class ParallelDoGradOpShapeInference : public framework::InferShapeBase {
397397
}
398398
};
399399

400-
class ParallelDoGradOpVarTypeInference : public framework::VarTypeInference {
401-
public:
402-
void operator()(const framework::OpDesc &op_desc,
403-
framework::BlockDesc *block) const override {
404-
framework::BlockDesc *sub_block =
405-
boost::get<framework::BlockDesc *>(op_desc.GetAttr(kParallelBlock));
406-
for (auto &out_vars : op_desc.Outputs()) {
407-
for (auto &out_var : out_vars.second) {
408-
auto &var = block->FindRecursiveOrCreateVar(out_var);
409-
auto sub_var = sub_block->FindRecursiveOrCreateVar(out_var);
410-
if (sub_var.GetType() != var.GetType()) {
411-
var.SetType(sub_var.GetType());
412-
}
413-
}
414-
}
415-
}
416-
};
417-
418400
} // namespace operators
419401
} // namespace paddle
420402

421403
REGISTER_OPERATOR(parallel_do, paddle::operators::ParallelDoOp,
422404
paddle::operators::ParallelDoOpProtoMaker,
423405
paddle::operators::ParallelDoGradOpDescMaker);
424406
REGISTER_OPERATOR(parallel_do_grad, paddle::operators::ParallelDoGradOp,
425-
paddle::operators::ParallelDoGradOpShapeInference,
426-
paddle::operators::ParallelDoGradOpVarTypeInference);
407+
paddle::operators::ParallelDoGradOpShapeInference);

python/paddle/fluid/layers/control_flow.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1267,7 +1267,7 @@ def complete(self):
12671267
]
12681268

12691269
step_scope = parent_block.create_var(
1270-
name='control_scope', type=core.VarDesc.VarType.STEP_SCOPES)
1270+
type=core.VarDesc.VarType.STEP_SCOPES)
12711271
parent_block.append_op(
12721272
type='conditional_block',
12731273
inputs={

0 commit comments

Comments
 (0)