Skip to content

Commit 1752c2a

Browse files
authored
Fix bf16 gradient norm divergence with ZeRO stage 0 (#7839)
Fixes: #7837 ZeRO-0 + bf16 has two bugs in `engine.py`: 1. `FP16_UnfusedOptimizer` applies `dynamic_loss_scale` with `cur_scale=65536` but `engine.backward()` never scales the loss, so `step()` divides gradients by 65536 2. `_take_model_step` skips `zero_grad` for bf16 without ZeRO, causing gradient accumulation. Fix: disable loss scaling for bf16 and remove the `zero_optimization()` gate on `zero_grad`. --------- Signed-off-by: Masahiro Tanaka <mtanaka@anyscale.com>
1 parent 2c36283 commit 1752c2a

File tree

8 files changed

+351
-212
lines changed

8 files changed

+351
-212
lines changed

deepspeed/runtime/engine.py

Lines changed: 50 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,7 @@
3737
sync_zenflow_optimizer_lr)
3838

3939
from deepspeed.runtime.fp16.fused_optimizer import FP16_Optimizer
40+
from deepspeed.runtime.fp16.loss_scaler import LossScaleConfig, LossScaleProfile
4041
from deepspeed.runtime.fp16.unfused_optimizer import FP16_UnfusedOptimizer
4142
from deepspeed.runtime.bf16_optimizer import BF16_Optimizer
4243

@@ -1773,7 +1774,6 @@ def _configure_quantization(self):
17731774
return quantizer
17741775

17751776
def _configure_fp16_optimizer(self, optimizer, low_precision_dtype):
1776-
initial_dynamic_scale = self.initial_dynamic_scale()
17771777
dynamic_loss_args = self.dynamic_loss_scale_args()
17781778
clip_grad = self.gradient_clipping()
17791779

@@ -1782,46 +1782,47 @@ def _configure_fp16_optimizer(self, optimizer, low_precision_dtype):
17821782
else:
17831783
fused_opts = FusedAdam
17841784

