File tree Expand file tree Collapse file tree 2 files changed +5
-1
lines changed
Expand file tree Collapse file tree 2 files changed +5
-1
lines changed Original file line number Diff line number Diff line change @@ -329,7 +329,10 @@ def __init__(
329329
330330 if deepspeed_plugins is None :
331331 # First check if we're creating another `Accelerator` w/o setting `deepspeed_plugin`
332- if AcceleratorState ._shared_state != {} and AcceleratorState ().distributed_type == DistributedType .DEEPSPEED :
332+ if (
333+ AcceleratorState ._shared_state != {}
334+ and AcceleratorState ().distributed_type == DistributedType .DEEPSPEED
335+ ):
333336 deepspeed_plugins = AcceleratorState ().deepspeed_plugins
334337 else :
335338 # init from env variables
Original file line number Diff line number Diff line change @@ -1113,6 +1113,7 @@ def prepare_data_loader(
11131113 # Given a device mesh (dp, tp) = (2, 3):
11141114 # - From the data parallel perspective, ranks should be structured as: 0 0 0 1 1 1
11151115 # - Processes with the same DP rank will receive the same batch.
1116+ submesh_tp_size = 1
11161117 if "tp" in torch_device_mesh .mesh_dim_names :
11171118 submesh_tp_size = torch_device_mesh ["tp" ].size ()
11181119 process_index = process_index // submesh_tp_size
You can’t perform that action at this time.
0 commit comments