Skip to content

Commit 42c47d9

Browse files
authored
[AutoParallel] Update dense_tensor_idx des (#71571) (#71740)
* Update api.py * test=document_fix * Update api.py * Update api.py * Update api.py * Update api.py * Update api.py * Update api.py * Update api.py * Update api.py
1 parent 9b637a7 commit 42c47d9

File tree

1 file changed

+12
-46
lines changed
  • python/paddle/distributed/auto_parallel

1 file changed

+12
-46
lines changed

python/paddle/distributed/auto_parallel/api.py

Lines changed: 12 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3300,29 +3300,12 @@ class ShardDataloader:
33003300
Users can specify the shard_dim of each mesh or specify a single shard_dim for all meshes.
33013301
Default: None, which means the data loader will not be split, i.e. mp.
33023302
is_dataset_splitted (bool): Whether the dataset has been splitted.
3303-
dense_tensor_idx (list): A 2D list specifies the index of the dense_tensor in the output of dataloader.
3303+
dense_tensor_idx (list): A paired 2D list specifies the index of the dense_tensor in the output of dataloader.
33043304
It allows users to identify which elements within each output batch are dense_tensor.
3305-
Default: None, which means all the outputs are dist_tensors.
3306-
e.g.
3307-
1. If the collator function returns:
3308-
return {
3309-
"input_ids": [
3310-
features["input_ids"],
3311-
features["attention_mask"],
3312-
features["position_ids"],
3313-
],
3314-
"image": features["image"],
3315-
"labels": features["labels"],
3316-
}
3317-
2. If `dense_tensor_idx = [[1, 2], [0], []]`:
3318-
- For "input_ids":
3319-
input_ids["input_ids"] is a dist_tensor
3320-
input_ids["attention_mask"] is a dense_tensor
3321-
input_ids["position_ids"] is a dense_tensor
3322-
- For "image":
3323-
image is a dense_tensor
3324-
- For "labels":
3325-
labels is a dist_tensor
3305+
first dense_tensor: the dense_tensor return by dataloader.
3306+
second dense_tensor: num_or_sections specifies how to split first tensor: evenly (if a number) or unevenly (if a list).
3307+
Default: None, meaning all outputs are dist_tensors.
3308+
Note: For dense_tensor_idx settings, the idx must be paired.
33263309
"""
33273310

33283311
def __init__(
@@ -3332,7 +3315,7 @@ def __init__(
33323315
input_keys: list[str] | tuple[str] | None = None,
33333316
shard_dims: list | tuple | str | int | None = None,
33343317
is_dataset_splitted: bool = False,
3335-
dense_tensor_idx: list | None = None,
3318+
dense_tensor_idx: list[list[int]] | None = None,
33363319
):
33373320
# do some check
33383321
if is_dataset_splitted is True and shard_dims is None:
@@ -3615,7 +3598,7 @@ def shard_dataloader(
36153598
input_keys: Sequence[str] | None = None,
36163599
shard_dims: Sequence[str] | Sequence[int] | str | int | None = None,
36173600
is_dataset_splitted: bool = False,
3618-
dense_tensor_idx: list | None = None,
3601+
dense_tensor_idx: list[list[int]] | None = None,
36193602
) -> ShardDataloader:
36203603
"""
36213604
Convert the dataloader to a ShardDataloader which provided two capabilities:
@@ -3640,29 +3623,12 @@ def shard_dataloader(
36403623
Users can specify the shard_dim of each mesh or specify a single shard_dim for all meshes.
36413624
Default: None, which means the data loader will not be split, i.e. mp.
36423625
is_dataset_splitted (bool): Whether the dataset has been splitted, Default: False.
3643-
dense_tensor_idx (list): A 2D list specifies the index of the dense_tensor in the output of dataloader.
3626+
dense_tensor_idx (list): A paired 2D list specifies the index of the dense_tensor in the output of dataloader.
36443627
It allows users to identify which elements within each output batch are dense_tensor.
3645-
Default: None, which means all the outputs are dist_tensors.
3646-
e.g.
3647-
1. If the collator function returns:
3648-
return {
3649-
"input_ids": [
3650-
features["input_ids"],
3651-
features["attention_mask"],
3652-
features["position_ids"],
3653-
],
3654-
"image": features["image"],
3655-
"labels": features["labels"],
3656-
}
3657-
2. If `dense_tensor_idx = [[1, 2], [0], []]`:
3658-
- For "input_ids":
3659-
input_ids["input_ids"] is a dist_tensor
3660-
input_ids["attention_mask"] is a dense_tensor
3661-
input_ids["position_ids"] is a dense_tensor
3662-
- For "image":
3663-
image is a dense_tensor
3664-
- For "labels":
3665-
labels is a dist_tensor
3628+
first dense_tensor: the dense_tensor return by dataloader.
3629+
second dense_tensor: num_or_sections specifies how to split first tensor: evenly (if a number) or unevenly (if a list).
3630+
Default: None, meaning all outputs are dist_tensors.
3631+
Note: For dense_tensor_idx settings, the idx must be paired.
36663632
Returns:
36673633
ShardDataloader: The sharded dataloader.
36683634

0 commit comments

Comments
 (0)