Skip to content

Commit 842fb02

Browse files
authored
Fix clone() bug. (#12583)
1 parent 7b03b18 commit 842fb02

File tree

13 files changed

+372
-55
lines changed

13 files changed

+372
-55
lines changed

paddle/fluid/API.spec

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,9 @@ paddle.fluid.Operator.all_attrs ArgSpec(args=['self'], varargs=None, keywords=No
1818
paddle.fluid.Operator.attr ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None)
1919
paddle.fluid.Operator.attr_type ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None)
2020
paddle.fluid.Operator.block_attr ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None)
21+
paddle.fluid.Operator.block_attr_id ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None)
22+
paddle.fluid.Operator.blocks_attr ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None)
23+
paddle.fluid.Operator.blocks_attr_ids ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None)
2124
paddle.fluid.Operator.has_attr ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None)
2225
paddle.fluid.Operator.has_kernel ArgSpec(args=['self', 'op_type'], varargs=None, keywords=None, defaults=None)
2326
paddle.fluid.Operator.input ArgSpec(args=['self', 'name'], varargs=None, keywords=None, defaults=None)

paddle/fluid/framework/op_desc.cc

Lines changed: 14 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -238,7 +238,20 @@ Attribute OpDesc::GetNullableAttr(const std::string &name) const {
238238
}
239239
}
240240

