Skip to content

Commit 801a454

Browse files
author
Hossein Kavianihamedani
committed
Simplify epoch tracking and fix metric extraction
- Fix extract_epoch_from_batch() to use 'key' attribute instead of 'metric_name' - Simplify epoch tracking: compare consecutive batches instead of tracking from start - Remove starting_epoch variable - no longer needed - Update start_epoch_sync() to use boolean epoch_changed instead of epoch_increment - Add better logging for epoch changes and tracking status - Epoch sync now works correctly with the actual metric structure
1 parent db35980 commit 801a454

File tree

1 file changed

+39
-14
lines changed

1 file changed

+39
-14
lines changed

apps/sft/eval_utils.py

Lines changed: 39 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -35,21 +35,30 @@ def extract_epoch_from_batch(batch: dict) -> int | None:
3535
Epoch number from metrics, or None if not found
3636
"""
3737
if "metrics" in batch:
38+
# Look for num_epochs in metric keys
39+
for metric in batch["metrics"]:
40+
# Metrics have a 'key' attribute with paths like:
41+
# 'dataset/yahma_alpaca-cleaned_train[:1%]/num_epochs'
42+
if hasattr(metric, "key") and "num_epochs" in metric.key:
43+
return int(metric.value)
44+
45+
# Fallback: check for old-style metric_name attribute
3846
for metric in batch["metrics"]:
3947
if hasattr(metric, "metric_name") and metric.metric_name == "num_epochs":
40-
return metric.value
48+
return int(metric.value)
49+
4150
return None
4251

4352

4453
def start_epoch_sync(
45-
epoch_increment: int,
54+
epoch_changed: bool,
4655
device: torch.device,
4756
dp_process_group: Any = None,
4857
) -> tuple[torch.Tensor | None, Any]:
4958
"""Start async all_reduce for epoch synchronization across ranks.
5059
5160
Args:
52-
epoch_increment: Difference between current and starting epoch
61+
epoch_changed: Whether the epoch changed on this rank
5362
device: Device for tensor
5463
dp_process_group: Data parallel process group (None = default group)
5564
@@ -59,7 +68,8 @@ def start_epoch_sync(
5968
if not torch.distributed.is_initialized():
6069
return None, None
6170

62-
epoch_tensor = torch.tensor([epoch_increment], dtype=torch.long, device=device)
71+
# Convert bool to tensor: 1 if epoch changed, 0 otherwise
72+
epoch_tensor = torch.tensor([int(epoch_changed)], dtype=torch.long, device=device)
6373
pending_work = torch.distributed.all_reduce(
6474
epoch_tensor,
6575
op=torch.distributed.ReduceOp.MAX,
@@ -117,7 +127,7 @@ def eval_loop(
117127
Tuple of (avg_loss, num_batches)
118128
"""
119129
total_loss = torch.tensor(0.0, device=device)
120-
num_batches, starting_epoch = 0, None
130+
num_batches = 0
121131

122132
# Prefetch first batch
123133
next_batch = next(dataloader_iter)
@@ -142,26 +152,41 @@ def eval_loop(
142152

143153
batch = next_batch
144154

145-
# Track starting epoch
155+
# Get current batch epoch
146156
current_epoch = extract_epoch_fn(batch)
147-
if starting_epoch is None:
148-
starting_epoch = current_epoch
149157

150-
# Prefetch next batch and start async epoch check
158+
# Prefetch next batch and check for epoch change
151159
try:
152160
next_batch = next(dataloader_iter)
153161
next_epoch = extract_epoch_fn(next_batch)
154162

155-
# Only check epochs if both are available
156-
if next_epoch is not None and starting_epoch is not None:
157-
epoch_increment = next_epoch - starting_epoch
163+
# Simple check: did epoch change between consecutive batches?
164+
if next_epoch is not None and current_epoch is not None:
165+
epoch_changed = next_epoch > current_epoch
166+
167+
if epoch_changed:
168+
logger.info(
169+
f"[{dataset_name}] Epoch change detected: "
170+
f"{current_epoch}{next_epoch}"
171+
)
172+
158173
if torch.distributed.is_initialized():
174+
# All-reduce: if ANY rank's epoch changed, all ranks should stop
159175
epoch_tensor, pending_work = start_epoch_sync(
160-
epoch_increment, device, dp_process_group
176+
epoch_changed, device, dp_process_group
161177
)
162178
else:
163-
should_break = epoch_increment > 0
179+
# Single process: stop immediately if epoch changed
180+
should_break = epoch_changed
181+
else:
182+
# No epoch tracking available - rely on eval_steps
183+
if num_batches == 0:
184+
logger.info(
185+
f"[{dataset_name}] No epoch tracking available "
186+
f"(current={current_epoch}, next={next_epoch})"
187+
)
164188
except StopIteration:
189+
logger.info(f"[{dataset_name}] StopIteration - dataloader exhausted")
165190
should_break = True
166191

167192
# Process current batch (overlaps with async all_reduce)

0 commit comments

Comments
 (0)