Skip to content

Commit 342e436

Browse files
committed
Make Var::GetMutable robust
test=develop
1 parent af91d41 commit 342e436

File tree

7 files changed

+35
-11
lines changed

7 files changed

+35
-11
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
101101
} else if (var_type == proto::VarType::FETCH_LIST) {
102102
var->GetMutable<FeedFetchList>();
103103
} else if (var_type == proto::VarType::STEP_SCOPES) {
104-
var->GetMutable<std::vector<framework::Scope>>();
104+
var->GetMutable<std::vector<framework::Scope*>>();
105105
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
106106
var->GetMutable<LoDRankTable>();
107107
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {

paddle/fluid/framework/feed_fetch_method.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input,
2727
// be created.
2828
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index;
2929
Variable* g_feed_value = scope->Var(var_name);
30-
auto& feed_inputs =
31-
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
30+
auto& feed_inputs = *(g_feed_value->GetMutable<FeedFetchList>());
3231
if (index >= feed_inputs.size()) {
3332
feed_inputs.resize(index + 1);
3433
}

paddle/fluid/framework/naive_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ static void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
3737
} else if (var_type == proto::VarType::FETCH_LIST) {
3838
var->GetMutable<FeedFetchList>();
3939
} else if (var_type == proto::VarType::STEP_SCOPES) {
40-
var->GetMutable<std::vector<framework::Scope>>();
40+
var->GetMutable<std::vector<framework::Scope *>>();
4141
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
4242
var->GetMutable<LoDRankTable>();
4343
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {

paddle/fluid/framework/var_desc.h

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ class VarDesc {
5959
public:
6060
explicit VarDesc(const std::string &name) {
6161
desc_.set_name(name);
62+
// TODO(paddle-dev): Why default to lodtensor.
6263
desc_.mutable_type()->set_type(proto::VarType::LOD_TENSOR);
6364
}
6465

paddle/fluid/framework/variable.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,12 @@ class Variable {
3838

3939
template <typename T>
4040
T* GetMutable() {
41-
if (!IsType<T>()) {
41+
if (!holder_) {
4242
holder_.reset(new PlaceholderImpl<T>(new T()));
43+
} else {
44+
PADDLE_ENFORCE(IsType<T>(),
45+
"Variable must be type %s, the holding type is %s",
46+
typeid(T).name(), holder_->Type().name());
4347
}
4448
return static_cast<T*>(holder_->Ptr());
4549
}

paddle/fluid/framework/variable_test.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ TEST(Variable, GetMutable) {
3333
const Tensor& tt = v->Get<Tensor>();
3434
EXPECT_EQ(1234, tt.content_);
3535

36-
std::string* s = v->GetMutable<std::string>();
37-
*s = "hello";
38-
39-
const std::string& ss = v->Get<std::string>();
40-
EXPECT_EQ("hello", ss);
36+
try {
37+
v->GetMutable<std::string>();
38+
} catch (std::exception& e) {
39+
return;
40+
}
41+
EXPECT_TRUE(false);
4142
}

paddle/fluid/operators/parallel_do_op.cc

Lines changed: 20 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -397,11 +397,30 @@ class ParallelDoGradOpShapeInference : public framework::InferShapeBase {
397397
}
398398
};
399399

400+
class ParallelDoGradOpVarTypeInference : public framework::VarTypeInference {
401+
public:
402+
void operator()(const framework::OpDesc &op_desc,
403+
framework::BlockDesc *block) const override {
404+
framework::BlockDesc *sub_block =
405+
boost::get<framework::BlockDesc *>(op_desc.GetAttr(kParallelBlock));
406+
for (auto &out_vars : op_desc.Outputs()) {
407+
for (auto &out_var : out_vars.second) {
408+
auto &var = block->FindRecursiveOrCreateVar(out_var);
409+
auto sub_var = sub_block->FindRecursiveOrCreateVar(out_var);
410+
if (sub_var.GetType() != var.GetType()) {
411+
var.SetType(sub_var.GetType());
412+
}
413+
}
414+
}
415+
}
416+
};
417+
400418
} // namespace operators
401419
} // namespace paddle
402420

403421
REGISTER_OPERATOR(parallel_do, paddle::operators::ParallelDoOp,
404422
paddle::operators::ParallelDoOpProtoMaker,
405423
paddle::operators::ParallelDoGradOpDescMaker);
406424
REGISTER_OPERATOR(parallel_do_grad, paddle::operators::ParallelDoGradOp,
407-
paddle::operators::ParallelDoGradOpShapeInference);
425+
paddle::operators::ParallelDoGradOpShapeInference,
426+
paddle::operators::ParallelDoGradOpVarTypeInference);

0 commit comments

Comments
 (0)