|
24 | 24 | 1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
|
25 | 25 | 2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
|
26 | 26 | 3. modify trainer program add split_op to each grad variable.
|
27 |
| -4. append send_op to send splited variables to server and |
| 27 | +4. append send_op to send splited variables to server and |
28 | 28 | 5. add recv_op to fetch params(splited blocks or origin param) from server.
|
29 | 29 | 6. append concat_op to merge splited blocks to update local weights.
|
30 | 30 |
|
|
44 | 44 | from ps_dispatcher import RoundRobin, HashName, PSDispatcher
|
45 | 45 | from .. import core, framework
|
46 | 46 | from ..framework import Program, default_main_program, \
|
47 |
| - default_startup_program, \ |
| 47 | + default_startup_program, Block, \ |
48 | 48 | Variable, Parameter, grad_var_name
|
49 | 49 | from details import *
|
50 | 50 |
|
@@ -471,21 +471,47 @@ def __append_optimize_op__(op, block, grad_to_block_id, merged_var):
|
471 | 471 | self._append_pserver_ops(block, op, endpoint, grad_to_block_id,
|
472 | 472 | self.origin_program, merged_var)
|
473 | 473 | else:
|
474 |
| - self._append_pserver_non_opt_ops(block, op, endpoint) |
| 474 | + self._append_pserver_non_opt_ops(block, op) |
475 | 475 |
|
476 | 476 | def __op_have_grad_input__(op):
|
477 | 477 | for varname in op.input_arg_names:
|
478 | 478 | if varname.find("@GRAD") >= 0:
|
479 | 479 | return varname
|
480 | 480 | return ""
|
481 | 481 |
|
| 482 | + def __clone_lr_op_sub_block__(op, program, new_block): |
| 483 | + if not op.has_attr('sub_block'): |
| 484 | + return |
| 485 | + |
| 486 | + origin_block_desc = op.attr('sub_block') |
| 487 | + origin_block = self.origin_program.block(origin_block_desc.id) |
| 488 | + assert isinstance(origin_block, Block) |
| 489 | + # we put the new sub block to new block to follow the block |
| 490 | + # hierarchy of the original blocks |
| 491 | + new_sub_block = program.create_block(new_block.idx) |
| 492 | + |
| 493 | + # clone vars |
| 494 | + for var in origin_block.vars: |
| 495 | + new_sub_block.clone_variable(var) |
| 496 | + |
| 497 | + # clone ops |
| 498 | + for op in origin_block.ops: |
| 499 | + self._clone_lr_op(program, new_sub_block, op) |
| 500 | + # clone sub_block of op |
| 501 | + __clone_lr_op_sub_block__(op, program, new_sub_block) |
| 502 | + |
| 503 | + # reset the block of op |
| 504 | + op.set_attr('sub_block', new_sub_block) |
| 505 | + |
482 | 506 | # append lr decay ops to the child block if exists
|
483 | 507 | lr_ops = self._get_lr_ops()
|
484 | 508 | if len(lr_ops) > 0:
|
485 | 509 | lr_decay_block = pserver_program.create_block(
|
486 | 510 | pserver_program.num_blocks - 1)
|
487 | 511 | for _, op in enumerate(lr_ops):
|
488 |
| - self._append_pserver_non_opt_ops(lr_decay_block, op, endpoint) |
| 512 | + self._append_pserver_non_opt_ops(lr_decay_block, op) |
| 513 | + # append sub blocks to pserver_program in lr_decay_op |
| 514 | + __clone_lr_op_sub_block__(op, pserver_program, lr_decay_block) |
489 | 515 |
|
490 | 516 | # append op to the current block
|
491 | 517 | grad_to_block_id = []
|
@@ -1116,7 +1142,29 @@ def _is_splited_grad_var(self, var, var_dict):
|
1116 | 1142 | break
|
1117 | 1143 | return grad_block
|
1118 | 1144 |
|
1119 |
| - def _append_pserver_non_opt_ops(self, optimize_block, opt_op, endpoint): |
| 1145 | + def _clone_lr_op(self, program, block, op): |
| 1146 | + inputs = self._get_input_map_from_op( |
| 1147 | + self.origin_program.global_block().vars, op) |
| 1148 | + for key, varlist in inputs.iteritems(): |
| 1149 | + if not isinstance(varlist, list): |
| 1150 | + varlist = [varlist] |
| 1151 | + for var in varlist: |
| 1152 | + if var not in program.global_block().vars: |
| 1153 | + block.clone_variable(var) |
| 1154 | + |
| 1155 | + outputs = self._get_output_map_from_op( |
| 1156 | + self.origin_program.global_block().vars, op) |
| 1157 | + for key, varlist in outputs.iteritems(): |
| 1158 | + if not isinstance(varlist, list): |
| 1159 | + varlist = [varlist] |
| 1160 | + for var in varlist: |
| 1161 | + if var not in program.global_block().vars: |
| 1162 | + block.clone_variable(var) |
| 1163 | + |
| 1164 | + block.append_op( |
| 1165 | + type=op.type, inputs=inputs, outputs=outputs, attrs=op.attrs) |
| 1166 | + |
| 1167 | + def _append_pserver_non_opt_ops(self, optimize_block, opt_op): |
1120 | 1168 | program = optimize_block.program
|
1121 | 1169 | # Append the ops for parameters that do not need to be optimized/updated
|
1122 | 1170 | inputs = self._get_input_map_from_op(
|
|
0 commit comments