Skip to content

Commit f7ef006

Browse files
authored
refactor: use init devish mesh for hybrid shard (#24)
1 parent 8ce08c0 commit f7ef006

File tree

1 file changed

+3
-5
lines changed

1 file changed

+3
-5
lines changed

open_diloco/train_fsdp.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,7 +33,8 @@
3333
ShardingStrategy,
3434
MixedPrecision,
3535
)
36-
from torch.distributed.device_mesh import DeviceMesh
36+
from torch.distributed.device_mesh import init_device_mesh
37+
3738
from open_diloco.ckpt_utils import (
3839
CKPT_PREFIX,
3940
CkptConfig,
@@ -229,10 +230,7 @@ def train(config: Config):
229230
]:
230231
local_world_size = int(os.environ["LOCAL_WORLD_SIZE"])
231232
nnodes = world_size // local_world_size
232-
device_mesh = DeviceMesh(
233-
"cuda",
234-
mesh=[[i * local_world_size + j for j in range(local_world_size)] for i in range(nnodes)],
235-
)
233+
device_mesh = init_device_mesh("cuda", (nnodes, local_world_size), mesh_dim_names=("global", "local"))
236234
else:
237235
device_mesh = None
238236
model = FSDP(

0 commit comments

Comments
 (0)