Skip to content

Commit 2fa8bc4

Browse files
fix test
Signed-off-by: Zhiyu Li <[email protected]>
1 parent 4e39468 commit 2fa8bc4

File tree

2 files changed

+28
-13
lines changed

2 files changed

+28
-13
lines changed

nemo_automodel/components/distributed/parallelizer.py

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -123,21 +123,30 @@ def _resolve_fsdp_mesh_dims(
123123
Returns:
124124
Tuple of dimension names to use for the FSDP mesh.
125125
"""
126-
if dp_shard_cp_mesh_name in device_mesh.mesh_dim_names:
127-
# Best case: use pre-flattened (dp_replicate, dp_shard_cp)
126+
mesh_dim_names = getattr(device_mesh, "mesh_dim_names", None)
127+
128+
# If mesh_dim_names not available, use the default passed names (backward compatibility)
129+
if mesh_dim_names is None:
130+
logger.debug(f"FSDP mesh: ({dp_replicate_mesh_name}, {dp_shard_cp_mesh_name}) [no mesh_dim_names]")
131+
return (dp_replicate_mesh_name, dp_shard_cp_mesh_name)
132+
133+
# Use flattened dimension if available
134+
if dp_shard_cp_mesh_name in mesh_dim_names:
128135
logger.debug(f"FSDP mesh: ({dp_replicate_mesh_name}, {dp_shard_cp_mesh_name}) [flattened]")
129136
return (dp_replicate_mesh_name, dp_shard_cp_mesh_name)
130-
elif device_mesh["cp"].size() == 1:
131-
# CP unused: construct 2D mesh without cp
132-
logger.debug(f"FSDP mesh: ({dp_replicate_mesh_name}, dp_shard) [cp=1, unflattened]")
137+
138+
# Check if cp dimension exists and get its size
139+
cp_size = device_mesh["cp"].size() if "cp" in mesh_dim_names else 1
140+
141+
if cp_size == 1:
142+
logger.debug(f"FSDP mesh: ({dp_replicate_mesh_name}, dp_shard) [cp={cp_size}, unflattened]")
133143
return (dp_replicate_mesh_name, "dp_shard")
134144
else:
135-
# This should be unreachable: if cp > 1, flattening should have succeeded
136145
raise RuntimeError(
137-
f"Cannot construct FSDP mesh: cp={device_mesh['cp'].size()} but '{dp_shard_cp_mesh_name}' dimension not found. "
146+
f"Cannot construct FSDP mesh: cp={cp_size} but '{dp_shard_cp_mesh_name}' dimension not found. "
138147
f"FSDP requires 1D or 2D mesh, but would need 3D: ({dp_replicate_mesh_name}, dp_shard, cp). "
139148
f"Flattening (dp_shard, cp) -> '{dp_shard_cp_mesh_name}' should occur in FSDP2Manager._get_device_mesh() (fsdp2.py). "
140-
f"Available: {device_mesh.mesh_dim_names}"
149+
f"Available: {mesh_dim_names}"
141150
)
142151

143152
def parallelize(

nemo_automodel/recipes/base_recipe.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -498,12 +498,18 @@ def _resolve_dp_dim_name(self, include_cp: bool = False) -> str:
498498
Returns:
499499
The dimension name to use for DP operations.
500500
"""
501-
if include_cp and self.device_mesh["cp"].size() > 1:
502-
# Prefer flattened "dp_cp", fallback to "dp_shard"
503-
return "dp_cp" if "dp_cp" in self.device_mesh.mesh_dim_names else "dp_shard"
501+
mesh_dim_names = getattr(self.device_mesh, "mesh_dim_names", None)
502+
503+
# If mesh_dim_names not available, use default behavior
504+
if mesh_dim_names is None:
505+
mesh_dim_names = []
506+
507+
cp_size = self.device_mesh["cp"].size() if "cp" in mesh_dim_names else 1
508+
509+
if include_cp and cp_size > 1:
510+
return "dp_cp" if "dp_cp" in mesh_dim_names else "dp_shard"
504511
else:
505-
# Prefer flattened "dp", fallback to "dp_shard"
506-
return "dp" if "dp" in self.device_mesh.mesh_dim_names else "dp_shard"
512+
return "dp" if "dp" in mesh_dim_names else "dp_shard"
507513

508514
def _get_dp_group(self, include_cp: bool = False):
509515
"""Get the data parallel process group.

0 commit comments

Comments
 (0)