Skip to content

Commit 9d384f4

Browse files
authored
[Auto-Parallel] add sharding2 in dynamic auto (#76113)
1 parent a9319a2 commit 9d384f4

File tree

1 file changed

+56
-4
lines changed
  • python/paddle/distributed/auto_parallel

1 file changed

+56
-4
lines changed

python/paddle/distributed/auto_parallel/api.py

Lines changed: 56 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1167,6 +1167,11 @@ def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
11671167
self._set_and_check_sharding_prop_from_param()
11681168
self._shard_fn._set_sharding_axis(self._sharding_axis)
11691169

1170+
# Invoke register hook for sharding stage 2 strategy
1171+
if isinstance(self._shard_fn, ShardingStage2) and not in_auto_dp_mode():
1172+
for param in self._inner_opt._parameter_list:
1173+
self._shard_fn._register_hook_for_param_grad(param)
1174+
11701175
# Invoke shard_parameter in sharding stage 3 strategy
11711176
if isinstance(self._shard_fn, ShardingStage3):
11721177
for param in self._inner_opt._parameter_list:
@@ -2147,7 +2152,7 @@ def __call__(self, key: str, param: Tensor, tensor: Tensor) -> Tensor:
21472152
return self._apply_placement(tensor, param, placements)
21482153

21492154

2150-
class ShardingStage2(ShardingStage1):
2155+
class ShardingStage2(_ShardingStageBase):
21512156
"""
21522157
A builtin shard_fn for shard_optimizer interface, users can pass it to shard_optimizer to implement sharding optimization with stage 2.
21532158
@@ -2186,9 +2191,56 @@ class ShardingStage2(ShardingStage1):
21862191
>>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py
21872192
"""
21882193

2189-
# Note(luchang): Due to reshard optimizations in Paddle where all-reduce + slicing is fused into reduce_scatter,
2190-
# the current behavior of ShardingStage2 is effectively the same as ShardingStage1.
2191-
pass
2194+
def __init__(
2195+
self,
2196+
sharding_mesh_dim: int | str,
2197+
mesh: ProcessMesh | None = None,
2198+
) -> None:
2199+
super().__init__(mesh, sharding_mesh_dim)
2200+
2201+
def __call__(self, key: str, param: Tensor, tensor: Tensor) -> Tensor:
2202+
if param.is_dist():
2203+
# Only deal with momentum in optimizer, beta should be replicated cross param's mesh
2204+
if 'beta' not in key:
2205+
placements = get_placement_with_sharding(
2206+
param, self._sharding_axis
2207+
)
2208+
else:
2209+
placements = [
2210+
dist.Replicate()
2211+
for _ in range(len(param.process_mesh.shape))
2212+
]
2213+
return shard_tensor(
2214+
tensor,
2215+
mesh=param.process_mesh,
2216+
placements=placements,
2217+
)
2218+
return tensor
2219+
2220+
@staticmethod
2221+
def _grad_hook(grad):
2222+
# do reshard only if the grad is dist tensor and in partial status
2223+
if grad.is_dist():
2224+
partial_mesh_axis = None
2225+
for mesh_axis, placement in enumerate(grad.placements):
2226+
if isinstance(placement, dist.Partial):
2227+
partial_mesh_axis = mesh_axis
2228+
if partial_mesh_axis is not None:
2229+
new_placements = get_placement_with_sharding(
2230+
grad, partial_mesh_axis
2231+
)
2232+
return reshard(grad, grad.process_mesh, new_placements)
2233+
2234+
return grad
2235+
2236+
def _register_hook_for_param_grad(self, param):
2237+
if param.is_dense() and self._mesh is not None:
2238+
placements = []
2239+
for _ in range(len(self._mesh.shape)):
2240+
placements.append(dist.Replicate())
2241+
param._to_dist_(placements, self._mesh)
2242+
if param.is_dist():
2243+
param.register_hook(ShardingStage2._grad_hook)
21922244

21932245

21942246
class ShardingStage3(_ShardingStageBase):

0 commit comments

Comments
 (0)