Skip to content

Commit 080740b

Browse files
authored
Merge pull request #14300 from jacquesqiao/dist-table-support-optimizer-regular
dist table support other optimize and regular config
2 parents 2d98599 + 04da1dc commit 080740b

File tree

5 files changed

+114
-36
lines changed

5 files changed

+114
-36
lines changed

python/paddle/fluid/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@
3434
from . import average
3535
from . import metrics
3636
from . import transpiler
37+
from . import distribute_lookup_table
3738
from .param_attr import ParamAttr, WeightNormParamAttr
3839
from .data_feeder import DataFeeder
3940
from .core import LoDTensor, LoDTensorArray, CPUPlace, CUDAPlace, CUDAPinnedPlace, Scope
Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
# Copyright (c) 2018 PaddlePaddle Authors. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
LOOKUP_TABLE_TYPE = "lookup_table"
16+
17+
18+
def find_distributed_lookup_table(program):
19+
"""
20+
Find distribute lookup table in program.
21+
We only support one distribute table now.
22+
:param program:
23+
:return: table_name or None
24+
"""
25+
table_name = None
26+
27+
for op in program.global_block().ops:
28+
if op.type == LOOKUP_TABLE_TYPE:
29+
if op.attr('is_distributed') is True:
30+
if table_name is None:
31+
table_name = op.input("W")[0]
32+
if table_name != op.input("W")[0]:
33+
raise RuntimeError("all distributed lookup_table_ops"
34+
" should have only one table")
35+
else:
36+
if table_name is not None:
37+
assert op.input("W")[0] != table_name
38+
39+
return table_name

python/paddle/fluid/optimizer.py

Lines changed: 59 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -13,21 +13,23 @@
1313
# limitations under the License.
1414

1515
from __future__ import print_function
16-
import re
17-
import sys
16+
1817
from collections import defaultdict
18+
from contextlib import contextmanager
19+
1920
from paddle.fluid.framework import Program, Variable, name_scope, default_main_program
21+
from paddle.fluid.distribute_lookup_table import find_distributed_lookup_table
22+
2023
from . import framework
2124
from . import layers
25+
from . import unique_name
2226
from .backward import append_backward
27+
from .clip import append_gradient_clip_ops, error_clip_callback
2328
from .framework import program_guard
24-
from . import unique_name
2529
from .initializer import Constant
2630
from .layer_helper import LayerHelper
27-
from .regularizer import append_regularization_ops
28-
from .clip import append_gradient_clip_ops, error_clip_callback
29-
from contextlib import contextmanager
3031
from .layers import ops
32+
from .regularizer import append_regularization_ops
3133

3234
__all__ = [
3335
'SGD', 'Momentum', 'Adagrad', 'Adam', 'Adamax', 'DecayedAdagrad', 'Ftrl',
@@ -85,7 +87,7 @@ def _create_global_learning_rate(self):
8587
name=unique_name.generate("learning_rate"),
8688
shape=[1],
8789
value=float(self._learning_rate),
88-
dtype='float32' if self._dtype == None else self._dtype,
90+
dtype='float32' if self._dtype is None else self._dtype,
8991
persistable=True)
9092

