@@ -3300,29 +3300,12 @@ class ShardDataloader:
3300
3300
Users can specify the shard_dim of each mesh or specify a single shard_dim for all meshes.
3301
3301
Default: None, which means the data loader will not be split, i.e. mp.
3302
3302
is_dataset_splitted (bool): Whether the dataset has been splitted.
3303
- dense_tensor_idx (list): A 2D list specifies the index of the dense_tensor in the output of dataloader.
3303
+ dense_tensor_idx (list): A paired 2D list specifies the index of the dense_tensor in the output of dataloader.
3304
3304
It allows users to identify which elements within each output batch are dense_tensor.
3305
- Default: None, which means all the outputs are dist_tensors.
3306
- e.g.
3307
- 1. If the collator function returns:
3308
- return {
3309
- "input_ids": [
3310
- features["input_ids"],
3311
- features["attention_mask"],
3312
- features["position_ids"],
3313
- ],
3314
- "image": features["image"],
3315
- "labels": features["labels"],
3316
- }
3317
- 2. If `dense_tensor_idx = [[1, 2], [0], []]`:
3318
- - For "input_ids":
3319
- input_ids["input_ids"] is a dist_tensor
3320
- input_ids["attention_mask"] is a dense_tensor
3321
- input_ids["position_ids"] is a dense_tensor
3322
- - For "image":
3323
- image is a dense_tensor
3324
- - For "labels":
3325
- labels is a dist_tensor
3305
+ first dense_tensor: the dense_tensor return by dataloader.
3306
+ second dense_tensor: num_or_sections specifies how to split first tensor: evenly (if a number) or unevenly (if a list).
3307
+ Default: None, meaning all outputs are dist_tensors.
3308
+ Note: For dense_tensor_idx settings, the idx must be paired.
3326
3309
"""
3327
3310
3328
3311
def __init__ (
@@ -3332,7 +3315,7 @@ def __init__(
3332
3315
input_keys : list [str ] | tuple [str ] | None = None ,
3333
3316
shard_dims : list | tuple | str | int | None = None ,
3334
3317
is_dataset_splitted : bool = False ,
3335
- dense_tensor_idx : list | None = None ,
3318
+ dense_tensor_idx : list [ list [ int ]] | None = None ,
3336
3319
):
3337
3320
# do some check
3338
3321
if is_dataset_splitted is True and shard_dims is None :
@@ -3615,7 +3598,7 @@ def shard_dataloader(
3615
3598
input_keys : Sequence [str ] | None = None ,
3616
3599
shard_dims : Sequence [str ] | Sequence [int ] | str | int | None = None ,
3617
3600
is_dataset_splitted : bool = False ,
3618
- dense_tensor_idx : list | None = None ,
3601
+ dense_tensor_idx : list [ list [ int ]] | None = None ,
3619
3602
) -> ShardDataloader :
3620
3603
"""
3621
3604
Convert the dataloader to a ShardDataloader which provided two capabilities:
@@ -3640,29 +3623,12 @@ def shard_dataloader(
3640
3623
Users can specify the shard_dim of each mesh or specify a single shard_dim for all meshes.
3641
3624
Default: None, which means the data loader will not be split, i.e. mp.
3642
3625
is_dataset_splitted (bool): Whether the dataset has been splitted, Default: False.
3643
- dense_tensor_idx (list): A 2D list specifies the index of the dense_tensor in the output of dataloader.
3626
+ dense_tensor_idx (list): A paired 2D list specifies the index of the dense_tensor in the output of dataloader.
3644
3627
It allows users to identify which elements within each output batch are dense_tensor.
3645
- Default: None, which means all the outputs are dist_tensors.
3646
- e.g.
3647
- 1. If the collator function returns:
3648
- return {
3649
- "input_ids": [
3650
- features["input_ids"],
3651
- features["attention_mask"],
3652
- features["position_ids"],
3653
- ],
3654
- "image": features["image"],
3655
- "labels": features["labels"],
3656
- }
3657
- 2. If `dense_tensor_idx = [[1, 2], [0], []]`:
3658
- - For "input_ids":
3659
- input_ids["input_ids"] is a dist_tensor
3660
- input_ids["attention_mask"] is a dense_tensor
3661
- input_ids["position_ids"] is a dense_tensor
3662
- - For "image":
3663
- image is a dense_tensor
3664
- - For "labels":
3665
- labels is a dist_tensor
3628
+ first dense_tensor: the dense_tensor return by dataloader.
3629
+ second dense_tensor: num_or_sections specifies how to split first tensor: evenly (if a number) or unevenly (if a list).
3630
+ Default: None, meaning all outputs are dist_tensors.
3631
+ Note: For dense_tensor_idx settings, the idx must be paired.
3666
3632
Returns:
3667
3633
ShardDataloader: The sharded dataloader.
3668
3634
0 commit comments