-
Notifications
You must be signed in to change notification settings - Fork 272
Expand file tree
/
Copy pathbatched_data_dict.py
More file actions
873 lines (760 loc) · 39.3 KB
/
batched_data_dict.py
File metadata and controls
873 lines (760 loc) · 39.3 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
# Copyright (c) 2025, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import UserDict
from copy import deepcopy
from typing import (
Any,
Generic,
Iterator,
Mapping,
Optional,
Sequence,
Type,
TypedDict,
TypeVar,
Union,
)
import torch
from typing_extensions import Self
from nemo_rl.data.multimodal_utils import (
PackedTensor,
)
from nemo_rl.data.packing import get_packer
from nemo_rl.distributed.collectives import (
gather_jagged_object_lists,
rebalance_nd_tensor,
)
DictT = TypeVar("DictT", bound=Mapping[str, Any])
class SequencePackingArgs(TypedDict):
"""Configuration settings for sequence packing.
Pass this to 'shard_by_batch_size()' to preprocess batches for sequence packing.
"""
max_tokens_per_microbatch: int
input_key: str
input_lengths_key: str
algorithm: str
sequence_length_pad_multiple: (
int # pad each sequence to a multiple of this value (for CP/TP alignment)
)
class DynamicBatchingArgs(TypedDict):
"""Configuration settings for dynamic batching.
Pass this to 'shard_by_batch_size()' to preprocess batches for dynamic batching.
"""
max_tokens_per_microbatch: int # Each microbatch contains at most this many tokens
sequence_length_round: (
int # Round each microbatch's sequence length to this multiple
)
input_key: str # The key in the data dict that specifics the input ids
input_lengths_key: (
str # The key in the data dict that specifies the sequence length per datum
)
class BatchedDataDict(UserDict, Generic[DictT]):
# keys that are model specific, but not part of the PackedTensor
ADDITIONAL_OPTIONAL_KEY_TENSORS = [
"token_type_ids", # specific to gemma3 that tells where the image tokens are in the sequence, not required for llm-only inference/training
]
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.micro_batch_indices = None
self.micro_batch_lengths = None
self.elem_counts_per_gb = None
def get_multimodal_dict(
self, as_tensors: bool = False, device: Optional[torch.device] = None
) -> dict[str, Any]:
"""Return a regular dict of tensors or packed multimodal data items."""
multimodal_dict = {}
for k, v in self.data.items():
if isinstance(v, PackedTensor):
multimodal_dict[k] = v.as_tensor(device=device) if as_tensors else v
elif k in self.ADDITIONAL_OPTIONAL_KEY_TENSORS:
multimodal_dict[k] = v
return multimodal_dict
@classmethod
def from_batches(
cls: Type[Self],
batches: Sequence[Mapping[Any, Any]],
pad_value_dict: Optional[dict[str, int | float]] = None,
) -> Self:
"""Given a list of batches, stack the tensors/lists within and put them in a single dictionary.
Pad sequences to the max length in the batch using either 0(default) or a non-default value for a given key provided in pad_value_dict.
Args:
batches (list[Dict]): A list of dictionaries, each containing a batch of data.
pad_value_dict (Optional[dict[str, int]]): An optional dict mapping keys to non-default(0) padding values.
Returns:
BatchedDataDict: A new BatchedDataDict containing the stacked data.
"""
stacked_dict: Self = cls()
pad_value_dict = pad_value_dict or {}
for k in sorted(batches[0]):
list_of_tensors = [item[k] for item in batches]
if isinstance(list_of_tensors[0], list):
tensor_or_list: list[Any] | torch.Tensor = [
item for sublist in list_of_tensors for item in sublist
]
elif isinstance(list_of_tensors[0], PackedTensor):
tensor_or_list = PackedTensor.concat(list_of_tensors)
elif all(x.ndim == 1 for x in list_of_tensors):
tensor_or_list = torch.cat(list_of_tensors)
elif isinstance(list_of_tensors[0], torch.Tensor):
pad_value = pad_value_dict.get(k, 0)
# We now add the following if statement to handle the 3D case in distillation
# (i.e., teacher top-k logits and indices); the else branch is the original code.
if list_of_tensors[0].ndim == 3:
# For 3D tensors, pad only along the sequence dimension (the 1st dimension here),
# keeping the feature dimension.
max_seq_len = max(tensor.shape[1] for tensor in list_of_tensors)
padded_tensors = []
for tensor in list_of_tensors:
# Pad along the 1st dimension to max_seq_len.
pad_length = max_seq_len - tensor.shape[1]
padded = torch.nn.functional.pad(
tensor,
# Only pad the last two dimensions (sequence length).
(0, 0, 0, pad_length),
mode="constant",
value=pad_value,
)
padded_tensors.append(padded)
tensor_or_list = torch.cat(
padded_tensors, dim=0
) # concatenate along the batch dimension
else:
list_of_tensors = [
row.flatten() for tensor in list_of_tensors for row in tensor
]
# TODO: can we avoid padding locally then padding globally?
tensor_or_list = torch.nn.utils.rnn.pad_sequence(
list_of_tensors, batch_first=True, padding_value=pad_value
)
else:
raise NotImplementedError(
(
f"Attempted to stack for unsupported type {type(list_of_tensors[0])} with key {k}."
"Please provide either a tensor or a list of picklable objects."
)
)
stacked_dict[k] = tensor_or_list
return stacked_dict
def all_gather(self, group: torch.distributed.ProcessGroup) -> Self:
"""Gathers batches with possibly jagged leading dimensions across the DP ranks.
If using reshard, it will treat PP as DP ranks.
Works with data that is either tensors or string lists.
"""
global_rollout_batch: Self = type(self)()
for k, value in self.data.items():
if isinstance(value, torch.Tensor):
value = rebalance_nd_tensor(value, group=group)
global_rollout_batch[k] = value
elif isinstance(value, list):
value = gather_jagged_object_lists(value, group=group)
global_rollout_batch[k] = value
else:
raise NotImplementedError(
(
f"Attempted to gather_and_balance_globally for unsupported type {type(value)} with key {k}."
"Please provide either a tensor or a list of picklable objects."
)
)
return global_rollout_batch
def chunk(self, rank: int, chunks: int) -> "SlicedDataDict":
"""Chunks a global batch into 'chunks' splits and returns the 'rank'th split batch=[A A A B B B D D E], rank=2, chunks=3 -> [D D E].
Requires all leading dimensions of tensors and lengths of lists to be the same over the batch
and the chunks must divide batch size.
"""
chunked_batch = SlicedDataDict()
batch_set = set()
for val in self.data.values():
if isinstance(val, torch.Tensor):
batch_set.add(val.size(0))
else:
batch_set.add(len(val))
assert len(batch_set) == 1, (
"batch sizes are not the same across the rollout batch"
)
B = batch_set.pop()
assert B % chunks == 0, (
f"batch size ({B}) is not a multiple of chunks ({chunks})"
)
assert B // chunks > rank, (
f"index OOB: not enough splits for this rank. rollout_batch_size: {B}, chunks ({chunks}), rank_idx ({rank})"
)
indices = torch.arange(B).tensor_split(chunks)[rank]
for k in self.data:
if torch.is_tensor(self.data[k]):
chunked_batch[k] = self.data[k][indices].clone()
elif isinstance(self.data[k], PackedTensor):
chunked_batch[k] = self.data[k].slice(indices)
else:
chunked_batch[k] = [self.data[k][i] for i in indices]
return chunked_batch
def reorder_data(self, reorded_indices: list[int]):
"""Reorders the data along the batch dimension by the given indices."""
batch_sizes = set()
for val in self.data.values():
if isinstance(val, torch.Tensor):
batch_sizes.add(val.size(0))
else:
batch_sizes.add(len(val))
assert len(batch_sizes) == 1, (
"Batch sizes are not the same across the rollout batch"
)
total_batch_size = batch_sizes.pop()
indices = range(total_batch_size)
reordered = sorted(zip(reorded_indices, indices), key=lambda pair: pair[0])
reordered_indices = [idx[1] for idx in reordered]
for k, v in self.data.items():
sorted_v: torch.Tensor | list[Any]
if torch.is_tensor(v):
sorted_v = v.index_select(
dim=0, index=torch.IntTensor(reordered_indices)
)
elif isinstance(v, PackedTensor):
sorted_v = v.slice(reordered_indices)
else:
sorted_v = [v[i] for i in reordered_indices]
self.data[k] = sorted_v
def shard_by_batch_size(
self,
shards: int,
batch_size: Optional[int] = None,
allow_uneven_shards: bool = False,
dynamic_batching_args: Optional[DynamicBatchingArgs] = None,
sequence_packing_args: Optional[SequencePackingArgs] = None,
) -> list["SlicedDataDict"] | tuple[list["SlicedDataDict"], list[int]]:
"""Shards a batch by first dividing it into chunks of size batch_size, then further dividing each chunk into shards equal parts. Finally aggregates the sub-shards by their position.
If batch_size is None, there will be no chunking beforehand (will default to the total batch size).
For example, with data [A A B B C C D D], batch_size=2, shards=2:
- Element 0: [A B C D] (first elements from each chunk)
- Element 1: [A B C D] (second elements from each chunk)
Args:
shards (int): The number of shards to divide each batch_size chunk into.
batch_size (int): The size of each initial chunk.
allow_uneven_shards (bool): Whether to allow shards to be unevenly sized.
If True, the last shard may be smaller than the others.
dynamic_batching_args (dict): If passed, preprocess batch for dynamic batching. This
dict requires four keys:
1. max_tokens_per_microbatch (int): the maximum
number of tokens in a microbatch
2. sequence_length_round (int): round each all
sequence lengths to this multiple
3. input_key (str): the key in the batch
which holds input ids.
4. input_lengths_key (str): the key in the batch
which holds the sequence length per value.
The sequence dim index is assumed to be 1.
Cannot be passed with sequence_packing_args.
sequence_packing_args (dict): If passed, preprocess batch for sequence packing. This
dict requires five keys:
1. max_tokens_per_microbatch (int): the maximum
number of tokens in a microbatch
2. input_key (str): the key in the batch
which holds input ids.
3. input_lengths_key (str): the key in the batch
which holds the sequence length per value.
The sequence dim index is assumed to be 1.
4. algorithm (str): the algorithm to use for sequence packing.
5. sequence_length_pad_multiple (int): the multiple to pad each sequence to.
With CP enabled, this should be set to a multiple of 2*CP and SP.
Cannot be passed with dynamic_batching_args.
Returns:
list[BatchedDataDict]: A list of BatchedDataDicts, length equal to shards.
If dynamic_batching_args is passed, returns a tuple of the list of BatchedDataDicts and the sorted indices.
Examples:
```{doctest}
>>> from nemo_rl.distributed.batched_data_dict import BatchedDataDict
>>> # Create a batch of two message logs with different lengths
>>> batch = BatchedDataDict({
... 'problem_id': [0, 0, 1, 1, 2, 2, 3, 3],
... 'arbitrary_data': [1, 2, 3, 4, 5, 6, 7, 8]
... })
>>> shards = batch.shard_by_batch_size(shards=2)
>>> shards
[{'problem_id': [0, 0, 1, 1], 'arbitrary_data': [1, 2, 3, 4]}, {'problem_id': [2, 2, 3, 3], 'arbitrary_data': [5, 6, 7, 8]}]
>>> # Now say that I'm training with a GBS of 4 and I want to take gradients steps on problems 0 and 1 before 2 and 3 (problems are repeated because GRPO)
>>> # In the current case, problems 0 and 2 will be trained on first since they're the first elements in each DP rank's batch.
>>> # So, we'll use the batch_size argument to split the batch into chunks of size 4 first.
>>> shards = batch.shard_by_batch_size(shards=2, batch_size=4)
>>> shards
[{'problem_id': [0, 0, 2, 2], 'arbitrary_data': [1, 2, 5, 6]}, {'problem_id': [1, 1, 3, 3], 'arbitrary_data': [3, 4, 7, 8]}]
>>> # Now, the ranks have 0 and 1 first so when they split their batches into microbatches (of size 2 since GBS=4 and DP=2), they'll train on 0 and 1 first.
>>> # Another way to use this function is with the 'allow_uneven_shards' flag, which allows the last shard to be smaller than the others when necessary.
>>> # This is necessary in multi-turn rollouts when some sequences terminate early, leaving unclean batch sizes.
>>> batch = BatchedDataDict({
... 'problem_id': [0, 1, 2, 3, 4],
... 'arbitrary_data': [10, 11, 12, 13, 14]
... })
>>> shards = batch.shard_by_batch_size(shards=2, allow_uneven_shards=True)
>>> shards
[{'problem_id': [0, 1, 2], 'arbitrary_data': [10, 11, 12]}, {'problem_id': [3, 4], 'arbitrary_data': [13, 14]}]
>>> # This is incompatible with the batch_size argument
```
"""
if allow_uneven_shards:
assert batch_size is None, (
"batch_size must be None if allow_uneven_shards is True"
)
assert dynamic_batching_args is None or sequence_packing_args is None, (
"dynamic_batching_args and sequence_packing_args cannot be passed together"
)
# Get the total batch size
batch_sizes = set()
for val in self.data.values():
if isinstance(val, torch.Tensor):
batch_sizes.add(val.size(0))
else:
batch_sizes.add(len(val))
assert len(batch_sizes) == 1, (
"Batch sizes are not the same across the rollout batch, found sizes: "
+ f"[{','.join(str(size) for size in batch_sizes)}]"
)
total_batch_size = batch_sizes.pop()
if batch_size is None:
batch_size = total_batch_size
# Validate that our batch size parameters are compatible with the data dimensions
assert total_batch_size % batch_size == 0, (
f"Total batch size ({total_batch_size}) is not a multiple of batch_size ({batch_size})"
)
if not allow_uneven_shards:
assert batch_size % shards == 0, (
f"Batch size ({batch_size}) is not a multiple of shards ({shards})"
)
num_chunks = total_batch_size // batch_size
# Calculate shard size, rounding up if not evenly divisible
shard_size = (
(batch_size + shards - 1) // shards
if allow_uneven_shards
else batch_size // shards
)
# if using dynamic microbatching, preprocess the data by sorting the data
# by the sequence lengths. This ensures each DP rank receives samples of about
# equal sequence lengths which improves load balancing
if dynamic_batching_args is not None:
data = {}
batch_sorted_indices = []
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * batch_size
chunk_end = (chunk_idx + 1) * batch_size
chunk_seqlens = self.data[dynamic_batching_args["input_lengths_key"]][
chunk_start:chunk_end
]
# sort the indices by sequence lengths
chunk_idx_indices = sorted(
range(batch_size), key=lambda i: chunk_seqlens[i]
)
chunk_idx_indices = [i + chunk_start for i in chunk_idx_indices]
# stride the sorted sequence lengths along the shards
chunk_idx_indices = [
chunk_idx_indices[i::shards] for i in range(shards)
]
chunk_idx_indices = sum(chunk_idx_indices, [])
# append the sorted sequence lengths for the chunk
batch_sorted_indices.extend(chunk_idx_indices)
# finally reorder the data along the sorted sequence len indices
for k, v in self.data.items():
sorted_v: torch.Tensor | list[Any] | PackedTensor
if torch.is_tensor(v):
sorted_v = v.index_select(
dim=0, index=torch.IntTensor(batch_sorted_indices)
)
elif isinstance(v, PackedTensor):
sorted_v = v.slice(batch_sorted_indices)
else:
sorted_v = [v[i] for i in batch_sorted_indices]
data[k] = sorted_v
elif sequence_packing_args is not None:
bin_packer = get_packer(
algorithm=sequence_packing_args["algorithm"],
bin_capacity=sequence_packing_args["max_tokens_per_microbatch"],
collect_metrics=False, # TODO(ahmadki): make configurable
min_bin_count=shards,
bin_count_multiple=shards,
)
input_lengths_key = sequence_packing_args["input_lengths_key"]
input_lens = self.data[input_lengths_key]
if not isinstance(input_lens, torch.Tensor):
input_lens = torch.tensor(input_lens)
pad_multiple = sequence_packing_args["sequence_length_pad_multiple"]
def _get_padded_seqlen(seqlen: int) -> int:
return (seqlen + pad_multiple - 1) // pad_multiple * pad_multiple
# Store bin assignments for each chunk to reuse later
all_chunk_bin_assignments = []
# Process each chunk separately to respect chunk boundaries
for chunk_idx in range(num_chunks):
chunk_start = chunk_idx * batch_size
chunk_end = (chunk_idx + 1) * batch_size
# Get sequence lengths for this chunk
chunk_seqlens = input_lens[chunk_start:chunk_end]
chunk_padded_seqlens_list = [
_get_padded_seqlen(seq_len.item()) for seq_len in chunk_seqlens
]
# Pack sequences in this chunk into bins
chunk_bin_assignments = bin_packer.pack(
sequence_lengths=chunk_padded_seqlens_list,
)
all_chunk_bin_assignments.append(chunk_bin_assignments)
# create shards with the packed bins
sharded_data: list[list[dict]] = [[] for _ in range(shards)]
sharded_micro_indices: list = [[] for _ in range(shards)]
sharded_micro_lengths: list = [[] for _ in range(shards)]
sharded_elem_counts_per_gb: list = [[] for _ in range(shards)]
global_indices_per_shard: list[list[int]] = [[] for _ in range(shards)]
for chunk_idx in range(num_chunks):
chunk_sharded_micro_indices: list[list[list[int]]] = [
[] for _ in range(shards)
]
chunk_sharded_micro_lengths: list[list[int]] = [
[] for _ in range(shards)
]
num_bins = len(all_chunk_bin_assignments[chunk_idx])
chunk_start = chunk_idx * batch_size
for bin_idx in range(num_bins):
shard_idx = bin_idx % shards
bin_indices = all_chunk_bin_assignments[chunk_idx][bin_idx]
global_bin_indices = [i + chunk_start for i in bin_indices]
sharded_data[shard_idx].append(
self.select_indices(global_bin_indices)
)
global_indices_per_shard[shard_idx].extend(global_bin_indices)
bin_seqlen = sum(
[
_get_padded_seqlen(input_lens[i].item())
for i in global_bin_indices
]
)
if chunk_sharded_micro_indices[shard_idx] == []:
chunk_sharded_micro_indices[shard_idx].append(
[0, len(bin_indices)]
)
else:
prev_bin_end = chunk_sharded_micro_indices[shard_idx][-1][1]
chunk_sharded_micro_indices[shard_idx].append(
[prev_bin_end, prev_bin_end + len(bin_indices)]
)
chunk_sharded_micro_lengths[shard_idx].append(bin_seqlen)
for shard_idx in range(shards):
sharded_micro_indices[shard_idx].append(
chunk_sharded_micro_indices[shard_idx]
)
sharded_micro_lengths[shard_idx].append(
chunk_sharded_micro_lengths[shard_idx]
)
sharded_elem_counts_per_gb[shard_idx].append(
chunk_sharded_micro_indices[shard_idx][-1][1]
)
# flatten global_indices_per_shard
batch_sorted_indices = []
for shard_idx in range(shards):
batch_sorted_indices.extend(global_indices_per_shard[shard_idx])
aggregated_shards = []
for shard_idx in range(shards):
shard = SlicedDataDict.from_batches(sharded_data[shard_idx])
shard.micro_batch_indices = sharded_micro_indices[shard_idx]
shard.micro_batch_lengths = sharded_micro_lengths[shard_idx]
shard.elem_counts_per_gb = sharded_elem_counts_per_gb[shard_idx]
aggregated_shards.append(shard)
return aggregated_shards, batch_sorted_indices
else:
data = self.data
aggregated_shards = [SlicedDataDict() for _ in range(shards)]
# Group data by shard position across all chunks
for shard_idx in range(shards):
shard_ranges: list[tuple[int, int]] = []
for chunk_idx in range(num_chunks):
# Calculate indices for this particular sub-shard within the chunk
chunk_start = chunk_idx * batch_size
shard_start = chunk_start + shard_idx * shard_size
shard_end = chunk_start + (shard_idx + 1) * shard_size
if allow_uneven_shards:
# Cap the end index at the total batch size for the last shard
# or if shard_end calculation goes beyond total_batch_size
shard_start = min(shard_start, total_batch_size)
shard_end = min(shard_end, total_batch_size)
if shard_start < shard_end:
shard_ranges.append((shard_start, shard_end))
for k, v in data.items():
if torch.is_tensor(v):
# Pre-allocate and copy each chunk once
rows = sum(end - start for start, end in shard_ranges)
shard_tensor = torch.empty(
(rows, *v.shape[1:]),
dtype=v.dtype,
device=v.device,
)
offset = 0
for start, end in shard_ranges:
span = end - start
shard_tensor[offset : offset + span].copy_(v[start:end])
offset += span
aggregated_shards[shard_idx][k] = shard_tensor
elif isinstance(v, PackedTensor):
# PackedTensor is collected per chunk then concatenated once
packed_slices = [
v.slice(list(range(start, end))) for start, end in shard_ranges
]
aggregated_shards[shard_idx][k] = (
PackedTensor.concat(packed_slices)
if packed_slices
else PackedTensor.empty_like(v)
)
else:
shard_values = []
for start, end in shard_ranges:
shard_values.extend([v[i] for i in range(start, end)])
aggregated_shards[shard_idx][k] = shard_values
# map inputs to microbatches such that the total number tokens in
# a microbatch is as close to (including padding tokens) 'max_tokens_per_microbatch'
if dynamic_batching_args is not None:
max_tokens_per_microbatch = dynamic_batching_args[
"max_tokens_per_microbatch"
]
micro_batch_indices = []
micro_batch_lengths = []
# loop through each chunk, dividing the chunk into microbatches
for chunk_idx in range(num_chunks):
chunk_micro_batch_indices = [[0, 0]]
chunk_micro_batch_lengths = [0]
max_seqlen_this_mb = 0
# for each indice in the shard, map it to an microbatch
for shard_indice in range(shard_size):
# use the max seqlen of all shards to calculate the total number of tokens in the mb
# this ensures each DP rank has the same batch size each iteration which is
# required for FSDP2 and megatron policies.
max_seqlen_this_shard_indice = 0
chunk_start = chunk_idx * shard_size
chunk_end = (chunk_idx + 1) * shard_size
for shard in aggregated_shards:
input_lengths = shard[
dynamic_batching_args["input_lengths_key"]
]
seq_len = input_lengths[chunk_start:chunk_end][
shard_indice
].item()
max_seqlen_this_shard_indice = max(
max_seqlen_this_shard_indice, seq_len
)
# pad to nearest multiple specified
sequence_length_round = dynamic_batching_args[
"sequence_length_round"
]
unpadded_seqlen = data[dynamic_batching_args["input_key"]].shape[1]
padded_seqlen = (
(max_seqlen_this_shard_indice + sequence_length_round - 1)
// sequence_length_round
) * sequence_length_round
max_seqlen_this_shard_indice = min(padded_seqlen, unpadded_seqlen)
assert max_seqlen_this_shard_indice <= max_tokens_per_microbatch, (
f"got an input of padded ({sequence_length_round}) sequence length of {max_seqlen_this_shard_indice}, however max microbatch size is {max_tokens_per_microbatch} tokens"
)
# check if the sample at shard_indice may be added to the current mbs for all shards
# the total tokens of a mbs = number of indices in the mbs * the max sequence length in the mbs
curr_mbs_size = (
chunk_micro_batch_indices[-1][1]
- chunk_micro_batch_indices[-1][0]
+ 1
)
max_seqlen_this_mb = max(
max_seqlen_this_mb, max_seqlen_this_shard_indice
)
total_tokens_in_mbs = curr_mbs_size * max_seqlen_this_mb
# if the current mbs can accomodate this indice, add it
if total_tokens_in_mbs <= max_tokens_per_microbatch:
chunk_micro_batch_indices[-1][-1] += 1
chunk_micro_batch_lengths[-1] = max_seqlen_this_mb
# otherwise start a new mbs
else:
chunk_micro_batch_indices.append(
[shard_indice, shard_indice + 1]
)
max_seqlen_this_mb = max_seqlen_this_shard_indice
chunk_micro_batch_lengths.append(max_seqlen_this_mb)
micro_batch_indices.append(chunk_micro_batch_indices)
micro_batch_lengths.append(chunk_micro_batch_lengths)
for shard in aggregated_shards:
shard.micro_batch_indices = micro_batch_indices
shard.micro_batch_lengths = micro_batch_lengths
return aggregated_shards, batch_sorted_indices
return aggregated_shards
def get_batch(self, batch_idx, batch_size=None) -> "SlicedDataDict":
"""Slices a subbatch from the batch.
Args:
batch_idx: the batch index to slice
batch_size: the size of the batch to be sliced
Returns:
BatchedDataDict: A new BatchedDataDict containing the sliced data
"""
if self.elem_counts_per_gb is not None:
assert self.micro_batch_indices is not None, (
"micro_batch_indices must be provided if sequence_packing is True"
)
elem_count = self.elem_counts_per_gb[batch_idx]
cum_elem_count = [0]
for i in range(len(self.elem_counts_per_gb)):
cum_elem_count.append(cum_elem_count[i] + self.elem_counts_per_gb[i])
batch = self.slice(cum_elem_count[batch_idx], cum_elem_count[batch_idx + 1])
batch.micro_batch_indices = [self.micro_batch_indices[batch_idx]]
batch.micro_batch_lengths = [self.micro_batch_lengths[batch_idx]] # type: ignore # This exists if idxs do
batch.elem_counts_per_gb = [elem_count]
return batch
start = batch_size * batch_idx
end = batch_size * (batch_idx + 1)
batch = self.slice(start, end)
if self.micro_batch_indices is not None:
batch.micro_batch_indices = [self.micro_batch_indices[batch_idx]]
batch.micro_batch_lengths = [self.micro_batch_lengths[batch_idx]] # type: ignore # This exists if idxs do
return batch
def slice(self, start: int, end: int) -> "SlicedDataDict":
"""Slices the batch from start to end.
Args:
start: Starting index (inclusive)
end: Ending index (exclusive)
Returns:
BatchedDataDict: A new BatchedDataDict containing the sliced data
"""
sliced_batch = SlicedDataDict()
for k in self.data:
if isinstance(self.data[k], PackedTensor):
sliced_batch[k] = self.data[k].slice(list(range(start, end)))
continue
if isinstance(self.data[k], torch.Tensor):
assert end <= self.data[k].shape[0], (
f"end: {end} is greater than the shape of the tensor: {self.data[k].shape[0]} for key: {k}"
)
sliced_batch[k] = self.data[k][start:end]
return sliced_batch
def repeat_interleave(self, num_repeats: int) -> Self:
"""Repeats the batch num_repeats times.
For each element in the batch, repeat each value num_repeats times.
i.e:
{"key": torch.tensor([1, 2, 3]), "other_key": [1, 2, 3]} -> {"key": torch.tensor([1, 1, 2, 2, 3, 3]), "other_key": [1, 1, 2, 2, 3, 3]}
"""
repeated_batch: Self = type(self)()
for k, v in self.data.items():
if torch.is_tensor(v):
# For tensors, use repeat_interleave to repeat each element
repeated_batch[k] = v.repeat_interleave(num_repeats, dim=0)
elif isinstance(v, PackedTensor):
raise NotImplementedError(
"PackedTensor does not currently support repeat_interleave"
)
else:
# For lists or other sequences, use a list comprehension to repeat each element
repeated_batch[k] = [
deepcopy(item) for item in v for _ in range(num_repeats)
]
return repeated_batch
def truncate_tensors(self, dim: int, truncated_len: int):
"""Truncates tensors in this dict of a given dim to a given length."""
for k, v in self.items():
if torch.is_tensor(v) and len(v.shape) >= dim + 1:
self.data[k] = torch.narrow(v, dim=dim, start=0, length=truncated_len)
def make_microbatch_iterator_with_dynamic_shapes(
self,
sequence_dim: int = 1,
) -> Iterator["SlicedDataDict"]:
"""Makes an iterator that yields microbatchs of dynamic batch and sequence sizes.
Args:
sequence_dim: the index of the sequence dim for all tensors in the data dict
Returns:
Iterator["SlicedDataDict"]: An iterator that yield dynamic microbatches
"""
assert (
self.micro_batch_indices is not None
and len(self.micro_batch_indices) == 1
and self.micro_batch_lengths is not None
)
for seqlen, (start_idx, end_idx) in zip(
self.micro_batch_lengths[0], self.micro_batch_indices[0]
):
mb = self.slice(start_idx, end_idx)
mb.truncate_tensors(dim=sequence_dim, truncated_len=seqlen)
yield mb
def get_microbatch_iterator_dynamic_shapes_len(self) -> int:
"""Get the length of the microbatch iterator for dynamic shapes."""
return len(self.micro_batch_indices[0])
def make_microbatch_iterator_for_packable_sequences(
self,
) -> Iterator["SlicedDataDict"]:
"""Make an iterator over the batch that yields microbatches that can be packed into a given max_tokens_per_microbatch."""
assert (
self.micro_batch_indices is not None
and len(self.micro_batch_indices) == 1
and self.micro_batch_lengths is not None
)
for seqlen, (start_idx, end_idx) in zip(
self.micro_batch_lengths[0], self.micro_batch_indices[0]
):
mb = self.slice(start_idx, end_idx)
yield mb
def get_microbatch_iterator_for_packable_sequences_len(self) -> tuple[int, int]:
"""Get the length of the microbatch iterator for sequence packing and the max packed seqlen."""
return len(self.micro_batch_indices[0]), max(self.micro_batch_lengths[0])
def make_microbatch_iterator(
self, microbatch_size: int
) -> Iterator["SlicedDataDict"]:
"""Make an iterator over the batch that yields microbatches of size microbatch_size."""
bsize = self.size
assert bsize % microbatch_size == 0, (
f"Data dict size ({bsize}) is not a multiple of the provided microbatch size ({microbatch_size})"
)
for i in range(0, bsize, microbatch_size):
yield self.slice(i, i + microbatch_size)
@property
def size(self) -> int:
"""Get the batch size of the batch."""
# Get the first key and use its size as the batch size
# This assumes all keys have the same batch size
key = next(iter(self.data))
if not self.data:
return 0
if not torch.is_tensor(self.data[key]):
return len(self.data[key])
return self.data[key].shape[0] # type: ignore # it's a tensor here
def to(self, device: str | torch.device) -> Self:
"""Move tensors in batched dict to device."""
for k, v in self.data.items():
if torch.is_tensor(v):
self.data[k] = v.to(device)
elif isinstance(v, PackedTensor):
self.data[k] = v.to(device)
return self
def select_indices(self, indices: Union[list[int], torch.Tensor]) -> Self:
"""Selects specific rows from the batch based on indices.
Args:
indices: A list or tensor of integer indices to select.
Returns:
BatchedDataDict: A new BatchedDataDict containing only the selected rows.
"""
selected_batch: Self = type(self)()
for k, v in self.data.items():
if torch.is_tensor(v):
selected_batch[k] = v[indices]
elif isinstance(v, PackedTensor):
selected_batch[k] = v.slice(indices)
elif isinstance(v, list):
selected_batch[k] = [v[i] for i in indices]
else:
# Handle other potential types if necessary, or raise error
raise TypeError(
f"Unsupported type {type(v)} for index selection in BatchedDataDict"
)
return selected_batch
def get_dict(self) -> dict[Any, Any]:
"""Get the underlying data dictionary."""
return self.data
class SlicedDataDict(BatchedDataDict):
"""A specialized subclass of BatchedDataDict that represents a slice or shard of a larger batch.
This class provides a distinct type to differentiate between full batches and sliced/sharded batches, which can be helpful for
type checking.
"""
pass