Skip to content

Commit 8627ef3

Browse files
committed
refactor: simplify unittest function
1 parent 964d631 commit 8627ef3

File tree

1 file changed

+8
-19
lines changed

1 file changed

+8
-19
lines changed

python/paddle/fluid/tests/unittests/test_memory_usage.py

Lines changed: 8 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -19,7 +19,7 @@
1919
import unittest
2020

2121

22-
def train_simulator(use_cuda, test_batch_size=10):
22+
def train_simulator(test_batch_size=10):
2323
if test_batch_size <= 0:
2424
raise ValueError("batch_size should be a positive integeral value, "
2525
"but got batch_size={}".format(test_batch_size))
@@ -34,14 +34,7 @@ def train_simulator(use_cuda, test_batch_size=10):
3434
sgd_optimizer = fluid.optimizer.SGD(learning_rate=0.001)
3535
sgd_optimizer.minimize(avg_cost)
3636

37-
train_reader = paddle.batch(
38-
paddle.reader.shuffle(
39-
paddle.dataset.uci_housing.train(), buf_size=500),
40-
batch_size=test_batch_size)
41-
42-
place = fluid.CUDAPlace(0) if use_cuda else fluid.CPUPlace()
43-
exe = fluid.Executor(place)
44-
37+
# Calculate memory usage in current network config
4538
lower_usage, upper_usage, unit = fluid.contrib.memory_usage(
4639
fluid.default_main_program(), batch_size=test_batch_size)
4740

@@ -50,21 +43,17 @@ def train_simulator(use_cuda, test_batch_size=10):
5043

5144

5245
class TestMemoryUsage(unittest.TestCase):
53-
def test_cpu(self):
54-
with self.program_scope_guard():
55-
train_simulator(use_cuda=False)
56-
57-
def test_cpu_with_unit_KB(self):
46+
def test_with_unit_B(self):
5847
with self.program_scope_guard():
59-
train_simulator(use_cuda=False, test_batch_size=1000)
48+
train_simulator()
6049

61-
def test_cpu_with_unit_MB(self):
50+
def test_with_unit_KB(self):
6251
with self.program_scope_guard():
63-
train_simulator(use_cuda=False, test_batch_size=100000)
52+
train_simulator(test_batch_size=1000)
6453

65-
def test_cuda(self):
54+
def test_with_unit_MB(self):
6655
with self.program_scope_guard():
67-
train_simulator(use_cuda=True)
56+
train_simulator(test_batch_size=100000)
6857

6958
@contextlib.contextmanager
7059
def program_scope_guard(self):

0 commit comments

Comments
 (0)