@@ -16,56 +16,63 @@ def collate_padded(batch: list[dict[str, Any]]) -> dict[str, Any]:
1616 """
1717 Collate function that pads sequences to the longest sample in the batch.
1818
19- Pads 'tokens' with 0 and 'labels' with CROSS_ENTROPY_IGNORE_IDX (-100).
20- Non-tensor fields (like metrics) are collected into lists and flattened
21- if all items are lists.
19+ Handles any tensor keys by padding to the longest
20+ sequence for that key. Uses 0 as default padding value, and
21+ CROSS_ENTROPY_IGNORE_IDX (-100) for 'labels' keys.
22+
23+ Non-tensor fields are collected into lists. The 'metrics' field is
24+ special-cased to be flattened (extended) rather than nested.
2225
2326 Args:
24- batch: List of samples, each containing 'tokens' and 'labels' tensors
27+ batch: List of samples, each containing tensor and non-tensor fields
2528
2629 Returns:
27- Batched dict with padded tensors
30+ Batched dict with padded tensors and collected non-tensor fields
31+
32+ Raises:
33+ ValueError: If all samples do not have the same keys
2834 """
2935 if not batch :
3036 return {}
3137
32- # Find max length in batch
33- max_len = max (sample ["tokens" ].size (0 ) for sample in batch )
38+ # Verify all samples have the same keys
39+ first_sample_keys = batch [0 ].keys ()
40+ for sample in batch :
41+ if sample .keys () != first_sample_keys :
42+ raise ValueError (
43+ f"All samples must have the same keys. Expected { first_sample_keys } , got { sample .keys ()} "
44+ )
3445
35- # Initialize lists for batched tensors
36- tokens_list = []
37- labels_list = []
46+ collated = {}
3847
39- # Pad each sample to max_len
40- for sample in batch :
41- seq_len = sample ["tokens" ].size (0 )
42- pad_len = max_len - seq_len
43-
44- # Pad tokens with 0
45- padded_tokens = F .pad (sample ["tokens" ], (0 , pad_len ), value = 0 )
46- tokens_list .append (padded_tokens )
47-
48- # Pad labels with CROSS_ENTROPY_IGNORE_IDX (-100)
49- padded_labels = F .pad (
50- sample ["labels" ], (0 , pad_len ), value = CROSS_ENTROPY_IGNORE_IDX
51- )
52- labels_list .append (padded_labels )
53-
54- # Stack into batch
55- result = {
56- "tokens" : torch .stack (tokens_list ),
57- "labels" : torch .stack (labels_list ),
58- }
59-
60- # Collect non-tensor fields (like metrics)
61- for key in batch [0 ].keys ():
62- if key not in ["tokens" , "labels" ]:
63- result [key ] = [sample [key ] for sample in batch ]
64- # Flatten if all are lists
65- if all (isinstance (item , list ) for item in result [key ]):
66- result [key ] = [item for sublist in result [key ] for item in sublist ]
67-
68- return result
48+ for key in first_sample_keys :
49+ if isinstance (batch [0 ][key ], torch .Tensor ):
50+ # Find max length for this tensor key
51+ max_len = max (sample [key ].size (0 ) for sample in batch )
52+
53+ # Determine padding value
54+ pad_value = CROSS_ENTROPY_IGNORE_IDX if key == "labels" else 0
55+
56+ # Pad each sample to max_len
57+ padded_tensors = []
58+ for sample in batch :
59+ seq_len = sample [key ].size (0 )
60+ pad_len = max_len - seq_len
61+ padded = F .pad (sample [key ], (0 , pad_len ), value = pad_value )
62+ padded_tensors .append (padded )
63+
64+ # Stack into batch
65+ collated [key ] = torch .stack (padded_tensors )
66+ elif key == "metrics" :
67+ # Flatten metrics lists
68+ collated [key ] = []
69+ for sample in batch :
70+ collated [key ].extend (sample [key ])
71+ else :
72+ # Collect other non-tensor fields as lists
73+ collated [key ] = [sample [key ] for sample in batch ]
74+
75+ return collated
6976
7077
7178def collate_packed (
0 commit comments