@@ -123,21 +123,30 @@ def _resolve_fsdp_mesh_dims(
123123 Returns:
124124 Tuple of dimension names to use for the FSDP mesh.
125125 """
126- if dp_shard_cp_mesh_name in device_mesh .mesh_dim_names :
127- # Best case: use pre-flattened (dp_replicate, dp_shard_cp)
126+ mesh_dim_names = getattr (device_mesh , "mesh_dim_names" , None )
127+
128+ # If mesh_dim_names not available, use the default passed names (backward compatibility)
129+ if mesh_dim_names is None :
130+ logger .debug (f"FSDP mesh: ({ dp_replicate_mesh_name } , { dp_shard_cp_mesh_name } ) [no mesh_dim_names]" )
131+ return (dp_replicate_mesh_name , dp_shard_cp_mesh_name )
132+
133+ # Use flattened dimension if available
134+ if dp_shard_cp_mesh_name in mesh_dim_names :
128135 logger .debug (f"FSDP mesh: ({ dp_replicate_mesh_name } , { dp_shard_cp_mesh_name } ) [flattened]" )
129136 return (dp_replicate_mesh_name , dp_shard_cp_mesh_name )
130- elif device_mesh ["cp" ].size () == 1 :
131- # CP unused: construct 2D mesh without cp
132- logger .debug (f"FSDP mesh: ({ dp_replicate_mesh_name } , dp_shard) [cp=1, unflattened]" )
137+
138+ # Check if cp dimension exists and get its size
139+ cp_size = device_mesh ["cp" ].size () if "cp" in mesh_dim_names else 1
140+
141+ if cp_size == 1 :
142+ logger .debug (f"FSDP mesh: ({ dp_replicate_mesh_name } , dp_shard) [cp={ cp_size } , unflattened]" )
133143 return (dp_replicate_mesh_name , "dp_shard" )
134144 else :
135- # This should be unreachable: if cp > 1, flattening should have succeeded
136145 raise RuntimeError (
137- f"Cannot construct FSDP mesh: cp={ device_mesh [ 'cp' ]. size () } but '{ dp_shard_cp_mesh_name } ' dimension not found. "
146+ f"Cannot construct FSDP mesh: cp={ cp_size } but '{ dp_shard_cp_mesh_name } ' dimension not found. "
138147 f"FSDP requires 1D or 2D mesh, but would need 3D: ({ dp_replicate_mesh_name } , dp_shard, cp). "
139148 f"Flattening (dp_shard, cp) -> '{ dp_shard_cp_mesh_name } ' should occur in FSDP2Manager._get_device_mesh() (fsdp2.py). "
140- f"Available: { device_mesh . mesh_dim_names } "
149+ f"Available: { mesh_dim_names } "
141150 )
142151
143152 def parallelize (
0 commit comments