1785-
if isinstance(optimizer, fused_opts) \
1786-
or self.optimizer_name() in [ONEBIT_ADAM_OPTIMIZER, ZERO_ONE_ADAM_OPTIMIZER]:
1787-
if self.dynamic_loss_scale():
1785+
use_fused_optimizer = isinstance(optimizer, fused_opts) \
1786+
or self.optimizer_name() in [ONEBIT_ADAM_OPTIMIZER, ZERO_ONE_ADAM_OPTIMIZER]
1787+
loss_scale_profile = LossScaleProfile.FUSED if use_fused_optimizer else LossScaleProfile.UNFUSED
1788+
initial_dynamic_scale = self.initial_dynamic_scale() if loss_scale_profile == LossScaleProfile.FUSED else None
1789+
loss_scale_config = LossScaleConfig(
1790+
low_precision_dtype=low_precision_dtype,
1791+
dynamic_loss_scale=self.dynamic_loss_scale(),
1792+
static_loss_scale=self.loss_scale(),
1793+
dynamic_loss_args=dynamic_loss_args,
1794+
profile=loss_scale_profile,
1795+
initial_dynamic_scale=initial_dynamic_scale,
1796+
)
1797+
1798+
if use_fused_optimizer:
1799+
if loss_scale_config.dynamic_loss_scale:
17881800
log_dist('Creating fp16 optimizer with dynamic loss scale', ranks=[0])
1789-
timers = self.timers if self.wall_clock_breakdown() else NoopTimer()
1790-
optimizer = FP16_Optimizer(
1791-
optimizer,
1792-
deepspeed=self,
1793-
low_precision_dtype=low_precision_dtype,
1794-
dynamic_loss_scale=True,
1795-
initial_dynamic_scale=initial_dynamic_scale,
1796-
dynamic_loss_args=dynamic_loss_args,
1797-
mpu=self.mpu,
1798-
clip_grad=clip_grad,
1799-
fused_adam_legacy=self.optimizer_legacy_fusion(),
1800-
timers=timers,
1801-
has_moe_layers=self.has_moe_layers,
1802-
)
18031801
else:
1804-
log_dist(f'Creating fp16 optimizer with static loss scale: {self.loss_scale()}', ranks=[0])
1805-
timers = self.timers if self.wall_clock_breakdown() else NoopTimer()
1806-
optimizer = FP16_Optimizer(
1807-
optimizer,
1808-
deepspeed=self,
1809-
low_precision_dtype=low_precision_dtype,
1810-
static_loss_scale=self.loss_scale(),
1811-
mpu=self.mpu,
1812-
clip_grad=clip_grad,
1813-
fused_adam_legacy=self.optimizer_legacy_fusion(),
1814-
timers=timers,
1815-
has_moe_layers=self.has_moe_layers,
1816-
)
1802+
log_dist(f'Creating fp16 optimizer with static loss scale: {loss_scale_config.cur_scale}', ranks=[0])
1803+
timers = self.timers if self.wall_clock_breakdown() else NoopTimer()
1804+
optimizer = FP16_Optimizer(
1805+
optimizer,
1806+
deepspeed=self,
1807+
loss_scale_config=loss_scale_config,
1808+
low_precision_dtype=low_precision_dtype,
1809+
mpu=self.mpu,
1810+
clip_grad=clip_grad,
1811+
fused_adam_legacy=self.optimizer_legacy_fusion(),
1812+
timers=timers,
1813+
has_moe_layers=self.has_moe_layers,
1814+
)
18171815
else:
1818-
log_dist('Creating fp16 unfused optimizer with dynamic loss scale', ranks=[0])
1816+
if loss_scale_config.dynamic_loss_scale:
1817+
log_dist('Creating fp16 unfused optimizer with dynamic loss scale', ranks=[0])
1818+
else:
1819+
log_dist(f'Creating fp16 unfused optimizer with static loss scale: {loss_scale_config.cur_scale}',
1820+
ranks=[0])
18191821
optimizer = FP16_UnfusedOptimizer(
18201822
optimizer,
18211823
deepspeed=self,
1822-
static_loss_scale=self.loss_scale(),
1823-
dynamic_loss_scale=self.dynamic_loss_scale(),
1824-
dynamic_loss_args=dynamic_loss_args,
1824+
loss_scale_config=loss_scale_config,
1825+
low_precision_dtype=low_precision_dtype,
18251826
mpu=self.mpu,
18261827
clip_grad=clip_grad,
18271828
fused_lamb_legacy=self.optimizer_name() == LAMB_OPTIMIZER,
@@ -2682,10 +2683,10 @@ def _take_model_step(self, lr_kwargs, block_eigenvalue={}):
26822683
# the behavior that we want
26832684
if self.bfloat16_enabled():
26842685
# TODO: Temporary until bf16_optimizer and zero_optimizer are integrated
2685-
if self.zero_optimization() and hasattr(self.optimizer, "zero_grad"):
2686+
if hasattr(self.optimizer, "zero_grad"):
26862687
self.optimizer.zero_grad()
26872688
else:
2688-
pass
2689+
self.zero_grad()
26892690
elif self.zero_optimization() or self.fp16_enabled() or self.amp_enabled():
26902691
self.optimizer.zero_grad()
26912692
else:
@@ -2757,8 +2758,8 @@ def step(self, lr_kwargs=None):
27572758
if (self.eigenvalue_enabled() and (self.gas_boundary_ctr % self.eigenvalue_gas_boundary_resolution() == 0)
27582759
and self.quantizer.any_precision_switch()):
27592760
log_dist("computing eigenvalue...", ranks=[0])
2760-
self.block_eigenvalue = self.eigenvalue.compute_eigenvalue(self.module, self.device,
2761-
self.optimizer.cur_scale)
2761+
loss_scale = self._get_optimizer_loss_scale() or 1.0
2762+
self.block_eigenvalue = self.eigenvalue.compute_eigenvalue(self.module, self.device, loss_scale)
27622763

27632764
if self.progressive_layer_drop:
27642765
self.progressive_layer_drop.update_state(self.global_steps)
@@ -2784,10 +2785,11 @@ def step(self, lr_kwargs=None):
27842785
if self.global_rank == 0:
27852786
self.summary_events = [("Train/Samples/lr", self.get_lr()[0], self.global_samples)]
27862787

2787-
if self.fp16_enabled() and hasattr(self.optimizer, "cur_scale"):
2788+
loss_scale = self._get_optimizer_loss_scale() if self.fp16_enabled() else None
2789+
if loss_scale is not None:
27882790
self.summary_events.append((
27892791
"Train/Samples/loss_scale",
2790-
self.optimizer.cur_scale,
2792+
loss_scale,
27912793
self.global_samples,
27922794
))
27932795

@@ -2930,6 +2932,13 @@ def _get_optimizer_param(self, param_name):
29302932
result.append(0.0)
29312933
return result
29322934

2935+
def _get_optimizer_loss_scale(self):
2936+
if not self.optimizer:
2937+
return None
2938+
if hasattr(self.optimizer, "loss_scale_config"):
2939+
return self.optimizer.loss_scale_config.cur_scale
2940+
return getattr(self.optimizer, "cur_scale", None)
2941+
29332942
def get_lr(self):
29342943
return self._get_optimizer_param("lr")
29352944

0 commit comments

Comments
 (0)