@@ -42,20 +42,21 @@ def test_checkpointing():
4242 algorithm = GradientDescent (cost = cost , parameters = [W ]),
4343 extensions = [Load ('myweirdmodel.tar' )]
4444 )
45- main_loop .extensions [0 ].main_loop = main_loop
46- main_loop ._run_extensions ('before_training' )
47- assert_allclose (W .get_value (), old_value )
45+ with main_loop .log :
46+ main_loop .extensions [0 ].main_loop = main_loop
47+ main_loop ._run_extensions ('before_training' )
48+ assert_allclose (W .get_value (), old_value )
4849
49- # Make sure things work too if the model was never saved before
50- main_loop = MainLoop (
51- model = Model (cost ),
52- data_stream = data_stream ,
53- algorithm = GradientDescent (cost = cost , parameters = [W ]),
54- extensions = [Load ('mynonexisting.tar' )]
55- )
56- main_loop .extensions [0 ].main_loop = main_loop
57- main_loop ._run_extensions ('before_training' )
50+ # Make sure things work too if the model was never saved before
51+ main_loop = MainLoop (
52+ model = Model (cost ),
53+ data_stream = data_stream ,
54+ algorithm = GradientDescent (cost = cost , parameters = [W ]),
55+ extensions = [Load ('mynonexisting.tar' )]
56+ )
57+ main_loop .extensions [0 ].main_loop = main_loop
58+ main_loop ._run_extensions ('before_training' )
5859
59- # Cleaning
60- if os .path .exists ('myweirdmodel.tar' ):
61- os .remove ('myweirdmodel.tar' )
60+ # Cleaning
61+ if os .path .exists ('myweirdmodel.tar' ):
62+ os .remove ('myweirdmodel.tar' )
0 commit comments