Skip to content
Open
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 20 additions & 3 deletions src/accelerate/utils/fsdp_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -657,14 +657,31 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
fsdp2_plugin.ignored_modules, model, accelerator.device
)

model_has_params4bit = False
params4bit = []
for name, param in model.named_parameters():
# this is a temporary fix whereby loading models with bnb params cannot be moved from
# GPU to a meta device due with FSDP2 because torch operations don't return the original class type
# bypassing the move to meta will still cause the VRAM spike, but at least it still will load
if param.__class__.__name__ == "Params4bit":
model_has_params4bit = True
break
params4bit.append(param)

model_has_params4bit = len(params4bit) > 0

# Exclude non-floating frozen Params4bit from FSDP sharding.
# Default uint8 quant_storage cannot survive fully_shard's DTensor conversion.
if model_has_params4bit and is_torch_version(">=", "2.7.0"):
incompatible_params4bit = {
p for p in params4bit
if (not p.requires_grad) and (not p.is_floating_point()) and (not p.is_complex())
}
if incompatible_params4bit:
ignored = set(fsdp2_kwargs.get("ignored_params", set()))
fsdp2_kwargs["ignored_params"] = ignored | incompatible_params4bit
if accelerator.is_main_process:
logger.info(
f"Found {len(incompatible_params4bit)} non-floating frozen Params4bit. "
"Excluding from FSDP2 sharding to prevent quant_state corruption."
)
Comment on lines +680 to +686
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

then maybe we should warn the user to set the type for the storage to get the right memory saving

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, updated the message to guide users toward the floating quant_storage path.


if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# Context: `fully_shard` moves the model to GPU if it was on CPU, however it can also be on `meta` and then it stays there even after `fully_shard`
Expand Down
88 changes: 88 additions & 0 deletions tests/fsdp/test_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,6 +471,94 @@ def test_param_mapping_error_handling(self):

AcceleratorState._reset_state(True)

def test_fsdp2_ignored_params_non_floating_params4bit(self):
"""Test that non-floating frozen Params4bit are auto-excluded from FSDP sharding via ignored_params."""
from unittest.mock import Mock, patch
from accelerate.utils.fsdp_utils import fsdp2_prepare_model

model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.Linear(4, 4),
)
# Simulate Params4bit: frozen + uint8 (non-floating)
fake_param = torch.nn.Parameter(torch.randint(0, 255, (4, 4), dtype=torch.uint8), requires_grad=False)
fake_param.__class__ = type("Params4bit", (torch.nn.Parameter,), {})
model[0].weight = fake_param

mock_accelerator = Mock()
mock_accelerator.mixed_precision = "no"
mock_accelerator.torch_device_mesh = None
mock_accelerator.device = torch.device("cpu")
mock_accelerator.is_main_process = True

mock_plugin = Mock()
mock_plugin.mixed_precision_policy = None
mock_plugin.reshard_after_forward = True
mock_plugin.cpu_offload = None
mock_plugin.cpu_ram_efficient_loading = False
mock_plugin.ignored_modules = None
mock_accelerator.state.fsdp_plugin = mock_plugin

captured_kwargs = {}
def mock_fully_shard(module, **kwargs):
captured_kwargs.update(kwargs)

with (
patch("torch.distributed.fsdp.fully_shard", side_effect=mock_fully_shard),
patch("accelerate.utils.fsdp_utils.is_compiled_module", return_value=False),
patch("accelerate.utils.fsdp_utils.fsdp2_prepare_auto_wrap_policy", return_value=None),
patch("accelerate.utils.fsdp_utils.is_torch_version", return_value=True),
patch("accelerate.utils.fsdp_utils.logger"),
):
fsdp2_prepare_model(mock_accelerator, model)

# Verify the fake Params4bit was added to ignored_params
ignored = captured_kwargs.get("ignored_params", set())
assert fake_param in ignored, f"Expected Params4bit in ignored_params, got {ignored}"

def test_fsdp2_floating_params4bit_not_ignored(self):
"""Test that floating Params4bit are not excluded from sharding."""
from unittest.mock import Mock, patch
from accelerate.utils.fsdp_utils import fsdp2_prepare_model

model = torch.nn.Sequential(
torch.nn.Linear(4, 4),
torch.nn.Linear(4, 4),
)
# Simulate floating Params4bit: frozen + bf16
fake_param = torch.nn.Parameter(torch.randn(4, 4, dtype=torch.bfloat16), requires_grad=False)
fake_param.__class__ = type("Params4bit", (torch.nn.Parameter,), {})
model[0].weight = fake_param

mock_accelerator = Mock()
mock_accelerator.mixed_precision = "no"
mock_accelerator.torch_device_mesh = None
mock_accelerator.device = torch.device("cpu")
mock_accelerator.is_main_process = True

mock_plugin = Mock()
mock_plugin.mixed_precision_policy = None
mock_plugin.reshard_after_forward = True
mock_plugin.cpu_offload = None
mock_plugin.cpu_ram_efficient_loading = False
mock_plugin.ignored_modules = None
mock_accelerator.state.fsdp_plugin = mock_plugin

captured_kwargs = {}
def mock_fully_shard(module, **kwargs):
captured_kwargs.update(kwargs)

with (
patch("torch.distributed.fsdp.fully_shard", side_effect=mock_fully_shard),
patch("accelerate.utils.fsdp_utils.is_compiled_module", return_value=False),
patch("accelerate.utils.fsdp_utils.fsdp2_prepare_auto_wrap_policy", return_value=None),
patch("accelerate.utils.fsdp_utils.is_torch_version", return_value=True),
):
fsdp2_prepare_model(mock_accelerator, model)

# Floating Params4bit should not be in ignored_params
ignored = captured_kwargs.get("ignored_params", set())
assert fake_param not in ignored, f"Floating Params4bit should not be ignored, got {ignored}"

@run_first
# Skip this test when TorchXLA is available because accelerate.launch does not support TorchXLA FSDP.
Expand Down