Skip to content

Commit ae39709

Browse files
committed
Polish code
1 parent 55d7f55 commit ae39709

File tree

16 files changed

+141
-134
lines changed

16 files changed

+141
-134
lines changed

paddle/fluid/framework/op_desc.cc

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -208,49 +208,44 @@ void OpDesc::SetAttr(const std::string &name, const Attribute &v) {
208208
proto::AttrType attr_type = static_cast<proto::AttrType>(v.which() - 1);
209209
if (attr_type == proto::AttrType::INTS &&
210210
boost::get<std::vector<int>>(v).size() == 0u) {
211-
proto::OpProto proto = OpInfoMap::Instance().Get(Type()).Proto();
212211
// Find current attr via attr name and set the correct attribute value
213-
for (int i = 0; i != proto.attrs_size(); ++i) {
214-
const proto::OpProto::Attr &attr = proto.attrs(i);
215-
if (attr.name() == name) {
216-
switch (attr.type()) {
217-
case proto::AttrType::BOOLEANS: {
218-
VLOG(11) << "SetAttr: " << Type() << ", " << name
219-
<< " from INTS to BOOLEANS";
220-
this->attrs_[name] = std::vector<bool>();
221-
break;
222-
}
223-
case proto::AttrType::INTS: {
224-
VLOG(11) << "SetAttr: " << Type() << ", " << name
225-
<< " from INTS to INTS";
226-
this->attrs_[name] = std::vector<int>();
227-
break;
228-
}
229-
case proto::AttrType::FLOATS: {
230-
VLOG(11) << "SetAttr: " << Type() << ", " << name
231-
<< " from INTS to FLOATS";
232-
this->attrs_[name] = std::vector<float>();
233-
break;
234-
}
235-
case proto::AttrType::STRINGS: {
236-
VLOG(11) << "SetAttr: " << Type() << ", " << name
237-
<< " from INTS to STRINGS";
238-
this->attrs_[name] = std::vector<std::string>();
239-
break;
240-
}
241-
case proto::AttrType::BLOCKS: {
242-
VLOG(11) << "SetAttr: " << Type() << ", " << name
243-
<< " from INTS to BLOCKS";
244-
this->SetBlocksAttr(name, std::vector<BlockDesc *>());
245-
return;
246-
}
247-
default:
248-
PADDLE_THROW("Wrong attr type %d", attr.type());
249-
}
250-
need_update_ = true;
212+
const proto::OpProto::Attr& attr = GetProtoAttr(name);
213+
switch (attr.type()) {
214+
case proto::AttrType::BOOLEANS: {
215+
VLOG(11) << "SetAttr: " << Type() << ", " << name
216+
<< " from INTS to BOOLEANS";
217+
this->attrs_[name] = std::vector<bool>();
218+
break;
219+
}
220+
case proto::AttrType::INTS: {
221+
VLOG(11) << "SetAttr: " << Type() << ", " << name
222+
<< " from INTS to INTS";
223+
this->attrs_[name] = std::vector<int>();
224+
break;
225+
}
226+
case proto::AttrType::FLOATS: {
227+
VLOG(11) << "SetAttr: " << Type() << ", " << name
228+
<< " from INTS to FLOATS";
229+
this->attrs_[name] = std::vector<float>();
230+
break;
231+
}
232+
case proto::AttrType::STRINGS: {
233+
VLOG(11) << "SetAttr: " << Type() << ", " << name
234+
<< " from INTS to STRINGS";
235+
this->attrs_[name] = std::vector<std::string>();
236+
break;
237+
}
238+
case proto::AttrType::BLOCKS: {
239+
VLOG(11) << "SetAttr: " << Type() << ", " << name
240+
<< " from INTS to BLOCKS";
241+
this->SetBlocksAttr(name, std::vector<BlockDesc *>());
251242
return;
252243
}
244+
default:
245+
PADDLE_THROW("Wrong attr type %d", attr.type());
253246
}
247+
need_update_ = true;
248+
return;
254249
}
255250

256251
this->attrs_[name] = v;
@@ -280,6 +275,18 @@ Attribute OpDesc::GetAttr(const std::string &name) const {
280275
return it->second;
281276
}
282277

278+
const proto::OpProto::Attr& OpDesc::GetProtoAttr(const std::string &name) {
279+
proto::OpProto& proto = OpInfoMap::Instance().Get(Type()).Proto();
280+
for (int i = 0; i != proto.attrs_size(); ++i) {
281+
const proto::OpProto::Attr &attr = proto.attrs(i);
282+
if (attr.name() == name) {
283+
return attr;
284+
}
285+
}
286+
287+
PADDLE_THROW("Attribute %s is not found in proto %s", name, proto.type());
288+
}
289+
283290
Attribute OpDesc::GetNullableAttr(const std::string &name) const {
284291
auto it = attrs_.find(name);
285292
if (it != attrs_.end()) {

paddle/fluid/framework/op_desc.h

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,8 @@ class OpDesc {
8181

8282
Attribute GetAttr(const std::string &name) const;
8383

84+
const proto::OpProto::Attr& GetProtoAttr(const std::string &name) const;
85+
8486
Attribute GetNullableAttr(const std::string &name) const;
8587

8688
int GetBlockAttr(const std::string &name) const;

python/paddle/dataset/cifar.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,8 @@ def read_batch(batch):
5555

5656
def reader():
5757
with tarfile.open(filename, mode='r') as f:
58-
names = [
59-
each_item.name for each_item in f if sub_name in each_item.name
60-
]
58+
names = (each_item.name for each_item in f
59+
if sub_name in each_item.name)
6160

6261
while True:
6362
for name in names:

python/paddle/dataset/common.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
import sys
2121
import importlib
2222
import paddle.dataset
23-
import paddle.fluid.compat as cpt
2423
import six.moves.cPickle as pickle
2524
import glob
2625

python/paddle/dataset/conll05.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -90,8 +90,8 @@ def reader():
9090
labels = []
9191
one_seg = []
9292
for word, label in zip(words_file, props_file):
93-
word = cpt.to_literal_str(word.strip())
94-
label = cpt.to_literal_str(label.strip().split())
93+
word = cpt.to_text(word.strip())
94+
label = cpt.to_text(label.strip().split())
9595

9696
if len(label) == 0: # end of sentence
9797
for i in range(len(one_seg[0])):

python/paddle/dataset/movielens.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -114,7 +114,7 @@ def __initialize_meta_info__():
114114
categories_set = set()
115115
with package.open('ml-1m/movies.dat') as movie_file:
116116
for i, line in enumerate(movie_file):
117-
line = cpt.to_literal_str(line, encoding='latin')
117+
line = cpt.to_text(line, encoding='latin')
118118
movie_id, title, categories = line.strip().split('::')
119119
categories = categories.split('|')
120120
for c in categories:
@@ -139,7 +139,7 @@ def __initialize_meta_info__():
139139
USER_INFO = dict()
140140
with package.open('ml-1m/users.dat') as user_file:
141141
for line in user_file:
142-
line = cpt.to_literal_str(line, encoding='latin')
142+
line = cpt.to_text(line, encoding='latin')
143143
uid, gender, age, job, _ = line.strip().split("::")
144144
USER_INFO[int(uid)] = UserInfo(
145145
index=uid, gender=gender, age=age, job_id=job)
@@ -152,7 +152,7 @@ def __reader__(rand_seed=0, test_ratio=0.1, is_test=False):
152152
with zipfile.ZipFile(file=fn) as package:
153153
with package.open('ml-1m/ratings.dat') as rating:
154154
for line in rating:
155-
line = cpt.to_literal_str(line, encoding='latin')
155+
line = cpt.to_text(line, encoding='latin')
156156
if (rand.random() < test_ratio) == is_test:
157157
uid, mov_id, rating, _ = line.strip().split("::")
158158
uid = int(uid)

python/paddle/dataset/wmt14.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ def __to_dict(fd, size):
5555
out_dict = dict()
5656
for line_count, line in enumerate(fd):
5757
if line_count < size:
58-
out_dict[cpt.to_literal_str(line.strip())] = line_count
58+
out_dict[cpt.to_text(line.strip())] = line_count
5959
else:
6060
break
6161
return out_dict

python/paddle/dataset/wmt16.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -89,9 +89,9 @@ def __load_dict(tar_file, dict_size, lang, reverse=False):
8989
with open(dict_path, "rb") as fdict:
9090
for idx, line in enumerate(fdict):
9191
if reverse:
92-
word_dict[idx] = cpt.to_literal_str(line.strip())
92+
word_dict[idx] = cpt.to_text(line.strip())
9393
else:
94-
word_dict[cpt.to_literal_str(line.strip())] = idx
94+
word_dict[cpt.to_text(line.strip())] = idx
9595
return word_dict
9696

9797

python/paddle/fluid/backward.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -103,8 +103,8 @@ def _some_in_set_(cands, s):
103103
"""
104104
if len(cands) == 0:
105105
return False
106-
literal_set = cpt.to_literal_str(s)
107-
literal_cands = cpt.to_literal_str(cands)
106+
literal_set = cpt.to_text(s)
107+
literal_cands = cpt.to_text(cands)
108108
for c in literal_cands:
109109
if c in literal_set:
110110
return True
@@ -117,7 +117,7 @@ def _strip_grad_suffix_(name):
117117
e.g. x@GRAD ==> x
118118
y@GRAD@RENAME@1 ==> y
119119
"""
120-
name = cpt.to_literal_str(name)
120+
name = cpt.to_text(name)
121121
pos = name.find(core.grad_var_suffix())
122122
return name[:pos] if pos != -1 else name
123123

@@ -127,7 +127,7 @@ def _append_grad_suffix_(name):
127127
Append grad suffix to the given variable name
128128
e.g. x ==> x@GRAD
129129
"""
130-
return cpt.to_literal_str(name) + core.grad_var_suffix()
130+
return cpt.to_text(name) + core.grad_var_suffix()
131131

132132

133133
def _addup_repetitive_outputs_(op_descs):
@@ -365,7 +365,7 @@ def _append_backward_ops_(block,
365365
# Getting op's corresponding grad_op
366366
grad_op_desc, op_grad_to_var = core.get_grad_op_desc(
367367
op.desc,
368-
cpt.to_literal_str(no_grad_dict[block.idx]), grad_sub_block_list)
368+
cpt.to_text(no_grad_dict[block.idx]), grad_sub_block_list)
369369

370370
grad_op_descs.extend(grad_op_desc)
371371
grad_to_var.update(op_grad_to_var)
@@ -600,7 +600,7 @@ def append_backward(loss, parameter_list=None, no_grad_set=None,
600600

601601
params_and_grads = []
602602
for param in parameters:
603-
if cpt.to_literal_str(param) not in grad_info_map:
603+
if cpt.to_text(param) not in grad_info_map:
604604
continue
605605
grad_info = grad_info_map[param]
606606
grad_block = grad_info[1]

python/paddle/fluid/compat.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
__all__ = [
1919
'long_type',
20-
'to_literal_str',
20+
'to_text',
2121
'to_bytes',
2222
'round',
2323
'floor_division',
@@ -33,7 +33,7 @@
3333

3434

3535
# str and bytes related functions
36-
def to_literal_str(obj, encoding='utf-8', inplace=False):
36+
def to_text(obj, encoding='utf-8', inplace=False):
3737
"""
3838
All string in PaddlePaddle should be represented as a literal string.
3939
This function will convert object to a literal string without any encoding.
@@ -60,23 +60,23 @@ def to_literal_str(obj, encoding='utf-8', inplace=False):
6060
if isinstance(obj, list):
6161
if inplace:
6262
for i in six.moves.xrange(len(obj)):
63-
obj[i] = _to_literal_str(obj[i], encoding)
63+
obj[i] = _to_text(obj[i], encoding)
6464
return obj
6565
else:
66-
return [_to_literal_str(item, encoding) for item in obj]
66+
return [_to_text(item, encoding) for item in obj]
6767
elif isinstance(obj, set):
6868
if inplace:
6969
for item in obj:
7070
obj.remove(item)
71-
obj.add(_to_literal_str(item, encoding))
71+
obj.add(_to_text(item, encoding))
7272
return obj
7373
else:
74-
return set([_to_literal_str(item, encoding) for item in obj])
74+
return set([_to_text(item, encoding) for item in obj])
7575
else:
76-
return _to_literal_str(obj, encoding)
76+
return _to_text(obj, encoding)
7777

7878

79-
def _to_literal_str(obj, encoding):
79+
def _to_text(obj, encoding):
8080
"""
8181
In Python3:
8282
Decode the bytes type object to str type with specific encoding

0 commit comments

Comments
 (0)