14
14
15
15
import math
16
16
import unittest
17
- from paddle .fluid .transpiler .distribute_transpiler import split_variable
17
+ from paddle .fluid .transpiler .distribute_transpiler import slice_variable
18
18
import paddle .fluid as fluid
19
19
import paddle .fluid .core as core
20
20
import random
21
21
22
22
23
- class TestSplitVar (unittest .TestCase ):
24
- def check_split_output (self , shapes , expected_sizes , min_size ):
23
+ class TestSliceVar (unittest .TestCase ):
24
+ def check_slice_output (self , shapes , expected_sizes , min_size ):
25
25
var_list = []
26
26
program = fluid .Program ()
27
27
for shape in shapes :
@@ -31,7 +31,7 @@ def check_split_output(self, shapes, expected_sizes, min_size):
31
31
# dtype=core.VarDesc.VarType.LOD_TENSOR,
32
32
shape = shape )
33
33
var_list .append (var )
34
- blocks = split_variable (var_list , 10 , min_size )
34
+ blocks = slice_variable (var_list , 10 , min_size )
35
35
all_sizes = []
36
36
for s in expected_sizes :
37
37
for s2 in s :
@@ -49,15 +49,15 @@ def test_1k(self):
49
49
[1150 , 1150 , 1150 , 1150 , 1150 , 1150 , 1100 ]
50
50
]
51
51
52
- self .check_split_output (shapes , expected_sizes , 1024 )
52
+ self .check_slice_output (shapes , expected_sizes , 1024 )
53
53
54
54
def test_check_output_8k (self ):
55
55
shapes = [[3 , 5 ], [1024 ], [28 , 784 ], [8 , 1020 ], [800 , 10 ],
56
56
[6 , 33 , 33 , 33 ]]
57
57
expected_sizes = [[15 ], [1024 ], [10976 , 10976 ], [8160 ], [8000 ],
58
58
[35937 , 35937 , 35937 , 35937 , 35937 , 35937 ]]
59
59
60
- self .check_split_output (shapes , expected_sizes , 8192 )
60
+ self .check_slice_output (shapes , expected_sizes , 8192 )
61
61
62
62
63
63
if __name__ == '__main__' :
0 commit comments