Skip to content

Commit ae8b1c3

Browse files
authored
Merge pull request #13821 from panyx0718/fix
Make variable::GetMutable robust
2 parents 9b3e0df + 2285066 commit ae8b1c3

File tree

6 files changed

+16
-24
lines changed

6 files changed

+16
-24
lines changed

paddle/fluid/framework/executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -66,7 +66,7 @@ void InitializeVariable(Variable* var, proto::VarType::Type var_type) {
6666
} else if (var_type == proto::VarType::FETCH_LIST) {
6767
var->GetMutable<FeedFetchList>();
6868
} else if (var_type == proto::VarType::STEP_SCOPES) {
69-
var->GetMutable<std::vector<framework::Scope>>();
69+
var->GetMutable<std::vector<framework::Scope*>>();
7070
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
7171
var->GetMutable<LoDRankTable>();
7272
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {

paddle/fluid/framework/feed_fetch_method.cc

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,8 +27,7 @@ void SetFeedVariable(Scope* scope, const LoDTensor& input,
2727
// be created.
2828
VLOG(3) << "SetFeedVariable name=" << var_name << " index=" << index;
2929
Variable* g_feed_value = scope->Var(var_name);
30-
auto& feed_inputs =
31-
*(g_feed_value->GetMutable<std::vector<paddle::framework::LoDTensor>>());
30+
auto& feed_inputs = *(g_feed_value->GetMutable<FeedFetchList>());
3231
if (index >= feed_inputs.size()) {
3332
feed_inputs.resize(index + 1);
3433
}

paddle/fluid/framework/naive_executor.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ static void InitializeVariable(Variable *var, proto::VarType::Type var_type) {
3737
} else if (var_type == proto::VarType::FETCH_LIST) {
3838
var->GetMutable<FeedFetchList>();
3939
} else if (var_type == proto::VarType::STEP_SCOPES) {
40-
var->GetMutable<std::vector<framework::Scope>>();
40+
var->GetMutable<std::vector<framework::Scope *>>();
4141
} else if (var_type == proto::VarType::LOD_RANK_TABLE) {
4242
var->GetMutable<LoDRankTable>();
4343
} else if (var_type == proto::VarType::LOD_TENSOR_ARRAY) {

paddle/fluid/framework/variable.h

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -38,8 +38,12 @@ class Variable {
3838

3939
template <typename T>
4040
T* GetMutable() {
41-
if (!IsType<T>()) {
41+
if (!holder_) {
4242
holder_.reset(new PlaceholderImpl<T>(new T()));
43+
} else {
44+
PADDLE_ENFORCE(IsType<T>(),
45+
"Variable must be type %s, the holding type is %s",
46+
typeid(T).name(), holder_->Type().name());
4347
}
4448
return static_cast<T*>(holder_->Ptr());
4549
}

paddle/fluid/framework/variable_test.cc

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,9 +33,10 @@ TEST(Variable, GetMutable) {
3333
const Tensor& tt = v->Get<Tensor>();
3434
EXPECT_EQ(1234, tt.content_);
3535

36-
std::string* s = v->GetMutable<std::string>();
37-
*s = "hello";
38-
39-
const std::string& ss = v->Get<std::string>();
40-
EXPECT_EQ("hello", ss);
36+
try {
37+
v->GetMutable<std::string>();
38+
} catch (std::exception& e) {
39+
return;
40+
}
41+
EXPECT_TRUE(false);
4142
}

python/paddle/fluid/tests/book/test_word2vec.py

Lines changed: 2 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
import paddle
1818
import paddle.fluid as fluid
1919
from paddle.fluid.layers.device import get_places
20-
from paddle.fluid.layers.control_flow import ParallelDo
2120
import unittest
2221
import os
2322
import numpy as np
@@ -84,18 +83,7 @@ def __network__(words):
8483
avg_cost, predict_word = __network__(
8584
[first_word, second_word, third_word, forth_word, next_word])
8685
else:
87-
places = get_places()
88-
pd = ParallelDo(places)
89-
with pd.do():
90-
avg_cost, predict_word = __network__(
91-
list(
92-
map(pd.read_input, [
93-
first_word, second_word, third_word, forth_word,
94-
next_word
95-
])))
96-
pd.write_output(avg_cost)
97-
98-
avg_cost = fluid.layers.mean(pd())
86+
raise ValueError('is_parallel=True not implemented')
9987

10088
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
10189
sgd_optimizer.minimize(avg_cost)
@@ -262,7 +250,7 @@ def __impl__(*args, **kwargs):
262250

263251
for use_cuda in (False, True):
264252
for is_sparse in (False, True):
265-
for is_parallel in (False, True):
253+
for is_parallel in (False, ): # TODO(paddle-dev): Add parallel test.
266254
inject_test_method(use_cuda, is_sparse, is_parallel)
267255

268256
if __name__ == '__main__':

0 commit comments

Comments
 (0)