Skip to content

Commit 33f7fa4

Browse files
committed
Merge pull request #14 from lucasb-eyer/example-pb
Factor out progress-bar to examples.utils.
2 parents 5ef5d77 + 6c4621b commit 33f7fa4

File tree

9 files changed

+55
-24
lines changed

9 files changed

+55
-24
lines changed

examples/Kaggle-Otto/progress_bar.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

examples/Kaggle-Otto/run.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -43,6 +43,9 @@ def nnet():
4343
return model
4444

4545
if __name__ == "__main__":
46+
if __package__ is None: # PEP366
47+
__package__ = "beacon8.examples.KaggleOtto"
48+
4649
train_data_x, train_data_y = load_train_data()
4750

4851
train_data_x, test_data_x, train_data_y, test_data_y = train_test_split(train_data_x, train_data_y, train_size=0.85)

examples/Kaggle-Otto/test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,12 @@
11
import numpy as np
2-
from progress_bar import *
32
import theano as _th
43
from sklearn.metrics import log_loss
54
from kaggle_utils import *
65

6+
from examples.utils import make_progressbar
7+
78
def validate(dataset_x, dataset_y, model, epoch, batch_size):
8-
progress = make_progressbar('Testing', epoch, len(dataset_x))
9+
progress = make_progressbar('Testing epoch #{}'.format(epoch), len(dataset_x))
910
progress.start()
1011

1112
mini_batch_input = np.empty(shape=(batch_size, 93), dtype=_th.config.floatX)

examples/Kaggle-Otto/train.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
11
import numpy as np
2-
from progress_bar import *
32
import theano as _th
43

4+
from examples.utils import make_progressbar
55

66
def train(dataset_x, dataset_y, model, optimiser, criterion, epoch, batch_size, mode=None):
7-
progress = make_progressbar('Training', epoch, len(dataset_x))
7+
progress = make_progressbar('Training epoch #{}'.format(epoch), len(dataset_x))
88
progress.start()
99

1010
shuffle = np.random.permutation(len(dataset_x))
@@ -24,6 +24,6 @@ def train(dataset_x, dataset_y, model, optimiser, criterion, epoch, batch_size,
2424
else:
2525
model.accumulate_statistics(mini_batch_input)
2626

27-
progress.update(j * batch_size)
27+
progress.update((j+1) * batch_size)
2828

2929
progress.finish()

examples/MNIST/progress_bar.py

Lines changed: 0 additions & 7 deletions
This file was deleted.

examples/MNIST/run.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,9 @@ def main(params):
2626

2727

2828
if __name__ == "__main__":
29+
if __package__ is None: # PEP366
30+
__package__ = "beacon8.examples.MNIST"
31+
2932
params = {}
3033
params['lr'] = 0.1
3134
params['batch_size'] = 64

examples/MNIST/test.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,10 @@
11
import numpy as np
2-
from progress_bar import *
32
import theano as _th
43

4+
from examples.utils import make_progressbar
5+
56
def validate(dataset_x, dataset_y, model, epoch, batch_size):
6-
progress = make_progressbar('Testing', epoch, len(dataset_x))
7+
progress = make_progressbar('Testing epoch #{}'.format(epoch), len(dataset_x))
78
progress.start()
89

910
mini_batch_input = np.empty(shape=(batch_size, 28*28), dtype=_th.config.floatX)

examples/MNIST/train.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,11 @@
11
import numpy as np
2-
from progress_bar import *
32
import theano as _th
43

4+
from examples.utils import make_progressbar
5+
56

67
def train(dataset_x, dataset_y, model, optimiser, criterion, epoch, batch_size, mode='train'):
7-
progress = make_progressbar('Training ({})'.format(mode), epoch, len(dataset_x))
8+
progress = make_progressbar('Training ({}) epoch #{}'.format(mode, epoch), len(dataset_x))
89
progress.start()
910

1011
shuffle = np.random.permutation(len(dataset_x))
@@ -26,6 +27,6 @@ def train(dataset_x, dataset_y, model, optimiser, criterion, epoch, batch_size,
2627
else:
2728
assert False, "Mode should be either 'train' or 'stats'"
2829

29-
progress.update(j * batch_size)
30+
progress.update((j+1) * batch_size)
3031

3132
progress.finish()

examples/utils.py

Lines changed: 36 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,36 @@
1+
import sys as _sys
2+
3+
4+
# Progressbar
5+
##############
6+
7+
8+
try:
9+
import progressbar as _pb
10+
11+
def make_progressbar(prefix, data_size):
12+
widgets = [prefix, ', processed ', _pb.Counter(), ' of ', str(data_size),
13+
' (', _pb.Percentage(), ')', ' ', _pb.Bar(), ' ', _pb.ETA()]
14+
return _pb.ProgressBar(maxval=data_size, widgets=widgets)
15+
16+
except ImportError:
17+
18+
class SimpleProgressBar(object):
19+
def __init__(self, tot, fmt):
20+
self.tot = tot
21+
self.fmt = fmt
22+
23+
def start(self):
24+
self.update(0)
25+
26+
def update(self, i):
27+
_sys.stdout.write("\r" + self.fmt.format(i=i, tot=self.tot, pct=float(i)/self.tot))
28+
_sys.stdout.flush()
29+
self.lasti = i
30+
31+
def finish(self):
32+
_sys.stdout.write("\r" + self.fmt.format(i=self.lasti, tot=self.tot, pct=1.0) + "\n")
33+
_sys.stdout.flush()
34+
35+
def make_progressbar(prefix, data_size):
36+
return SimpleProgressBar(data_size, prefix + ", processed {i} of {tot} ({pct:.2%})")

0 commit comments

Comments
 (0)