File tree Expand file tree Collapse file tree 1 file changed +6
-12
lines changed Expand file tree Collapse file tree 1 file changed +6
-12
lines changed Original file line number Diff line number Diff line change @@ -879,19 +879,13 @@ def _update_weights(self, state: State) -> Optional[torch.Tensor]:
879879 total_grad_norm = None
880880 # gradient norm clipping
881881 if clip_grad_norm :
882- if _is_fsdp_module (module ):
883- if isinstance (module , FSDP ):
884- with get_timing_context (
885- state , f"{ self .__class__ .__name__ } .clip_grad_norm"
886- ):
887- total_grad_norm = module .clip_grad_norm_ (
888- max_norm = clip_grad_norm
889- )
890- else :
891- raise RuntimeError (
892- "Composable FSDP clip_grad_norm is not yet implemented: https://github.com/pytorch/pytorch/issues/97271"
893- )
882+ if isinstance (module , FSDP ):
883+ with get_timing_context (
884+ state , f"{ self .__class__ .__name__ } .clip_grad_norm"
885+ ):
886+ total_grad_norm = module .clip_grad_norm_ (max_norm = clip_grad_norm )
894887 else :
888+ # strategy=None, DDP, and FSDP2 will work with this
895889 with get_timing_context (
896890 state , f"{ self .__class__ .__name__ } .clip_grad_norm"
897891 ):
You can’t perform that action at this time.
0 commit comments