Skip to content

Commit 6588d2e

Browse files
author
yi.wu
committed
complete dist transpiler doc
1 parent 4c3eb44 commit 6588d2e

File tree

2 files changed

+207
-151
lines changed

2 files changed

+207
-151
lines changed

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 197 additions & 149 deletions
Original file line numberDiff line numberDiff line change
@@ -12,14 +12,6 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414
"""
15-
Transpile the program to distributed data-parallelism programs.
16-
The main_program will be transformed to use a remote parameter server
17-
to do parameter optimization. And the optimization graph will be put
18-
into a parameter server program.
19-
20-
Use different methods to split trainable variables to different
21-
parameter servers.
22-
2315
Steps to transpile trainer:
2416
1. split variable to multiple blocks, aligned by product(dim[1:]) (width).
2517
2. rename splited grad variables to add trainer_id suffix ".trainer_%d".
@@ -118,128 +110,40 @@ def slice_variable(var_list, slice_count, min_block_size=8192):
118110

119111

120112
class DistributeTranspiler:
121-
def _has_distributed_lookup_table(self):
122-
# process lookup_table_op
123-
# 1. check all lookup_table_op is distributed
124-
# 2. check all lookup_table_op share the same table.
125-
distributed_lookup_table_ops = []
126-
# support only one distributed_lookup_table now
127-
self.table_name = None
128-
for op in self.origin_program.global_block().ops:
129-
if op.type == LOOKUP_TABLE_TYPE:
130-
if op.attrs['is_distributed'] is True:
131-
if self.table_name is None:
132-
self.table_name = op.input("W")[0]
133-
if self.table_name != op.input("W")[0]:
134-
raise RuntimeError("all distributed lookup_table_ops"
135-
" should have only one table")
136-
distributed_lookup_table_ops.append(op)
137-
else:
138-
if self.table_name is not None:
139-
assert op.input("W")[0] != self.table_name
140-
141-
return len(distributed_lookup_table_ops) > 0
142-
143-
def _update_dist_lookup_table_vars(self, param_list, grad_list,
144-
params_grads):
145-
# TODO(wuyi): put find a way to put dist lookup table stuff all together.
146-
# update self.table_param_grad and self.trainer_side_table_grad_list
147-
program = self.origin_program
148-
if self.has_distributed_lookup_table:
149-
param_list = [
150-
param for param in param_list if param.name != self.table_name
151-
]
152-
grad_list = [
153-
grad for grad in grad_list
154-
if grad.name != grad_var_name(self.table_name)
155-
]
156-
self.table_param_grad = [
157-
param_grad for param_grad in params_grads
158-
if param_grad[0].name == self.table_name
159-
][0]
160-
table_grad_var = self.table_param_grad[1]
161-
if self.sync_mode:
162-
self.trainer_side_table_grad_list = [
163-
program.global_block().create_var(
164-
name="%s.trainer_%d.pserver_%d" %
165-
(table_grad_var.name, self.trainer_id, index),
166-
type=table_grad_var.type,
167-
shape=table_grad_var.shape,
168-
dtype=table_grad_var.dtype)
169-
for index in range(len(self.pserver_endpoints))
170-
]
171-
else:
172-
self.trainer_side_table_grad_list = [
173-
program.global_block().create_var(
174-
name="%s.pserver_%d" % (table_grad_var.name, index),
175-
type=table_grad_var.type,
176-
shape=table_grad_var.shape,
177-
dtype=table_grad_var.dtype)
178-
for index in range(len(self.pserver_endpoints))
179-
]
180-
return param_list, grad_list
181-
182-
def _init_splited_vars(self, slice_var_up):
183-
# update these mappings for further transpile:
184-
# 1. param_var_mapping: param var name -> [splited params vars]
185-
# 2. grad_var_mapping: grad var name -> [splited grads vars]
186-
# 3. grad_param_mapping: grad.blockx -> param.blockx
187-
# 4. param_grad_ep_mapping: ep -> {"params": [], "grads": []}
188-
189-
param_list = []
190-
grad_list = []
191-
param_grad_set = set()
192-
for p, g in self.params_grads:
193-
# skip parameter marked not trainable
194-
if type(p) == Parameter and p.trainable == False:
195-
continue
196-
if p.name not in param_grad_set:
197-
param_list.append(p)
198-
param_grad_set.add(p.name)
199-
if g.name not in param_grad_set:
200-
grad_list.append(g)
201-
param_grad_set.add(g.name)
202-
203-
param_list, grad_list = self._update_dist_lookup_table_vars(
204-
param_list, grad_list, self.params_grads)
205-
206-
if slice_var_up:
207-
# when we slice var up into blocks, we will slice the var according to
208-
# pserver services' count. A pserver may have two or more listening ports.
209-
grad_blocks = slice_variable(grad_list, len(self.pserver_endpoints))
210-
param_blocks = slice_variable(param_list,
211-
len(self.pserver_endpoints))
212-
else:
213-
# when we do NOT slice var up into blocks, we will always slice params
214-
# grads into one block.
215-
grad_blocks = slice_variable(grad_list, 1)
216-
param_blocks = slice_variable(param_list, 1)
217-
assert (len(grad_blocks) == len(param_blocks))
218-
219-
# origin_varname -> [splited_var]
220-
self.param_var_mapping = self._create_vars_from_blocklist(
221-
self.origin_program, param_blocks)
222-
self.grad_var_mapping = self._create_vars_from_blocklist(
223-
self.origin_program,
224-
grad_blocks,
225-
add_trainer_suffix=self.trainer_num > 1)
226-
self.grad_param_mapping = dict()
227-
for g, p in zip(grad_blocks, param_blocks):
228-
g_name, g_bid, _ = g.split(":")
229-
p_name, p_bid, _ = p.split(":")
230-
self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \
231-
self.param_var_mapping[p_name][int(p_bid)]
232-
233-
# create mapping of endpoint -> split var to create pserver side program
234-
self.param_grad_ep_mapping = dict()
235-
[
236-
self.param_grad_ep_mapping.update({
237-
ep: {
238-
"params": [],
239-
"grads": []
240-
}
241-
}) for ep in self.pserver_endpoints
242-
]
113+
"""
114+
**DistributeTranspiler**
115+
116+
Convert the fluid program to distributed data-parallelism programs.
117+
118+
The main_program will be transformed to use a remote parameter server
119+
to do parameter optimization. And the optimization graph will be put
120+
into a parameter server program.
121+
122+
Examples:
123+
.. code-block:: python
124+
125+
# Define your model before these codes.
126+
port = os.getenv("PADDLE_PSERVER_PORT", "6174")
127+
pserver_ips = os.getenv("PADDLE_PSERVER_IPS", "")
128+
eplist = []
129+
for ip in pserver_ips.split(","):
130+
eplist.append(':'.join([ip, port]))
131+
pserver_endpoints = ",".join(eplist)
132+
trainers = int(os.getenv("PADDLE_TRAINERS"))
133+
current_endpoint = os.getenv("PADDLE_CURRENT_IP", "") + ":" + port
134+
trainer_id = int(os.getenv("PADDLE_TRAINER_ID", "0"))
135+
role = os.getenv("PADDLE_TRAINING_ROLE")
136+
137+
t = distribute_transpiler.DistributeTranspiler()
138+
t.transpile(
139+
trainer_id, pservers=pserver_endpoints, trainers=trainers)
140+
if role == "PSERVER":
141+
pserver_program = t.get_pserver_program(current_endpoint)
142+
pserver_startup_program = t.get_startup_program(current_endpoint,
143+
pserver_program)
144+
elif role == "TRAINER":
145+
trainer_program = t.get_trainer_program()
146+
"""
243147

244148
def transpile(self,
245149
trainer_id,
@@ -250,20 +154,20 @@ def transpile(self,
250154
split_method=RoundRobin,
251155
sync_mode=True):
252156
"""
253-
:param trainer_id: one unique id for each trainer in a job.
254-
:type trainer_id: int
255-
:param program: program to transpile, default is default_main_program
256-
:type program: Program
257-
:param pservers: parameter server endpoints like "m1:6174,m2:6174"
258-
:type pservers: string
259-
:param trainers: total number of workers/trainers in the job
260-
:type trainers: int
261-
:param split_method: A function to determin how to split variables
262-
to different servers equally.
263-
:type split_method: function
264-
:type sync_mode: boolean default True
265-
:param sync_mode: if sync_mode is set True, it means that dist transpiler
266-
will transpile the program into sync_mode pserver and trainer program.
157+
Run the transpiler.
158+
159+
Args:
160+
trainer_id (int): id for current trainer worker, if you have
161+
n workers, the id may range from 0 ~ n-1
162+
program (Program|None): program to transpile,
163+
default is fluid.default_main_program().
164+
pservers (str): comma separated ip:port string for the pserver
165+
list.
166+
trainers (int): number of trainers in the distributed job.
167+
slice_var_up (bool): Do Tensor slice for pservers, default is True.
168+
split_method (PSDispatcher): RoundRobin or HashName can be used
169+
try to choose the best method to balance loads for pservers.
170+
sync_mode (bool): Do sync training or not, default is True.
267171
"""
268172
assert (split_method.__bases__[0] == PSDispatcher)
269173
if program is None:
@@ -390,6 +294,12 @@ def transpile(self,
390294
self._split_table_grad_and_add_send_vars(program, pserver_endpoints)
391295

392296
def get_trainer_program(self):
297+
"""
298+
Get transpiled trainer side program.
299+
300+
Returns:
301+
Program: trainer side program.
302+
"""
393303
# remove optimize ops and add a send op to main_program
394304
delete_ops(self.origin_program.global_block(), self.optimize_ops)
395305
# FIXME(typhoonzero): serialize once will fix error occurs when clone.
@@ -398,12 +308,19 @@ def get_trainer_program(self):
398308

399309
def get_pserver_program(self, endpoint):
400310
"""
401-
Get pserver side program using the endpoint.
402-
TODO(panyx0718): Revisit this assumption. what if #blocks > #pservers.
403-
NOTE: assume blocks of the same variable is not distributed
404-
on the same pserver, only change param/grad varnames for
405-
trainers to fetch.
311+
Get parameter server side program.
312+
313+
Args:
314+
endpoint (str): current parameter server endpoint.
315+
316+
Returns:
317+
Program: the program for current parameter server to run.
406318
"""
319+
# TODO(panyx0718): Revisit this assumption. what if #blocks > #pservers.
320+
# NOTE: assume blocks of the same variable is not distributed
321+
# on the same pserver, only change param/grad varnames for
322+
# trainers to fetch.
323+
407324
# step1
408325
pserver_program = Program()
409326
# step2: Create vars to receive vars at parameter servers.
@@ -556,6 +473,14 @@ def get_startup_program(self, endpoint, pserver_program):
556473
Get startup program for current parameter server.
557474
Modify operator input variables if there are variables that
558475
were split to several blocks.
476+
477+
Args:
478+
endpoint (str): current pserver endpoint.
479+
pserver_program (Program): call get_pserver_program first and
480+
pass the result here.
481+
482+
Returns:
483+
Program: parameter server side startup program.
559484
"""
560485
s_prog = Program()
561486
orig_s_prog = default_startup_program()
@@ -607,6 +532,129 @@ def _get_splited_name_and_shape(varname):
607532

