Skip to content

Commit df54cbd

Browse files
feat: TQDM added to trainers (#1593)
Co-authored-by: Fedor Ignatov <[email protected]>
1 parent a54b265 commit df54cbd

File tree

2 files changed

+9
-4
lines changed

2 files changed

+9
-4
lines changed

deeppavlov/core/trainers/fit_trainer.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,8 @@
1919
from logging import getLogger
2020
from typing import Tuple, Dict, Union, Optional, Iterable, Any, Collection
2121

22+
from tqdm import tqdm
23+
2224
from deeppavlov.core.commands.infer import build_model
2325
from deeppavlov.core.common.chainer import Chainer
2426
from deeppavlov.core.common.params import from_params
@@ -90,7 +92,7 @@ def fit_chainer(self, iterator: Union[DataFittingIterator, DataLearningIterator]
9092
targets = [targets]
9193

9294
if self.batch_size > 0 and callable(getattr(component, 'partial_fit', None)):
93-
for i, (x, y) in enumerate(iterator.gen_batches(self.batch_size, shuffle=False)):
95+
for i, (x, y) in tqdm(enumerate(iterator.gen_batches(self.batch_size, shuffle=False))):
9496
preprocessed = self._chainer.compute(x, y, targets=targets)
9597
# noinspection PyUnresolvedReferences
9698
component.partial_fit(*preprocessed)
@@ -160,7 +162,7 @@ def test(self, data: Iterable[Tuple[Collection[Any], Collection[Any]]],
160162

161163
data = islice(data, self.max_test_batches)
162164

163-
for x, y_true in data:
165+
for x, y_true in tqdm(data):
164166
examples += len(x)
165167
y_predicted = list(self._chainer.compute(list(x), list(y_true), targets=expected_outputs))
166168
if len(expected_outputs) == 1:

deeppavlov/core/trainers/nn_trainer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -20,12 +20,15 @@
2020
from pathlib import Path
2121
from typing import List, Tuple, Union, Optional, Iterable
2222

23+
from tqdm import tqdm
24+
2325
from deeppavlov.core.common.errors import ConfigError
26+
from deeppavlov.core.common.log_events import get_tb_writer
2427
from deeppavlov.core.common.registry import register
2528
from deeppavlov.core.data.data_learning_iterator import DataLearningIterator
2629
from deeppavlov.core.trainers.fit_trainer import FitTrainer
2730
from deeppavlov.core.trainers.utils import parse_metrics, NumpyArrayEncoder
28-
from deeppavlov.core.common.log_events import get_tb_writer
31+
2932
log = getLogger(__name__)
3033
report_log = getLogger('train_report')
3134

@@ -273,7 +276,7 @@ def train_on_batches(self, iterator: DataLearningIterator) -> None:
273276
while True:
274277
impatient = False
275278
self._send_event(event_name='before_train')
276-
for x, y_true in iterator.gen_batches(self.batch_size, data_type='train'):
279+
for x, y_true in tqdm(iterator.gen_batches(self.batch_size, data_type='train')):
277280
self.last_result = self._chainer.train_on_batch(x, y_true)
278281
if self.last_result is None:
279282
self.last_result = {}

0 commit comments

Comments
 (0)