44This module provides a wrapper collator that extracts dataset source information
55for progress tracking during training.
66"""
7- from typing import Any , Callable , Dict , List
7+ from typing import Any , Callable , Dict , List , Optional , Tuple
88
99
1010class ProgressTrackingCollator :
@@ -32,6 +32,30 @@ def __init__(self, collator: Callable, track_progress: bool = True):
3232 self .collator = collator
3333 self .track_progress = track_progress
3434
35+ def _extract_info (self , item : Any ) -> Tuple [Optional [Any ], Optional [int ]]:
36+ """Extract and remove _dataset_source, extract length from item."""
37+ if isinstance (item , dict ):
38+ sources = item .pop ('_dataset_source' , None )
39+ length = item .get ('length' )
40+ return sources , length
41+ return None , None
42+
43+ def _collect_sources_and_lengths (
44+ self ,
45+ sources : Optional [Any ],
46+ length : Optional [int ],
47+ batch_sources : List [str ],
48+ batch_lengths : List [int ],
49+ ) -> None :
50+ """Collect sources and lengths into batch lists."""
51+ if self .track_progress and sources :
52+ if isinstance (sources , str ):
53+ batch_sources .append (sources )
54+ elif isinstance (sources , list ):
55+ batch_sources .extend (sources )
56+ if length is not None :
57+ batch_lengths .append (length )
58+
3559 def __call__ (self , batch : List [Dict [str , Any ]]) -> Dict [str , Any ]:
3660 """Process batch and extract dataset sources and token lengths.
3761
@@ -43,41 +67,15 @@ def __call__(self, batch: List[Dict[str, Any]]) -> Dict[str, Any]:
4367 """
4468 # 1. Collect sources and lengths before calling original collator
4569 # (original collator may modify batch in place)
46- batch_sources = []
47- batch_lengths = []
48-
49- def _extract_info (item ):
50- if isinstance (item , dict ):
51- sources = item .get ('_dataset_source' )
52- if sources is not None :
53- del item ['_dataset_source' ]
54- # Extract length but don't delete it (may be needed later)
55- length = item .get ('length' )
56- return sources , length
57- return None , None
70+ batch_sources : List [str ] = []
71+ batch_lengths : List [int ] = []
5872
5973 for b in batch :
60- # Handle Packing scenario where batch item is a list of samples
61- if isinstance (b , list ):
62- for sub_item in b :
63- sources , length = _extract_info (sub_item )
64- if self .track_progress and sources :
65- if isinstance (sources , str ):
66- batch_sources .append (sources )
67- elif isinstance (sources , list ):
68- batch_sources .extend (sources )
69- if length is not None :
70- batch_lengths .append (length )
71- # Handle normal scenario where batch item is a single sample dict
72- elif isinstance (b , dict ):
73- sources , length = _extract_info (b )
74- if self .track_progress and sources :
75- if isinstance (sources , str ):
76- batch_sources .append (sources )
77- elif isinstance (sources , list ):
78- batch_sources .extend (sources )
79- if length is not None :
80- batch_lengths .append (length )
74+ # Handle both Packing scenario (list) and normal scenario (dict)
75+ items = b if isinstance (b , list ) else [b ]
76+ for item in items :
77+ sources , length = self ._extract_info (item )
78+ self ._collect_sources_and_lengths (sources , length , batch_sources , batch_lengths )
8179
8280 # 2. Call original collator
8381 result = self .collator (batch )
0 commit comments