@@ -1167,6 +1167,11 @@ def __init__(self, optimizer, shard_fn=None, gradient_accumulation_steps=1):
11671167 self ._set_and_check_sharding_prop_from_param ()
11681168 self ._shard_fn ._set_sharding_axis (self ._sharding_axis )
11691169
1170+ # Invoke register hook for sharding stage 2 strategy
1171+ if isinstance (self ._shard_fn , ShardingStage2 ) and not in_auto_dp_mode ():
1172+ for param in self ._inner_opt ._parameter_list :
1173+ self ._shard_fn ._register_hook_for_param_grad (param )
1174+
11701175 # Invoke shard_parameter in sharding stage 3 strategy
11711176 if isinstance (self ._shard_fn , ShardingStage3 ):
11721177 for param in self ._inner_opt ._parameter_list :
@@ -2147,7 +2152,7 @@ def __call__(self, key: str, param: Tensor, tensor: Tensor) -> Tensor:
21472152 return self ._apply_placement (tensor , param , placements )
21482153
21492154
2150- class ShardingStage2 (ShardingStage1 ):
2155+ class ShardingStage2 (_ShardingStageBase ):
21512156 """
21522157 A builtin shard_fn for shard_optimizer interface, users can pass it to shard_optimizer to implement sharding optimization with stage 2.
21532158
@@ -2186,9 +2191,56 @@ class ShardingStage2(ShardingStage1):
21862191 >>> # python -m paddle.distributed.launch --gpus=0,1 {test_case}.py
21872192 """
21882193
2189- # Note(luchang): Due to reshard optimizations in Paddle where all-reduce + slicing is fused into reduce_scatter,
2190- # the current behavior of ShardingStage2 is effectively the same as ShardingStage1.
2191- pass
2194+ def __init__ (
2195+ self ,
2196+ sharding_mesh_dim : int | str ,
2197+ mesh : ProcessMesh | None = None ,
2198+ ) -> None :
2199+ super ().__init__ (mesh , sharding_mesh_dim )
2200+
2201+ def __call__ (self , key : str , param : Tensor , tensor : Tensor ) -> Tensor :
2202+ if param .is_dist ():
2203+ # Only deal with momentum in optimizer, beta should be replicated cross param's mesh
2204+ if 'beta' not in key :
2205+ placements = get_placement_with_sharding (
2206+ param , self ._sharding_axis
2207+ )
2208+ else :
2209+ placements = [
2210+ dist .Replicate ()
2211+ for _ in range (len (param .process_mesh .shape ))
2212+ ]
2213+ return shard_tensor (
2214+ tensor ,
2215+ mesh = param .process_mesh ,
2216+ placements = placements ,
2217+ )
2218+ return tensor
2219+
2220+ @staticmethod
2221+ def _grad_hook (grad ):
2222+ # do reshard only if the grad is dist tensor and in partial status
2223+ if grad .is_dist ():
2224+ partial_mesh_axis = None
2225+ for mesh_axis , placement in enumerate (grad .placements ):
2226+ if isinstance (placement , dist .Partial ):
2227+ partial_mesh_axis = mesh_axis
2228+ if partial_mesh_axis is not None :
2229+ new_placements = get_placement_with_sharding (
2230+ grad , partial_mesh_axis
2231+ )
2232+ return reshard (grad , grad .process_mesh , new_placements )
2233+
2234+ return grad
2235+
2236+ def _register_hook_for_param_grad (self , param ):
2237+ if param .is_dense () and self ._mesh is not None :
2238+ placements = []
2239+ for _ in range (len (self ._mesh .shape )):
2240+ placements .append (dist .Replicate ())
2241+ param ._to_dist_ (placements , self ._mesh )
2242+ if param .is_dist ():
2243+ param .register_hook (ShardingStage2 ._grad_hook )
21922244
21932245
21942246class ShardingStage3 (_ShardingStageBase ):
0 commit comments