File tree Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Expand file tree Collapse file tree 1 file changed +19
-0
lines changed Original file line number Diff line number Diff line change @@ -725,6 +725,25 @@ def maybe_enable_compiled_autograd(
725725 ):
726726 loss .backward (retain_graph = self .loss_backward_retain_graph )
727727
728+ # l1 grad nom
729+ if isinstance (self .module , DDP ):
730+ _module = self .module .module
731+ else :
732+ _module = self .module
733+ if hasattr (_module , "gradient_l1_loss" ):
734+ gradient_l1_loss = _module .gradient_l1_loss ()
735+ with maybe_enable_compiled_autograd (self .enable_compiled_autograd ):
736+ if grad_scaler :
737+ gradient_l1_loss = grad_scaler .scale (gradient_l1_loss )
738+ gradient_l1_loss .backward (
739+ retain_graph = self .loss_backward_retain_graph
740+ )
741+ else :
742+ gradient_l1_loss .backward (
743+ retain_graph = self .loss_backward_retain_graph
744+ )
745+ loss = loss + gradient_l1_loss
746+
728747 total_grad_norm = None
729748 if should_update_weights :
730749 total_grad_norm = self ._update_weights (state )
You can’t perform that action at this time.
0 commit comments