File tree Expand file tree Collapse file tree 1 file changed +13
-6
lines changed
Expand file tree Collapse file tree 1 file changed +13
-6
lines changed Original file line number Diff line number Diff line change @@ -1169,13 +1169,20 @@ def no_sync(self, model):
11691169 >>> optimizer.zero_grad()
11701170 ```
11711171 """
1172- context = contextlib .nullcontext
1173- if self .use_distributed :
1174- if self .distributed_type != DistributedType .DEEPSPEED or self .state .deepspeed_plugin .zero_stage < 2 :
1175- context = getattr (model , "no_sync" , context )
1172+ if self .is_fsdp2 :
1173+ model .set_requires_gradient_sync (False )
1174+ try :
1175+ yield
1176+ finally :
1177+ model .set_requires_gradient_sync (True )
1178+ else :
1179+ context = contextlib .nullcontext
1180+ if self .use_distributed :
1181+ if self .distributed_type != DistributedType .DEEPSPEED or self .state .deepspeed_plugin .zero_stage < 2 :
1182+ context = getattr (model , "no_sync" , context )
11761183
1177- with context ():
1178- yield
1184+ with context ():
1185+ yield
11791186
11801187 @staticmethod
11811188 @contextmanager
You can’t perform that action at this time.
0 commit comments