Skip to content

Commit 2f075c7

Browse files
set default submesh_tp_size to prevent unset local variable error (#3687)
* set default submesh_tp_size to prevent unset local variable error * Apply style fixes --------- Co-authored-by: github-actions[bot] <github-actions[bot]@users.noreply.github.com>
1 parent 7ecc2d7 commit 2f075c7

File tree

2 files changed

+5
-1
lines changed

2 files changed

+5
-1
lines changed

src/accelerate/accelerator.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -329,7 +329,10 @@ def __init__(
329329

330330
if deepspeed_plugins is None:
331331
# First check if we're creating another `Accelerator` w/o setting `deepspeed_plugin`
332-
if AcceleratorState._shared_state != {} and AcceleratorState().distributed_type == DistributedType.DEEPSPEED:
332+
if (
333+
AcceleratorState._shared_state != {}
334+
and AcceleratorState().distributed_type == DistributedType.DEEPSPEED
335+
):
333336
deepspeed_plugins = AcceleratorState().deepspeed_plugins
334337
else:
335338
# init from env variables

src/accelerate/data_loader.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,7 @@ def prepare_data_loader(
11131113
# Given a device mesh (dp, tp) = (2, 3):
11141114
# - From the data parallel perspective, ranks should be structured as: 0 0 0 1 1 1
11151115
# - Processes with the same DP rank will receive the same batch.
1116+
submesh_tp_size = 1
11161117
if "tp" in torch_device_mesh.mesh_dim_names:
11171118
submesh_tp_size = torch_device_mesh["tp"].size()
11181119
process_index = process_index // submesh_tp_size

0 commit comments

Comments
 (0)