608533
# ====================== private transpiler functions =====================
609534

535+
def _has_distributed_lookup_table(self):
536+
# process lookup_table_op
537+
# 1. check all lookup_table_op is distributed
538+
# 2. check all lookup_table_op share the same table.
539+
distributed_lookup_table_ops = []
540+
# support only one distributed_lookup_table now
541+
self.table_name = None
542+
for op in self.origin_program.global_block().ops:
543+
if op.type == LOOKUP_TABLE_TYPE:
544+
if op.attrs['is_distributed'] is True:
545+
if self.table_name is None:
546+
self.table_name = op.input("W")[0]
547+
if self.table_name != op.input("W")[0]:
548+
raise RuntimeError("all distributed lookup_table_ops"
549+
" should have only one table")
550+
distributed_lookup_table_ops.append(op)
551+
else:
552+
if self.table_name is not None:
553+
assert op.input("W")[0] != self.table_name
554+
555+
return len(distributed_lookup_table_ops) > 0
556+
557+
def _update_dist_lookup_table_vars(self, param_list, grad_list,
558+
params_grads):
559+
# TODO(wuyi): put find a way to put dist lookup table stuff all together.
560+
# update self.table_param_grad and self.trainer_side_table_grad_list
561+
program = self.origin_program
562+
if self.has_distributed_lookup_table:
563+
param_list = [
564+
param for param in param_list if param.name != self.table_name
565+
]
566+
grad_list = [
567+
grad for grad in grad_list
568+
if grad.name != grad_var_name(self.table_name)
569+
]
570+
self.table_param_grad = [
571+
param_grad for param_grad in params_grads
572+
if param_grad[0].name == self.table_name
573+
][0]
574+
table_grad_var = self.table_param_grad[1]
575+
if self.sync_mode:
576+
self.trainer_side_table_grad_list = [
577+
program.global_block().create_var(
578+
name="%s.trainer_%d.pserver_%d" %
579+
(table_grad_var.name, self.trainer_id, index),
580+
type=table_grad_var.type,
581+
shape=table_grad_var.shape,
582+
dtype=table_grad_var.dtype)
583+
for index in range(len(self.pserver_endpoints))
584+
]
585+
else:
586+
self.trainer_side_table_grad_list = [
587+
program.global_block().create_var(
588+
name="%s.pserver_%d" % (table_grad_var.name, index),
589+
type=table_grad_var.type,
590+
shape=table_grad_var.shape,
591+
dtype=table_grad_var.dtype)
592+
for index in range(len(self.pserver_endpoints))
593+
]
594+
return param_list, grad_list
595+
596+
def _init_splited_vars(self, slice_var_up):
597+
# update these mappings for further transpile:
598+
# 1. param_var_mapping: param var name -> [splited params vars]
599+
# 2. grad_var_mapping: grad var name -> [splited grads vars]
600+
# 3. grad_param_mapping: grad.blockx -> param.blockx
601+
# 4. param_grad_ep_mapping: ep -> {"params": [], "grads": []}
602+
603+
param_list = []
604+
grad_list = []
605+
param_grad_set = set()
606+
for p, g in self.params_grads:
607+
# skip parameter marked not trainable
608+
if type(p) == Parameter and p.trainable == False:
609+
continue
610+
if p.name not in param_grad_set:
611+
param_list.append(p)
612+
param_grad_set.add(p.name)
613+
if g.name not in param_grad_set:
614+
grad_list.append(g)
615+
param_grad_set.add(g.name)
616+
617+
param_list, grad_list = self._update_dist_lookup_table_vars(
618+
param_list, grad_list, self.params_grads)
619+
620+
if slice_var_up:
621+
# when we slice var up into blocks, we will slice the var according to
622+
# pserver services' count. A pserver may have two or more listening ports.
623+
grad_blocks = slice_variable(grad_list, len(self.pserver_endpoints))
624+
param_blocks = slice_variable(param_list,
625+
len(self.pserver_endpoints))
626+
else:
627+
# when we do NOT slice var up into blocks, we will always slice params
628+
# grads into one block.
629+
grad_blocks = slice_variable(grad_list, 1)
630+
param_blocks = slice_variable(param_list, 1)
631+
assert (len(grad_blocks) == len(param_blocks))
632+
633+
# origin_varname -> [splited_var]
634+
self.param_var_mapping = self._create_vars_from_blocklist(
635+
self.origin_program, param_blocks)
636+
self.grad_var_mapping = self._create_vars_from_blocklist(
637+
self.origin_program,
638+
grad_blocks,
639+
add_trainer_suffix=self.trainer_num > 1)
640+
self.grad_param_mapping = dict()
641+
for g, p in zip(grad_blocks, param_blocks):
642+
g_name, g_bid, _ = g.split(":")
643+
p_name, p_bid, _ = p.split(":")
644+
self.grad_param_mapping[self.grad_var_mapping[g_name][int(g_bid)]] = \
645+
self.param_var_mapping[p_name][int(p_bid)]
646+
647+
# create mapping of endpoint -> split var to create pserver side program
648+
self.param_grad_ep_mapping = dict()
649+
[
650+
self.param_grad_ep_mapping.update({
651+
ep: {
652+
"params": [],
653+
"grads": []
654+
}
655+
}) for ep in self.pserver_endpoints
656+
]
657+
610658
# transpiler function for dis lookup_table
611659
def _replace_lookup_table_op_with_prefetch(self, program,
612660
pserver_endpoints):

0 commit comments

Comments
 (0)