Skip to content

Commit 757aa01

Browse files
Wrap test
1 parent 2cb2585 commit 757aa01

File tree

1 file changed

+16
-15
lines changed

1 file changed

+16
-15
lines changed

tests/extensions/test_saveload.py

Lines changed: 16 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -42,20 +42,21 @@ def test_checkpointing():
4242
algorithm=GradientDescent(cost=cost, parameters=[W]),
4343
extensions=[Load('myweirdmodel.tar')]
4444
)
45-
main_loop.extensions[0].main_loop = main_loop
46-
main_loop._run_extensions('before_training')
47-
assert_allclose(W.get_value(), old_value)
45+
with main_loop.log:
46+
main_loop.extensions[0].main_loop = main_loop
47+
main_loop._run_extensions('before_training')
48+
assert_allclose(W.get_value(), old_value)
4849

49-
# Make sure things work too if the model was never saved before
50-
main_loop = MainLoop(
51-
model=Model(cost),
52-
data_stream=data_stream,
53-
algorithm=GradientDescent(cost=cost, parameters=[W]),
54-
extensions=[Load('mynonexisting.tar')]
55-
)
56-
main_loop.extensions[0].main_loop = main_loop
57-
main_loop._run_extensions('before_training')
50+
# Make sure things work too if the model was never saved before
51+
main_loop = MainLoop(
52+
model=Model(cost),
53+
data_stream=data_stream,
54+
algorithm=GradientDescent(cost=cost, parameters=[W]),
55+
extensions=[Load('mynonexisting.tar')]
56+
)
57+
main_loop.extensions[0].main_loop = main_loop
58+
main_loop._run_extensions('before_training')
5859

59-
# Cleaning
60-
if os.path.exists('myweirdmodel.tar'):
61-
os.remove('myweirdmodel.tar')
60+
# Cleaning
61+
if os.path.exists('myweirdmodel.tar'):
62+
os.remove('myweirdmodel.tar')

0 commit comments

Comments
 (0)