Skip to content

Commit c232d5c

Browse files
authored
Fixes progress bar epoch counter (#54)
* epoch bar now prints the current epoch instead of the step * removes unused imports
1 parent c1b210d commit c232d5c

File tree

3 files changed

+8
-11
lines changed

3 files changed

+8
-11
lines changed

src/mini_trainer/async_structured_logger.py

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -105,9 +105,8 @@ def log_sync(self, data: dict):
105105
if 'step' in data and 'steps_per_epoch' in data and 'epoch' in data:
106106
# Initialize tqdm on first call (lazy init to avoid early printing)
107107
if self.train_pbar is None:
108-
# Simple bar format with ANSI colors - we'll add metrics manually
108+
# Simple bar format with ANSI colors - we'll add epoch and metrics manually
109109
self.train_bar_format = (
110-
'\033[1;34mEpoch {n_fmt}:\033[0m '
111110
'{bar} '
112111
'\033[33m{percentage:3.0f}%\033[0m │ '
113112
'\033[37m{n}/{total}\033[0m'
@@ -122,15 +121,15 @@ def log_sync(self, data: dict):
122121
ascii='━╺─', # custom characters matching Rich style
123122
disable=True, # disable auto-display, we'll manually call display()
124123
)
125-
124+
126125
# Reset tqdm if we're in a new epoch
127126
current_step_in_epoch = (data['step'] - 1) % data['steps_per_epoch'] + 1
128127
if current_step_in_epoch == 1:
129128
self.train_pbar.reset(total=data['steps_per_epoch'])
130-
129+
131130
# Update tqdm position
132131
self.train_pbar.n = current_step_in_epoch
133-
132+
134133
# Manually format the complete progress line with metrics using format_meter
135134
bar_str = self.train_pbar.format_meter(
136135
n=current_step_in_epoch,
@@ -140,6 +139,10 @@ def log_sync(self, data: dict):
140139
bar_format=self.train_bar_format,
141140
ascii='━╺─',
142141
)
142+
143+
# Prepend the epoch number (1-indexed)
144+
epoch_prefix = f'\033[1;34mEpoch {data["epoch"] + 1}:\033[0m '
145+
bar_str = epoch_prefix + bar_str
143146

144147
# Add the metrics to the bar string
145148
metrics_str = (

src/mini_trainer/sampler.py

Lines changed: 0 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,12 +24,7 @@
2424
same number of minibatches.
2525
"""
2626
from deprecated import deprecated
27-
from itertools import chain
28-
import json
2927
import os
30-
import pytest
31-
import tempfile
32-
from unittest.mock import patch
3328

3429
import torch
3530
from torch.utils.data import Sampler, Dataset, DataLoader, SequentialSampler

tests/gpu_tests/conftest.py

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,6 @@
11
"""Pytest configuration for GPU tests."""
22
import pytest
33
import torch
4-
import os
54
import sys
65
from pathlib import Path
76

0 commit comments

Comments
 (0)