Skip to content

Commit 0e85686

Browse files
committed
wip
1 parent 136a591 commit 0e85686

File tree

6 files changed

+111
-28
lines changed

6 files changed

+111
-28
lines changed

paddle/framework/block_desc.cc

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -42,28 +42,30 @@ bool BlockDesc::HasVar(const std::string &name) const {
4242
return vars_.find(name) != vars_.end();
4343
}
4444

45-
void BlockDesc::RenameVar(const std::string &old_name,
46-
const std::string &new_name) {
47-
if (this->HasVar(old_name)) {
48-
auto *var = this->Var(old_name);
49-
var->SetName(new_name);
50-
vars_[new_name].reset(var);
51-
vars_.erase(old_name);
52-
// rename inputs and outputs
53-
for (const auto &op : ops_) {
54-
auto *it = op.get();
55-
for (auto in_name : it->InputArgumentNames()) {
56-
if (in_name == old_name) {
57-
it->RenameInput(old_name, new_name);
58-
}
59-
}
60-
for (auto out_name : it->OutputArgumentNames()) {
61-
if (out_name == old_name) {
62-
it->RenameOutput(old_name, new_name);
63-
}
64-
}
65-
}
45+
VarDesc *BlockDesc::RenameVar(const std::string &old_name,
46+
const std::string &new_name) {
47+
if (!this->HasVar(old_name)) {
48+
return nullptr;
49+
}
50+
need_update_ = true;
51+
auto *var = this->Var(old_name);
52+
VarDesc *new_var = new VarDesc(*(var->Proto()));
53+
new_var->SetName(new_name);
54+
// new_var->SetShape(var->GetShape());
55+
// new_var->SetType(var->GetType());
56+
// new_var->SetDataType(var->GetDataType());
57+
// new_var->SetLoDLevel(var->GetLoDLevel());
58+
// new_var->SetPersistable(var->Persistable());
59+
60+
vars_[new_name].reset(new_var);
61+
62+
// rename inputs and outputs
63+
for (const auto &op : ops_) {
64+
auto *it = op.get();
65+
it->Rename(old_name, new_name);
6666
}
67+
vars_.erase(old_name);
68+
return new_var;
6769
}
6870

6971
VarDesc *BlockDesc::FindVarRecursive(const std::string &name) const {

paddle/framework/block_desc.h

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -55,7 +55,7 @@ class BlockDesc {
5555

5656
bool HasVar(const std::string &var_name) const;
5757

58-
void RenameVar(const std::string &old_name, const std::string &new_name);
58+
VarDesc *RenameVar(const std::string &old_name, const std::string &new_name);
5959

6060
VarDesc *FindVarRecursive(const std::string &name_bytes) const;
6161

paddle/pybind/protobuf.cc

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -170,12 +170,14 @@ void BindBlockDesc(py::module &m) {
170170
[](BlockDesc &self, py::bytes byte_name) {
171171
std::string name = byte_name;
172172
return self.HasVar(name);
173-
})
173+
},
174+
py::return_value_policy::reference)
174175
.def("rename_var",
175-
[](BlockDesc &self, py::bytes byte_name, py::bytes byte_name_new) {
176+
[](BlockDesc &self, const py::bytes &byte_name,
177+
const py::bytes &byte_name_new) {
176178
std::string name = byte_name;
177179
std::string new_name = byte_name_new;
178-
return self.RenameVar(name, new_name);
180+
self.RenameVar(name, new_name);
179181
})
180182
.def("has_var_recursive",
181183
[](BlockDesc &self, py::bytes byte_name) {
@@ -213,7 +215,7 @@ void BindVarDsec(py::module &m) {
213215
py::class_<VarDesc> var_desc(m, "VarDesc", "");
214216
var_desc
215217
.def("name",
216-
[](const VarDesc &self) {
218+
[](VarDesc &self) {
217219
py::bytes name = self.Name();
218220
return name;
219221
},

python/paddle/v2/dataset/common.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -74,6 +74,8 @@ def download(url, module_name, md5sum, save_name=None):
7474
retry = 0
7575
retry_limit = 3
7676
while not (os.path.exists(filename) and md5file(filename) == md5sum):
77+
if os.path.exists(filename):
78+
print "file md5", md5file(filename), md5sum
7779
if retry < retry_limit:
7880
retry += 1
7981
else:

python/paddle/v2/fluid/distribute_transpiler.py

Lines changed: 12 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -175,6 +175,7 @@ def transpile(self,
175175
shape=[0])
176176

177177
# create send_op
178+
print("send inputs: ", send_inputs)
178179
send_op = program.global_block().append_op(
179180
type="send",
180181
inputs={"X": send_inputs},
@@ -204,12 +205,12 @@ def _create_vars_from_blocklist(self, program, block_list):
204205
block_map[varname].append((long(offset), long(size)))
205206
for varname, splited in block_map.iteritems():
206207
orig_var = program.global_block().var(varname)
207-
208208
if len(splited) == 1:
209209
# rename var to the trainer_id var
210210
new_var_name = "%s.trainer_%d" % \
211211
(orig_var.name, self.trainer_id)
212212
program.global_block().rename_var(varname, new_var_name)
213+
print("renaming OK...", varname, new_var_name)
213214
var_mapping[varname] = \
214215
[program.global_block().var(new_var_name)]
215216
continue
@@ -375,7 +376,10 @@ def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint):
375376
new_inputs = dict()
376377
# update param/grad shape first, then other inputs like
377378
# moment can use the updated shape
379+
print("mark1")
378380
for key in opt_op.input_names:
381+
# print("opt type: ", opt_op.type)
382+
# print("opt op input: ", key)
379383
if key == "Grad":
380384
grad_block = None
381385
for g in self.param_grad_ep_mapping[endpoint]["grads"]:
@@ -422,6 +426,7 @@ def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint):
422426

423427
new_inputs[key] = tmpvar
424428

429+
print("mark2")
425430
for key in opt_op.input_names:
426431
if key in ["Param", "Grad"]:
427432
continue
@@ -453,6 +458,7 @@ def _append_pserver_ops(self, program, pserver_program, opt_op, endpoint):
453458
inputs=new_inputs,
454459
outputs=outputs,
455460
attrs=opt_op.attrs)
461+
print("mark3")
456462

457463
def _append_pserver_non_opt_ops(self, program, pserver_program, opt_op):
458464
# Append the ops for parameters that do not need to be optimized/updated
@@ -523,6 +529,11 @@ def get_pserver_program(self, endpoint):
523529
optimize_sub_program = Program()
524530
# Iterate through the ops and append ops as needed
525531
for idx, opt_op in enumerate(self.optimize_ops):
532+
print("mark0")
533+
print(opt_op.inputs.keys())
534+
for v in opt_op.inputs.values():
535+
print(v.name)
536+
print(v.shape)
526537
is_op_on_pserver = self._is_op_on_pserver(endpoint,
527538
self.optimize_ops, idx)
528539
if not is_op_on_pserver:

python/paddle/v2/fluid/framework.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -741,9 +741,75 @@ def rename_var(self, name, new_name):
741741
"""
742742
if not self.has_var(name):
743743
raise ValueError("var %s is not in current" % name)
744+
v = self.var(name)
745+
stop_gradient = None
746+
trainable = None
747+
optimize_attr = None
748+
regularizer = None
749+
gradient_clip_attr = None
750+
error_clip = None
751+
if type(v) == Parameter:
752+
stop_gradient = v.stop_gradient
753+
trainable = v.trainable
754+
optimize_attr = v.optimize_attr
755+
regularizer = v.regularizer
756+
gradient_clip_attr = v.gradient_clip_attr
757+
error_clip = v.error_clip
758+
elif type(v) == Variable:
759+
error_clip = v.error_clip
760+
stop_gradient = v.stop_gradient
761+
else:
762+
raise ValueError("unsupported var type: %s", type(v))
763+
764+
def _clear_op_io_for_var(name):
765+
for op in self.ops:
766+
for k in op.inputs.keys():
767+
768+
if op.inputs[k].name == name:
769+
op.inputs[k] = None
770+
for k in op.outputs.keys():
771+
if op.outputs[k].name == name:
772+
op.outputs[k] = None
773+
774+
_clear_op_io_for_var(name)
744775
self.desc.rename_var(name, new_name)
776+
d = self.desc.find_var(new_name)
777+
var = None
778+
if type(v) == Parameter:
779+
var = Parameter(
780+
self,
781+
d.shape(),
782+
d.dtype(),
783+
name=new_name,
784+
stop_gradient=stop_gradient,
785+
trainable=trainable,
786+
optimize_attr=optimize_attr,
787+
regularizer=regularizer,
788+
gradient_clip_attr=gradient_clip_attr,
789+
error_clip=error_clip)
790+
elif type(v) == Variable:
791+
var = Variable(
792+
self,
793+
name=new_name,
794+
error_clip=error_clip,
795+
stop_gradient=stop_gradient)
796+
797+
# rename the python side, sync_with_cpp will only add
798+
# new vars/ops to python side.
799+
self.vars[new_name] = var
800+
for op in self.ops:
801+
print("### rename op i/o ", name, op.inputs)
802+
if op.inputs:
803+
for k in op.inputs.keys():
804+
if op.inputs[k] == None:
805+
print("rename input: ", name, var)
806+
op.inputs[k] = var
807+
if op.outputs:
808+
for k in op.outputs.keys():
809+
if op.outputs[k] == None:
810+
op.outputs[k] = var
811+
del self.vars[name]
745812
self.sync_with_cpp()
746-
print("renamed var: ", self.var(new_name))
747813

748814
def create_parameter(self, *args, **kwargs):
749815
global_block = self.program.global_block()

0 commit comments

Comments
 (0)