Skip to content

Commit bf97648

Browse files
committed
add TestEmptyPserverOptimizeBlocks
1 parent 11b5c44 commit bf97648

File tree

2 files changed

+27
-1
lines changed

2 files changed

+27
-1
lines changed

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,31 @@ def transpiler_test_impl(self):
405405
["sum", "scale", "scale", "elementwise_add", "momentum"])
406406

407407

408+
class TestEmptyPserverOptimizeBlocks(TranspilerTest):
409+
def net_conf(self):
410+
x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
411+
# only one parameter
412+
y_predict = fluid.layers.fc(input=x,
413+
size=1000,
414+
act=None,
415+
param_attr=fluid.ParamAttr(name='fc_w'),
416+
bias_attr=False)
417+
y = fluid.layers.data(name='y', shape=[1], dtype='float32')
418+
cost = fluid.layers.square_error_cost(input=y_predict, label=y)
419+
avg_cost = fluid.layers.mean(cost)
420+
sgd_optimizer = fluid.optimizer.SGD(learning_rate=1.0)
421+
sgd_optimizer.minimize(avg_cost)
422+
423+
def transpiler_test_impl(self):
424+
config = fluid.DistributeTranspilerConfig()
425+
config.slice_var_up = False
426+
427+
pserver, startup = self.get_pserver(ep=self.pserver2_ep, config=config)
428+
429+
self.assertEqual(len(pserver.blocks), 2)
430+
self.assertEqual(len(pserver.blocks[1].ops), 0)
431+
432+
408433
class TestDistLookupTableBase(TranspilerTest):
409434
def network_with_table(self, is_sparse, is_distributed):
410435
self.table_size = 1000

python/paddle/fluid/transpiler/distribute_transpiler.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
import numpy as np
3636
import collections
3737
import six
38+
import logging
3839

3940
from .ps_dispatcher import RoundRobin, HashName, PSDispatcher
4041
from .. import core, framework
@@ -768,6 +769,7 @@ def __clone_lr_op_sub_block__(op, program, lr_block):
768769
lookup_table_var_name_to_block_id)
769770

770771
if len(optimize_blocks) == 0:
772+
logging.warn("pserver [" + str(endpoint) + "] has no optimize block!!")
771773
pre_block_idx = pserver_program.num_blocks - 1
772774
empty_block = pserver_program._create_block(pre_block_idx)
773775
optimize_blocks.append(empty_block)
@@ -1282,7 +1284,6 @@ def _create_table_optimize_block(self, pserver_index, pserver_program,
12821284
}
12831285
outputs = {"ParamOut": [param_var]}
12841286
# only support sgd now
1285-
import logging
12861287
logging.warn(
12871288
"distribute lookup table only support sgd optimizer, change it's optimizer to sgd instead of "
12881289
+ table_opt_op.type)

0 commit comments

Comments
 (0)