Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions blocks/extensions/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -447,13 +447,14 @@ def do(self, which_callback, *args):

class Printing(SimpleExtension):
"""Prints log messages to the screen."""
def __init__(self, **kwargs):
def __init__(self, print_status=True, **kwargs):
kwargs.setdefault("before_first_epoch", True)
kwargs.setdefault("on_resumption", True)
kwargs.setdefault("after_training", True)
kwargs.setdefault("after_epoch", True)
kwargs.setdefault("on_interrupt", True)
super(Printing, self).__init__(**kwargs)
self.print_status = print_status

def _print_attributes(self, attribute_tuples):
for attr, value in sorted(attribute_tuples.items(), key=first):
Expand All @@ -462,7 +463,6 @@ def _print_attributes(self, attribute_tuples):

def do(self, which_callback, *args):
log = self.main_loop.log
print_status = True

print()
print("".join(79 * "-"))
Expand All @@ -478,12 +478,12 @@ def do(self, which_callback, *args):
print("TRAINING HAS BEEN INTERRUPTED")
print_status = False
print("".join(79 * "-"))
if print_status:
if self.print_status:
print("Training status:")
self._print_attributes(log.status)
print("Log records from the iteration {}:".format(
log.status['iterations_done']))
self._print_attributes(log.current_row)
print("Log records from the iteration {}:".format(
log.status['iterations_done']))
self._print_attributes(log.current_row)
print()


Expand Down
40 changes: 40 additions & 0 deletions tests/extensions/test_printing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,40 @@
import sys
from six import StringIO

from blocks.extensions import FinishAfter, Printing
from blocks.utils.testing import MockMainLoop


def setup_mainloop(print_status):
epochs = 1
main_loop = MockMainLoop(extensions=[Printing(print_status),
FinishAfter(after_n_epochs=epochs)])
return main_loop

def test_printing_status():
main_loop = setup_mainloop(print_status=True)

stdout = sys.stdout
try:
saved_out = StringIO()
sys.stdout = saved_out

main_loop.run()
output = saved_out.getvalue()
assert 'Training status' in output
finally:
sys.stdout = stdout

def test_printing_no_status():
main_loop = setup_mainloop(print_status=False)

stdout = sys.stdout
try:
saved_out = StringIO()
sys.stdout = saved_out

main_loop.run()
output = saved_out.getvalue()
assert 'Training status' not in output
finally:
sys.stdout = stdout