Skip to content

Commit 13e99cf

Browse files
committed
add unit test
1 parent f42247e commit 13e99cf

File tree

1 file changed

+22
-3
lines changed

1 file changed

+22
-3
lines changed

python/paddle/fluid/tests/unittests/test_dist_transpiler.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
# See the License for the specific language governing permissions and
1313
# limitations under the License.
1414

15+
import math
16+
1517
import unittest
1618
import paddle.fluid as fluid
1719
from paddle.fluid.transpiler.distribute_transpiler import delete_ops
@@ -362,12 +364,13 @@ def transpiler_test_impl(self):
362364

363365
class TestDistLookupTableBase(TranspilerTest):
364366
def network_with_table(self, is_sparse, is_distributed):
367+
self.table_size = 1000
368+
self.emb_size = 64
369+
365370
def emb_pool(ids):
366-
table_size = 1000
367-
emb_size = 64
368371
emb = fluid.layers.embedding(
369372
input=ids,
370-
size=[table_size, emb_size],
373+
size=[self.table_size, self.emb_size],
371374
dtype='float32',
372375
param_attr='shared_w', # share parameter
373376
is_sparse=is_sparse,
@@ -536,6 +539,22 @@ def transpiler_test_impl(self):
536539
self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
537540

538541

542+
class TestDistLookupTableSliceSize(TestDistLookupTableBase):
543+
def net_conf(self):
544+
self.network_with_table(is_sparse=True, is_distributed=True)
545+
546+
def transpiler_test_impl(self):
547+
config = fluid.DistributeTranspilerConfig()
548+
pserver1, startup1 = self.get_pserver(self.pserver1_ep, config)
549+
550+
self.assertTrue(self.transpiler.has_distributed_lookup_table)
551+
lookup_table_var = pserver1.global_block().vars[
552+
self.transpiler.table_name]
553+
row_size = lookup_table_var.shape[0]
554+
calc_row_size = int(math.ceil(self.table_size / self.pservers))
555+
self.assertEqual(row_size, calc_row_size)
556+
557+
539558
class TestRMSPropOptimizer(TranspilerTest):
540559
def net_conf(self):
541560
x = fluid.layers.data(name='x', shape=[1000], dtype='float32')

0 commit comments

Comments
 (0)