9193
def _global_learning_rate(self, program=None):
@@ -245,6 +247,50 @@ def _create_optimization_pass(self,
245247
end = len(global_block.ops)
246248
return global_block._slice_ops(start, end)
247249

250+
def _process_distribute_lookuptable(self, param_grads, loss,
251+
startup_program):
252+
"""
253+
Because distribute lookup table only support SGD optimizer for now, not support
254+
other optimizer and regularization, so we should find the table parameter out,
255+
and avoid to add regularization and other op for it, and add sgd optimize op
256+
for it independently.
257+
:param param_grads(list((Var, Var))): list of (param, grad) pair.
258+
:param loss: the loss variable.
259+
:param startup_program: the startup program
260+
"""
261+
program = loss.block.program
262+
table_name = find_distributed_lookup_table(program)
263+
table_param = None
264+
table_grad = None
265+
new_param_grads = []
266+
for p, g in param_grads:
267+
if p.name == table_name:
268+
if table_param is not None:
269+
raise RuntimeError(
270+
"multi dist table var found, only support one now!")
271+
table_param = p
272+
table_grad = g
273+
else:
274+
new_param_grads.append((p, g))
275+
sgd_op = None
276+
if table_param is not None:
277+
with program_guard(program, startup_program):
278+
param_and_grad = [table_param, table_grad]
279+
with table_param.block.program._optimized_guard(param_and_grad), \
280+
framework.name_scope("optimizer"):
281+
self._create_global_learning_rate()
282+
# create the optimize op
283+
sgd_op = loss.block.append_op(
284+
type='sgd',
285+
inputs={
286+
"Param": table_param,
287+
"Grad": table_grad,
288+
"LearningRate":
289+
self._create_param_lr(param_and_grad)
290+
},
291+
outputs={"ParamOut": param_and_grad[0]})
292+
return new_param_grads, (table_param, table_grad), sgd_op
293+
248294
def minimize(self,
249295
loss,
250296
startup_program=None,
@@ -260,6 +306,9 @@ def minimize(self,
260306

261307
params_grads = sorted(params_grads, key=lambda x: x[0].name)
262308

309+
params_grads, table_param_and_grad, table_optimize_op = \
310+
self._process_distribute_lookuptable(params_grads, loss, startup_program)
311+
263312
params_grads = append_gradient_clip_ops(params_grads)
264313

265314
# Add regularization if any
@@ -268,6 +317,9 @@ def minimize(self,
268317

269318
optimize_ops = self._create_optimization_pass(params_grads, loss,
270319
startup_program)
320+
if table_optimize_op is not None:
321+
optimize_ops.append(table_optimize_op)
322+
params_grads.append(table_param_and_grad)
271323
return optimize_ops, params_grads
272324

273325

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 11 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -567,7 +567,6 @@ def transpiler_test_impl(self):
567567
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
568568
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
569569
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
570-
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
571570
'fill_constant', 'fill_constant', 'uniform_random',
572571
'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat',
573572
'fake_init'
@@ -639,7 +638,7 @@ def transpiler_test_impl(self):
639638
# 5 save table
640639
self.assertEqual([op.type for op in pserver1.blocks[5].ops], ["save"])
641640

642-
trainer, _ = self.get_trainer(config)
641+
trainer, trainer_startup = self.get_trainer(config)
643642
self.assertEqual(len(trainer.blocks), 1)
644643
ops = [
645644
'split_ids', 'prefetch', 'merge_ids', 'sequence_pool',
@@ -653,6 +652,16 @@ def transpiler_test_impl(self):
653652
'recv', 'concat'
654653
]
655654
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
655+
startup_ops = [
656+
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
657+
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
658+
'fill_constant', 'fill_constant', 'fill_constant', 'fill_constant',
659+
'fill_constant', 'fill_constant', 'uniform_random',
660+
'uniform_random', 'recv', 'recv', 'recv', 'fetch_barrier', 'concat',
661+
'fake_init'
662+
]
663+
self.assertEqual([op.type for op in trainer_startup.blocks[0].ops],
664+
startup_ops)
656665

657666

658667
class TestDistLookupTableSliceSize(TestDistLookupTableBase):

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 4 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -31,18 +31,17 @@
3131
"""
3232

3333
import math
34-
import sys
3534
import numpy as np
3635
import collections
37-
import six
3836
import logging
3937

40-
from .ps_dispatcher import RoundRobin, HashName, PSDispatcher
38+
from .ps_dispatcher import RoundRobin, PSDispatcher
4139
from .. import core, framework, unique_name
4240
from ..framework import Program, default_main_program, \
4341
default_startup_program, Block, \
4442
Parameter, grad_var_name
4543
from .details import *
44+
from ..distribute_lookup_table import find_distributed_lookup_table
4645
from functools import reduce
4746

4847
LOOKUP_TABLE_TYPE = "lookup_table"
@@ -292,7 +291,8 @@ def transpile(self,
292291
self.optimize_ops, self.params_grads = self._get_optimize_pass()
293292

294293
ps_dispatcher = self.config.split_method(self.pserver_endpoints)
295-
self.has_distributed_lookup_table = self._has_distributed_lookup_table()
294+
self.table_name = find_distributed_lookup_table(self.origin_program)
295+
self.has_distributed_lookup_table = self.table_name != None
296296
self.param_name_to_grad_name = dict()
297297
self.grad_name_to_param_name = dict()
298298
for param_var, grad_var in self.params_grads:
@@ -966,28 +966,6 @@ def _get_slice_vars_and_attrs(self, endpoint):
966966

967967
# ====================== private transpiler functions =====================
968968

969-
def _has_distributed_lookup_table(self):
970-
# process lookup_table_op
971-
# 1. check all lookup_table_op is distributed
972-
# 2. check all lookup_table_op share the same table.
973-
distributed_lookup_table_ops = []
974-
# support only one distributed_lookup_table now
975-
self.table_name = None
976-
for op in self.origin_program.global_block().ops:
977-
if op.type == LOOKUP_TABLE_TYPE:
978-
if op.attr('is_distributed') is True:
979-
if self.table_name is None:
980-
self.table_name = op.input("W")[0]
981-
if self.table_name != op.input("W")[0]:
982-
raise RuntimeError("all distributed lookup_table_ops"
983-
" should have only one table")
984-
distributed_lookup_table_ops.append(op)
985-
else:
986-
if self.table_name is not None:
987-
assert op.input("W")[0] != self.table_name
988-
989-
return len(distributed_lookup_table_ops) > 0
990-
991969
def _update_dist_lookup_table_vars(self, param_list, grad_list,
992970
params_grads):
993971
# TODO(wuyi): put find a way to put dist lookup table stuff all together.
@@ -1341,7 +1319,6 @@ def _create_checkpoint_save_block(self, pserver_program, pre_block_idx):
13411319
"""
13421320
create a new block to handle save checkpoint.
13431321
"""
1344-
import os
13451322

13461323
pserver_program.global_block().create_var(
13471324
name="kLookupTablePath",

0 commit comments

Comments
 (0)