@@ -105,9 +105,8 @@ def log_sync(self, data: dict):
105105 if 'step' in data and 'steps_per_epoch' in data and 'epoch' in data :
106106 # Initialize tqdm on first call (lazy init to avoid early printing)
107107 if self .train_pbar is None :
108- # Simple bar format with ANSI colors - we'll add metrics manually
108+ # Simple bar format with ANSI colors - we'll add epoch and metrics manually
109109 self .train_bar_format = (
110- '\033 [1;34mEpoch {n_fmt}:\033 [0m '
111110 '{bar} '
112111 '\033 [33m{percentage:3.0f}%\033 [0m │ '
113112 '\033 [37m{n}/{total}\033 [0m'
@@ -122,15 +121,15 @@ def log_sync(self, data: dict):
122121 ascii = '━╺─' , # custom characters matching Rich style
123122 disable = True , # disable auto-display, we'll manually call display()
124123 )
125-
124+
126125 # Reset tqdm if we're in a new epoch
127126 current_step_in_epoch = (data ['step' ] - 1 ) % data ['steps_per_epoch' ] + 1
128127 if current_step_in_epoch == 1 :
129128 self .train_pbar .reset (total = data ['steps_per_epoch' ])
130-
129+
131130 # Update tqdm position
132131 self .train_pbar .n = current_step_in_epoch
133-
132+
134133 # Manually format the complete progress line with metrics using format_meter
135134 bar_str = self .train_pbar .format_meter (
136135 n = current_step_in_epoch ,
@@ -140,6 +139,10 @@ def log_sync(self, data: dict):
140139 bar_format = self .train_bar_format ,
141140 ascii = '━╺─' ,
142141 )
142+
143+ # Prepend the epoch number (1-indexed)
144+ epoch_prefix = f'\033 [1;34mEpoch { data ["epoch" ] + 1 } :\033 [0m '
145+ bar_str = epoch_prefix + bar_str
143146
144147 # Add the metrics to the bar string
145148 metrics_str = (
0 commit comments