|
12 | 12 | # See the License for the specific language governing permissions and
|
13 | 13 | # limitations under the License.
|
14 | 14 |
|
| 15 | +import math |
| 16 | + |
15 | 17 | import unittest
|
16 | 18 | import paddle.fluid as fluid
|
17 | 19 | from paddle.fluid.transpiler.distribute_transpiler import delete_ops
|
@@ -362,12 +364,13 @@ def transpiler_test_impl(self):
|
362 | 364 |
|
363 | 365 | class TestDistLookupTableBase(TranspilerTest):
|
364 | 366 | def network_with_table(self, is_sparse, is_distributed):
|
| 367 | + self.table_size = 1000 |
| 368 | + self.emb_size = 64 |
| 369 | + |
365 | 370 | def emb_pool(ids):
|
366 |
| - table_size = 1000 |
367 |
| - emb_size = 64 |
368 | 371 | emb = fluid.layers.embedding(
|
369 | 372 | input=ids,
|
370 |
| - size=[table_size, emb_size], |
| 373 | + size=[self.table_size, self.emb_size], |
371 | 374 | dtype='float32',
|
372 | 375 | param_attr='shared_w', # share parameter
|
373 | 376 | is_sparse=is_sparse,
|
@@ -536,6 +539,22 @@ def transpiler_test_impl(self):
|
536 | 539 | self.assertEqual([op.type for op in trainer.blocks[0].ops], ops)
|
537 | 540 |
|
538 | 541 |
|
| 542 | +class TestDistLookupTableSliceSize(TestDistLookupTableBase): |
| 543 | + def net_conf(self): |
| 544 | + self.network_with_table(is_sparse=True, is_distributed=True) |
| 545 | + |
| 546 | + def transpiler_test_impl(self): |
| 547 | + config = fluid.DistributeTranspilerConfig() |
| 548 | + pserver1, startup1 = self.get_pserver(self.pserver1_ep, config) |
| 549 | + |
| 550 | + self.assertTrue(self.transpiler.has_distributed_lookup_table) |
| 551 | + lookup_table_var = pserver1.global_block().vars[ |
| 552 | + self.transpiler.table_name] |
| 553 | + row_size = lookup_table_var.shape[0] |
| 554 | + calc_row_size = int(math.ceil(self.table_size / self.pservers)) |
| 555 | + self.assertEqual(row_size, calc_row_size) |
| 556 | + |
| 557 | + |
539 | 558 | class TestRMSPropOptimizer(TranspilerTest):
|
540 | 559 | def net_conf(self):
|
541 | 560 | x = fluid.layers.data(name='x', shape=[1000], dtype='float32')
|
|
0 commit comments