3737 sync_zenflow_optimizer_lr )
3838
3939from deepspeed .runtime .fp16 .fused_optimizer import FP16_Optimizer
40+ from deepspeed .runtime .fp16 .loss_scaler import LossScaleConfig , LossScaleProfile
4041from deepspeed .runtime .fp16 .unfused_optimizer import FP16_UnfusedOptimizer
4142from deepspeed .runtime .bf16_optimizer import BF16_Optimizer
4243
@@ -1773,7 +1774,6 @@ def _configure_quantization(self):
17731774 return quantizer
17741775
17751776 def _configure_fp16_optimizer (self , optimizer , low_precision_dtype ):
1776- initial_dynamic_scale = self .initial_dynamic_scale ()
17771777 dynamic_loss_args = self .dynamic_loss_scale_args ()
17781778 clip_grad = self .gradient_clipping ()
17791779
@@ -1782,46 +1782,47 @@ def _configure_fp16_optimizer(self, optimizer, low_precision_dtype):
17821782 else :
17831783 fused_opts = FusedAdam
17841784
1785- if isinstance (optimizer , fused_opts ) \
1786- or self .optimizer_name () in [ONEBIT_ADAM_OPTIMIZER , ZERO_ONE_ADAM_OPTIMIZER ]:
1787- if self .dynamic_loss_scale ():
1785+ use_fused_optimizer = isinstance (optimizer , fused_opts ) \
1786+ or self .optimizer_name () in [ONEBIT_ADAM_OPTIMIZER , ZERO_ONE_ADAM_OPTIMIZER ]
1787+ loss_scale_profile = LossScaleProfile .FUSED if use_fused_optimizer else LossScaleProfile .UNFUSED
1788+ initial_dynamic_scale = self .initial_dynamic_scale () if loss_scale_profile == LossScaleProfile .FUSED else None
1789+ loss_scale_config = LossScaleConfig (
1790+ low_precision_dtype = low_precision_dtype ,
1791+ dynamic_loss_scale = self .dynamic_loss_scale (),
1792+ static_loss_scale = self .loss_scale (),
1793+ dynamic_loss_args = dynamic_loss_args ,
1794+ profile = loss_scale_profile ,
1795+ initial_dynamic_scale = initial_dynamic_scale ,
1796+ )
1797+
1798+ if use_fused_optimizer :
1799+ if loss_scale_config .dynamic_loss_scale :
17881800 log_dist ('Creating fp16 optimizer with dynamic loss scale' , ranks = [0 ])
1789- timers = self .timers if self .wall_clock_breakdown () else NoopTimer ()
1790- optimizer = FP16_Optimizer (
1791- optimizer ,
1792- deepspeed = self ,
1793- low_precision_dtype = low_precision_dtype ,
1794- dynamic_loss_scale = True ,
1795- initial_dynamic_scale = initial_dynamic_scale ,
1796- dynamic_loss_args = dynamic_loss_args ,
1797- mpu = self .mpu ,
1798- clip_grad = clip_grad ,
1799- fused_adam_legacy = self .optimizer_legacy_fusion (),
1800- timers = timers ,
1801- has_moe_layers = self .has_moe_layers ,
1802- )
18031801 else :
1804- log_dist (f'Creating fp16 optimizer with static loss scale: { self . loss_scale () } ' , ranks = [0 ])
1805- timers = self .timers if self .wall_clock_breakdown () else NoopTimer ()
1806- optimizer = FP16_Optimizer (
1807- optimizer ,
1808- deepspeed = self ,
1809- low_precision_dtype = low_precision_dtype ,
1810- static_loss_scale = self . loss_scale () ,
1811- mpu = self .mpu ,
1812- clip_grad = clip_grad ,
1813- fused_adam_legacy = self .optimizer_legacy_fusion (),
1814- timers = timers ,
1815- has_moe_layers = self .has_moe_layers ,
1816- )
1802+ log_dist (f'Creating fp16 optimizer with static loss scale: { loss_scale_config . cur_scale } ' , ranks = [0 ])
1803+ timers = self .timers if self .wall_clock_breakdown () else NoopTimer ()
1804+ optimizer = FP16_Optimizer (
1805+ optimizer ,
1806+ deepspeed = self ,
1807+ loss_scale_config = loss_scale_config ,
1808+ low_precision_dtype = low_precision_dtype ,
1809+ mpu = self .mpu ,
1810+ clip_grad = clip_grad ,
1811+ fused_adam_legacy = self .optimizer_legacy_fusion (),
1812+ timers = timers ,
1813+ has_moe_layers = self .has_moe_layers ,
1814+ )
18171815 else :
1818- log_dist ('Creating fp16 unfused optimizer with dynamic loss scale' , ranks = [0 ])
1816+ if loss_scale_config .dynamic_loss_scale :
1817+ log_dist ('Creating fp16 unfused optimizer with dynamic loss scale' , ranks = [0 ])
1818+ else :
1819+ log_dist (f'Creating fp16 unfused optimizer with static loss scale: { loss_scale_config .cur_scale } ' ,
1820+ ranks = [0 ])
18191821 optimizer = FP16_UnfusedOptimizer (
18201822 optimizer ,
18211823 deepspeed = self ,
1822- static_loss_scale = self .loss_scale (),
1823- dynamic_loss_scale = self .dynamic_loss_scale (),
1824- dynamic_loss_args = dynamic_loss_args ,
1824+ loss_scale_config = loss_scale_config ,
1825+ low_precision_dtype = low_precision_dtype ,
18251826 mpu = self .mpu ,
18261827 clip_grad = clip_grad ,
18271828 fused_lamb_legacy = self .optimizer_name () == LAMB_OPTIMIZER ,
@@ -2682,10 +2683,10 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
26822683 # the behavior that we want
26832684 if self .bfloat16_enabled ():
26842685 # TODO: Temporary until bf16_optimizer and zero_optimizer are integrated
2685- if self . zero_optimization () and hasattr (self .optimizer , "zero_grad" ):
2686+ if hasattr (self .optimizer , "zero_grad" ):
26862687 self .optimizer .zero_grad ()
26872688 else :
2688- pass
2689+ self . zero_grad ()
26892690 elif self .zero_optimization () or self .fp16_enabled () or self .amp_enabled ():
26902691 self .optimizer .zero_grad ()
26912692 else :
@@ -2757,8 +2758,8 @@ def step(self, lr_kwargs=None):
27572758 if (self .eigenvalue_enabled () and (self .gas_boundary_ctr % self .eigenvalue_gas_boundary_resolution () == 0 )
27582759 and self .quantizer .any_precision_switch ()):
27592760 log_dist ("computing eigenvalue..." , ranks = [0 ])
2760- self . block_eigenvalue = self .eigenvalue . compute_eigenvalue ( self . module , self . device ,
2761- self .optimizer . cur_scale )
2761+ loss_scale = self ._get_optimizer_loss_scale () or 1.0
2762+ self . block_eigenvalue = self .eigenvalue . compute_eigenvalue ( self . module , self . device , loss_scale )
27622763
27632764 if self .progressive_layer_drop :
27642765 self .progressive_layer_drop .update_state (self .global_steps )
@@ -2784,10 +2785,11 @@ def step(self, lr_kwargs=None):
27842785 if self .global_rank == 0 :
27852786 self .summary_events = [("Train/Samples/lr" , self .get_lr ()[0 ], self .global_samples )]
27862787
2787- if self .fp16_enabled () and hasattr (self .optimizer , "cur_scale" ):
2788+ loss_scale = self ._get_optimizer_loss_scale () if self .fp16_enabled () else None
2789+ if loss_scale is not None :
27882790 self .summary_events .append ((
27892791 "Train/Samples/loss_scale" ,
2790- self . optimizer . cur_scale ,
2792+ loss_scale ,
27912793 self .global_samples ,
27922794 ))
27932795
@@ -2930,6 +2932,13 @@ def _get_optimizer_param(self, param_name):
29302932 result .append (0.0 )
29312933 return result
29322934
2935+ def _get_optimizer_loss_scale (self ):
2936+ if not self .optimizer :
2937+ return None
2938+ if hasattr (self .optimizer , "loss_scale_config" ):
2939+ return self .optimizer .loss_scale_config .cur_scale
2940+ return getattr (self .optimizer , "cur_scale" , None )
2941+
29332942 def get_lr (self ):
29342943 return self ._get_optimizer_param ("lr" )
29352944
0 commit comments