Skip to content
Draft
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
100 changes: 82 additions & 18 deletions src/axolotl/monkeypatch/accelerate/fsdp2.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,13 +180,8 @@ def get_state_dict(self, model, unwrap=True):

return state_dict


def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
"""Helper function to process LoRA modules for FSDP2."""
from torch.distributed.fsdp import fully_shard

log_bias_dtype_mismatch = False

def cast_lora_module(module):
base_layer_dtype = module.base_layer.weight.dtype
# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
Expand All @@ -198,15 +193,71 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs):

for active_adapter in module.active_adapters:
if module.lora_A:
fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None:
module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_B:
fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None:
module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_A:
fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None:
module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_embedding_B:
fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None:
module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype)
if module.lora_magnitude_vector:
fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype)
if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None:
module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype)

def _process_lora_module_for_fsdp(module, fsdp2_kwargs):
"""Helper function to process LoRA modules for FSDP2."""
from torch.distributed.fsdp import fully_shard

log_bias_dtype_mismatch = False

# Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to
# wrap this. Therefore we must ensure the bias has the same dtype as the weight
if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None:
if module.base_layer.weight.dtype != module.base_layer.bias.dtype:
log_bias_dtype_mismatch = True
module.base_layer.bias.data = module.base_layer.bias.data.to(
module.base_layer.weight.dtype
)
fully_shard(module, **fsdp2_kwargs)
module.set_reshard_after_forward(False)
module.set_reshard_after_backward(False)
# for active_adapter in module.active_adapters:
# for adapter_name in [
# "lora_A",
# "lora_B",
# "lora_embedding_A",
# "lora_embedding_B",
# "lora_magnitude_vector",
# ]:
# adapter_module = getattr(module, adapter_name, None)
# # print(adapter_module, adapter_name)
# # torch.distributed.breakpoint()
# if not adapter_module:
# continue
# fsdp_adapter_module = fully_shard(adapter_module[active_adapter], **fsdp2_kwargs)
# # fsdp_adapter_module.unshard()
# fsdp_adapter_module.set_reshard_after_backward(False)
# fsdp_adapter_module.set_reshard_after_forward(False)
# torch.distributed.breakpoint()
# if module.lora_A:
# fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs)
# if module.lora_B:
# fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs)
# if module.lora_embedding_A:
# fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs)
# if module.lora_embedding_B:
# fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs)
# if module.lora_magnitude_vector:
# fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs)
return log_bias_dtype_mismatch


Expand Down Expand Up @@ -320,16 +371,26 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
model.tie_weights()

is_peft_model = isinstance(model, PeftModel)

# TODO - this doesn't actually do anything
for name, module in model.named_children():
if name == "experts":
# torch.distributed.breakpoint()
for expert in module.children():
# torch.distributed.breakpoint()
print(f"expert: {expert}")
for lora_module in expert.children():
print(f"lora {lora_module}")
# torch.distributed.breakpoint()
cast_lora_module(lora_module)
_process_lora_module_for_fsdp(lora_module, fsdp2_kwargs)
auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model)
log_bias_dtype_mismatch = False
if auto_wrap_policy is not None:
for module in get_module_children_bottom_up(model)[:-1]:
if is_peft_model and isinstance(module, LoraLayer):
module_log_bias_mismatch = _process_lora_module_for_fsdp(
module, fsdp2_kwargs
)
log_bias_dtype_mismatch |= module_log_bias_mismatch
if is_peft_model and isinstance(module, LoraLayer) and not isinstance(module, FSDPModule):
# torch.distributed.breakpoint()
cast_lora_module(module)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could put this behind a config field - the only downside is that upstreaming it is going to be a pain

# torch.distributed.breakpoint()
if auto_wrap_policy(module) and not isinstance(module, FSDPModule):
fully_shard(module, **fsdp2_kwargs)

Expand All @@ -346,6 +407,9 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module:
accelerator, model, original_sd, offload_to_cpu=offload_to_cpu
)

# for module in model.named_modules():
# if "Lora" in

if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit:
# We re-register the buffers, as they may not be in the state_dict
for fqn, buffer_tensor in original_non_persistent_buffers.items():
Expand Down