Skip to content

Commit 2bc05b2

Browse files
author
liuxiaoming
committed
refactor(collator): improve ProgressTrackingCollator per code review
- Move _extract_info from inner function to class method to avoid redefinition overhead on each __call__ invocation - Extract duplicated sources/lengths collection logic into _collect_sources_and_lengths method - Use item.pop() instead of get() + del for cleaner code - Add type hints for better code clarity
1 parent 7af05da commit 2bc05b2

File tree

1 file changed

+32
-34
lines changed

1 file changed

+32
-34
lines changed

swift/llm/dataset/collator.py

Lines changed: 32 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
This module provides a wrapper collator that extracts dataset source information
55
for progress tracking during training.
66
"""
7-
from typing import Any, Callable, Dict, List
7+
from typing import Any, Callable, Dict, List, Optional, Tuple
88

99

1010
class ProgressTrackingCollator:
@@ -32,6 +32,30 @@ def __init__(self, collator: Callable, track_progress: bool = True):
3232
self.collator = collator
3333
self.track_progress = track_progress
3434

35+
def _extract_info(self, item: Any) -> Tuple[Optional[Any], Optional[int]]:
36+
"""Extract and remove _dataset_source, extract length from item."""
37+
if isinstance(item, dict):
38+
sources = item.pop('_dataset_source', None)
39+
length = item.get('length')
40+
return sources, length
41+
return None, None
42+
43+
def _collect_sources_and_lengths(
44+
self,
45+
sources: Optional[Any],
46+
length: Optional[int],
47+
batch_sources: List[str],
48+
batch_lengths: List[int],
49+
) -> None:
50+
"""Collect sources and lengths into batch lists."""
51+
if self.track_progress and sources:
52+
if isinstance(sources, str):
53+
batch_sources.append(sources)
54+
elif isinstance(sources, list):
55+
batch_sources.extend(sources)
56+
if length is not None:
57+
batch_lengths.append(length)
58+
3559
def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
3660
"""Process batch and extract dataset sources and token lengths.
3761
@@ -43,41 +67,15 @@ def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
4367
"""
4468
# 1. Collect sources and lengths before calling original collator
4569
# (original collator may modify batch in place)
46-
batch_sources = []
47-
batch_lengths = []
48-
49-
def _extract_info(item):
50-
if isinstance(item, dict):
51-
sources = item.get('_dataset_source')
52-
if sources is not None:
53-
del item['_dataset_source']
54-
# Extract length but don't delete it (may be needed later)
55-
length = item.get('length')
56-
return sources, length
57-
return None, None
70+
batch_sources: List[str] = []
71+
batch_lengths: List[int] = []
5872

5973
for b in batch:
60-
# Handle Packing scenario where batch item is a list of samples
61-
if isinstance(b, list):
62-
for sub_item in b:
63-
sources, length = _extract_info(sub_item)
64-
if self.track_progress and sources:
65-
if isinstance(sources, str):
66-
batch_sources.append(sources)
67-
elif isinstance(sources, list):
68-
batch_sources.extend(sources)
69-
if length is not None:
70-
batch_lengths.append(length)
71-
# Handle normal scenario where batch item is a single sample dict
72-
elif isinstance(b, dict):
73-
sources, length = _extract_info(b)
74-
if self.track_progress and sources:
75-
if isinstance(sources, str):
76-
batch_sources.append(sources)
77-
elif isinstance(sources, list):
78-
batch_sources.extend(sources)
79-
if length is not None:
80-
batch_lengths.append(length)
74+
# Handle both Packing scenario (list) and normal scenario (dict)
75+
items = b if isinstance(b, list) else [b]
76+
for item in items:
77+
sources, length = self._extract_info(item)
78+
self._collect_sources_and_lengths(sources, length, batch_sources, batch_lengths)
8179

8280
# 2. Call original collator
8381
result = self.collator(batch)

0 commit comments

Comments
 (0)