Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 31 additions & 3 deletions swift/trainers/mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@
from swift.template import Template, update_generation_config_eos_token
from swift.tuner_plugin import tuners_map
from swift.tuners import SwiftModel
from swift.utils import (HfConfigFactory, copy_files_by_pattern, deep_getattr, get_current_device, get_logger,
get_packed_seq_params, is_dist, is_mp, is_mp_ddp, ms_logger_context, seed_worker)
from swift.utils import (HfConfigFactory, copy_files_by_pattern, deep_getattr, get_current_device, get_dist_setting,
get_logger, get_packed_seq_params, is_dist, is_mp, is_mp_ddp, ms_logger_context, seed_worker)
from .arguments import TrainingArguments
from .utils import (can_return_loss, dynamic_gradient_checkpointing, find_labels, get_function, get_resume_dir,
is_instance_of_ms_model, patch_modelscope_hub_timeout, replace_index_file)
Expand Down Expand Up @@ -658,6 +658,34 @@ def clip_grad_norm_(self, parameters, *args, **kwargs):
finally:
Accelerator.clip_grad_norm_ = origin_clip_grad_norm_

def _get_reduced_grad_norm_for_logging(self, grad_norm) -> Optional[float]:
"""Reduce grad_norm across processes for consistent logging (fix #6815).

Under DeepSpeed ZeRO-0/1/2 or plain DDP, each rank may report a different
gradient norm (e.g. local view before reduce). ZeRO-3 reports a global norm.
We all-reduce (average) grad_norm when not ZeRO-3 and world_size > 1 so
that logged grad_norm is consistent and comparable across ZeRO stages.
"""
if grad_norm is None:
return None
if not isinstance(grad_norm, torch.Tensor):
return float(grad_norm)
if not is_dist():
return grad_norm.item()
_, _, world_size, _ = get_dist_setting()
if world_size <= 1:
return grad_norm.item()
if is_deepspeed_zero3_enabled():
return grad_norm.item()
try:
gn = grad_norm.clone().detach().float()
if not gn.is_cuda:
gn = gn.to(self.accelerator.device)
dist.all_reduce(gn, op=dist.ReduceOp.AVG)
return gn.item()
except Exception:
return grad_norm.item()
Comment on lines +686 to +687
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The broad except Exception: without logging can hide issues during gradient norm reduction. If an error occurs, it will silently fall back to using the un-reduced gradient norm, which could be misleading for monitoring. It's better to log the exception to make debugging easier.

Suggested change
except Exception:
return grad_norm.item()
except Exception as e:
logger.warning(f'Failed to reduce grad_norm for logging: {e}. Returning un-reduced value.')
return grad_norm.item()


def _patch_tasks(self):
if isinstance(self.model, PeftModel):
model = self.model.model
Expand Down Expand Up @@ -953,7 +981,7 @@ def _maybe_log_save_evaluate(self, tr_loss, *args, **kwargs):
if version.parse(transformers.__version__) >= version.parse('4.38'):
grad_norm = args[0]
if grad_norm is not None:
logs['grad_norm'] = grad_norm.item() if isinstance(grad_norm, torch.Tensor) else grad_norm
logs['grad_norm'] = self._get_reduced_grad_norm_for_logging(grad_norm)
logs['learning_rate'] = self._get_learning_rate()
tr_loss -= tr_loss
self._total_loss_scalar += tr_loss_scalar
Expand Down
56 changes: 56 additions & 0 deletions tests/train/test_grad_norm_reduce.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
"""Tests for grad_norm all-reduce under ZeRO-0/DDP (fix #6815)."""
import unittest
from unittest.mock import MagicMock, patch

import torch

from swift.trainers.mixin import SwiftMixin


def _make_trainer():
trainer = MagicMock()
trainer.accelerator = MagicMock()
trainer.accelerator.device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
return trainer


class TestGradNormReduce(unittest.TestCase):
"""Test _get_reduced_grad_norm_for_logging for consistent grad_norm logging under ZeRO-0."""

def test_grad_norm_none(self):
trainer = _make_trainer()
self.assertIsNone(SwiftMixin._get_reduced_grad_norm_for_logging(trainer, None))

def test_grad_norm_float(self):
trainer = _make_trainer()
self.assertEqual(SwiftMixin._get_reduced_grad_norm_for_logging(trainer, 1.5), 1.5)

def test_grad_norm_tensor_single_process(self):
trainer = _make_trainer()
with patch('swift.trainers.mixin.is_dist', return_value=False):
gn = torch.tensor(2.0)
self.assertEqual(SwiftMixin._get_reduced_grad_norm_for_logging(trainer, gn), 2.0)

def test_grad_norm_tensor_dist_zero3_no_reduce(self):
trainer = _make_trainer()
with patch('swift.trainers.mixin.is_dist', return_value=True), \
patch('swift.trainers.mixin.get_dist_setting', return_value=(0, 0, 2, 2)), \
patch('swift.trainers.mixin.is_deepspeed_zero3_enabled', return_value=True):
gn = torch.tensor(0.025)
out = SwiftMixin._get_reduced_grad_norm_for_logging(trainer, gn)
self.assertAlmostEqual(out, 0.025)

def test_grad_norm_tensor_dist_zero0_reduce(self):
trainer = _make_trainer()
with patch('swift.trainers.mixin.is_dist', return_value=True), \
patch('swift.trainers.mixin.get_dist_setting', return_value=(0, 0, 2, 2)), \
patch('swift.trainers.mixin.is_deepspeed_zero3_enabled', return_value=False), \
patch('torch.distributed.all_reduce') as mock_all_reduce:
gn = torch.tensor(1656.0)
def _side_effect(tensor, *args, **kwargs):
tensor.fill_(tensor.item() / 2)
mock_all_reduce.side_effect = _side_effect
out = SwiftMixin._get_reduced_grad_norm_for_logging(trainer, gn)
self.assertEqual(mock_all_reduce.call_count, 1)
self.assertAlmostEqual(out, 828.0)