241-
int OpDesc::GetBlockAttr(const std::string &name) const {
241+
std::vector<int> OpDesc::GetBlocksAttrIds(const std::string &name) const {
242+
auto it = attrs_.find(name);
243+
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
244+
auto blocks = boost::get<std::vector<BlockDesc *>>(it->second);
245+
246+
std::vector<int> ids;
247+
for (auto n : blocks) {
248+
ids.push_back(n->ID());
249+
}
250+
251+
return ids;
252+
}
253+
254+
int OpDesc::GetBlockAttrId(const std::string &name) const {
242255
auto it = attrs_.find(name);
243256
PADDLE_ENFORCE(it != attrs_.end(), "Attribute %s is not found", name);
244257
return boost::get<BlockDesc *>(it->second)->ID();

paddle/fluid/framework/op_desc.h

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,9 @@ class OpDesc {
8383

8484
Attribute GetNullableAttr(const std::string &name) const;
8585

86-
int GetBlockAttr(const std::string &name) const;
86+
int GetBlockAttrId(const std::string &name) const;
87+
88+
std::vector<int> GetBlocksAttrIds(const std::string &name) const;
8789

8890
void Rename(const std::string &old_name, const std::string &new_name);
8991

paddle/fluid/framework/program_desc.cc

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -58,7 +58,7 @@ ProgramDesc::ProgramDesc(const ProgramDesc &o) {
5858
for (const std::string &attr_name : op->AttrNames()) {
5959
if (op->GetAttrType(attr_name) == proto::AttrType::BLOCK) {
6060
int sub_block_id =
61-
o.Block(block_id).Op(op_id)->GetBlockAttr(attr_name);
61+
o.Block(block_id).Op(op_id)->GetBlockAttrId(attr_name);
6262
op->SetBlockAttr(attr_name, MutableBlock(sub_block_id));
6363
}
6464
}

paddle/fluid/pybind/protobuf.cc

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -301,7 +301,8 @@ void BindOpDesc(pybind11::module *m) {
301301
std::string ser(seriralized);
302302
self.SetAttr(name, ser);
303303
})
304-
.def("block_attr", &pd::OpDesc::GetBlockAttr)
304+
.def("block_attr_id", &pd::OpDesc::GetBlockAttrId)
305+
.def("blocks_attr_ids", &pd::OpDesc::GetBlocksAttrIds)
305306
.def("check_attrs", &pd::OpDesc::CheckAttrs)
306307
.def("infer_shape", &pd::OpDesc::InferShape)
307308
.def("infer_var_type", &pd::OpDesc::InferVarType)

python/paddle/fluid/backward.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -344,7 +344,7 @@ def _append_backward_ops_(block,
344344
grad_sub_block_list = []
345345
# If the op has its own sub-block, deal with the sub-block first
346346
if op.has_attr("sub_block"):
347-
sub_block = program.block(op.block_attr("sub_block"))
347+
sub_block = program.block(op.block_attr_id("sub_block"))
348348
grad_sub_block = program.create_block()
349349
grad_sub_block._set_forward_block_idx(sub_block.idx)
350350
cb = _callback_lookup_(op)
@@ -406,7 +406,7 @@ def _append_backward_vars_(block, start_op_idx, grad_to_var, grad_info_map):
406406
for op_idx in range(start_op_idx, block.desc.op_size()):
407407
op_desc = block.desc.op(op_idx)
408408
if op_desc.has_attr("sub_block"):
409-
sub_block = block.program.block(op_desc.block_attr("sub_block"))
409+
sub_block = block.program.block(op_desc.block_attr_id("sub_block"))
410410
_append_backward_vars_(sub_block, 0, grad_to_var, grad_info_map)
411411
new_vars = set()
412412
# create new gradient variables

python/paddle/fluid/framework.py

Lines changed: 81 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -476,23 +476,25 @@ def __init__(self,
476476
attrs=None):
477477
self.block = block
478478
self.desc = desc
479-
self.attrs = attrs
480-
if self.attrs is None:
481-
self.attrs = dict()
479+
# note: not add self.attrs here:
480+
# https://github.com/PaddlePaddle/Paddle/pull/12583#pullrequestreview-145093173
481+
op_attrs = attrs
482+
if op_attrs is None:
483+
op_attrs = dict()
482484
del attrs
483485

484486
op_maker = core.op_proto_and_checker_maker
485487

486-
if op_maker.kOpRoleAttrName() not in self.attrs:
487-
self.attrs[op_maker.kOpRoleAttrName()] = self.block.program.op_role
488+
if op_maker.kOpRoleAttrName() not in op_attrs:
489+
op_attrs[op_maker.kOpRoleAttrName()] = self.block.program.op_role
488490

489491
role_var_name = op_maker.kOpRoleVarAttrName()
490492
if len(self.block.program.
491-
op_role_var) != 0 and role_var_name not in self.attrs:
492-
self.attrs[role_var_name] = self.block.program.op_role_var
493+
op_role_var) != 0 and role_var_name not in op_attrs:
494+
op_attrs[role_var_name] = self.block.program.op_role_var
493495

494-
if role_var_name in self.attrs and len(self.attrs[role_var_name]) == 0:
495-
del self.attrs[role_var_name]
496+
if role_var_name in op_attrs and len(op_attrs[role_var_name]) == 0:
497+
del op_attrs[role_var_name]
496498

497499
if len(self.desc.type()) != 0:
498500
return
@@ -576,15 +578,14 @@ def find_name(var_list, name):
576578
arg.op = self
577579
self.desc.set_output(out_proto.name, out_arg_names)
578580

579-
if self.attrs is not None:
580-
if not isinstance(self.attrs, dict):
581+
if op_attrs is not None:
582+
if not isinstance(op_attrs, dict):
581583
raise TypeError("'attrs' should be a dict.")
582584
for attr in proto.attrs:
583585
attr_name = attr.name
584-
if (attr_name not in self.attrs) or (
585-
self.attrs[attr_name] is None):
586+
if (attr_name not in op_attrs) or (op_attrs[attr_name] is None):
586587
continue
587-
attr_val = self.attrs[attr_name]
588+
attr_val = op_attrs[attr_name]
588589
self._update_desc_attr(attr_name, attr_val)
589590

590591
self.desc.check_attrs()
@@ -732,7 +733,6 @@ def set_attr(self, name, val):
732733
Raises:
733734
ValueError: If the type of value doesn't match with desc.attr_type(name).
734735
"""
735-
self.attrs[name] = val
736736
self._update_desc_attr(name, val)
737737

738738
def _update_desc_attr(self, name, val):
@@ -774,32 +774,84 @@ def attr(self, name):
774774
"""
775775
return self.desc.attr(name)
776776

777-
def block_attr(self, name):
777+
def block_attr_id(self, name):
778778
"""
779-
Get the block attribute by name.
779+
Get the block attribute's id by name.
780780
781781
Args:
782782
name(str): the attribute name.
783783
784784
Returns:
785785
int: the block index.
786786
"""
787-
return self.desc.block_attr(name)
787+
return self.desc.block_attr_id(name)
788+
789+
def block_attr(self, name):
790+
"""
791+
Get the block attribute by name.
792+
793+
Args:
794+
name(str): the attribute name.
795+
796+
Returns:
797+
block: the block attribute.
798+
"""
799+
800+
id = self.block_attr_id(name)
801+
assert (id >= 0 and id < len(self.block.program.blocks))
802+
return self.block.program.blocks[id]
803+
804+
def blocks_attr(self, name):
805+
"""
806+
Get the blocks attribute by name.
807+
808+
Args:
809+
name(str): the attribute name.
810+
811+
Returns:
812+
list: list of the blocks attribute.
813+
"""
814+
attrs = []
815+
for i in self.blocks_attr_ids(name):
816+
assert (i >= 0 and i < len(self.block.program.blocks))
817+
attrs.append(self.block.program.blocks[i])
818+
819+
return attrs
820+
821+
def blocks_attr_ids(self, name):
822+
"""
823+
Get the blocks attribute's ids by name.
824+
825+
Args:
826+
name(str): the attribute name.
827+
828+
Returns:
829+
list: list of the blocks ids.
830+
"""
831+
832+
return self.desc.blocks_attr_ids(name)
788833

789834
def all_attrs(self):
790835
"""
791836
Get the attribute dict.
792837
793838
Returns:
794-
dict: The Operator's attribute dict.
839+
dict: The Operator's attribute dict, name->attr.
795840
"""
796841
attr_names = self.attr_names
797842
attr_map = {}
798843
for n in attr_names:
799-
if n == 'sub_block':
844+
attr_type = self.desc.attr_type(n)
845+
if attr_type == core.AttrType.BLOCK:
800846
attr_map[n] = self.block_attr(n)
801-
else:
802-
attr_map[n] = self.attr(n)
847+
continue
848+
849+
if attr_type == core.AttrType.BLOCKS:
850+
attr_map[n] = self.blocks_attr(n)
851+
continue
852+
853+
attr_map[n] = self.attr(n)
854+
803855
return attr_map
804856

805857

@@ -1521,8 +1573,14 @@ def clone(self, for_test=False):
15211573
p = self.inference_optimize(export_for_deployment=False)
15221574
else:
15231575
p = Program()
1576+
p.current_block_idx = self.current_block_idx
1577+
p._seed = self._seed
15241578
p.desc = core.ProgramDesc(self.desc)
1525-
p.blocks = [Block(p, i) for i in range(self.desc.num_blocks())]
1579+
p.blocks = [Block(p, i) for i in xrange(self.desc.num_blocks())]
1580+
1581+
p._current_role = self._current_role
1582+
p._op_role_var = self._op_role_var
1583+
15261584
p._sync_with_cpp()
15271585

15281586
p._copy_param_info_from(self)

python/paddle/fluid/initializer.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,8 @@ def __call__(self, var, block):
264264
"dtype": int(var.dtype),
265265
"mean": self._mean,
266266
"std": self._std_dev,
267-
"seed": self._seed
267+
"seed": self._seed,
268+
"use_mkldnn": False
268269
})
269270
var.op = op
270271
return op

0 commit comments

Comments
 (0)