From 6daed7d060d758e6c2ed490a52b855fa18866ed3 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Tue, 9 Sep 2025 17:11:13 +0100 Subject: [PATCH 1/4] dont keep adpater weights in fp32 --- src/axolotl/monkeypatch/accelerate/fsdp2.py | 41 +++++++++++++++++++-- 1 file changed, 37 insertions(+), 4 deletions(-) diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 3b38a33b70..9351846525 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -178,6 +178,38 @@ def get_state_dict(self, model, unwrap=True): return state_dict +def cast_lora_module(module): + base_layer_dtype = module.base_layer.weight.dtype + # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to + # wrap this. Therefore we must ensure the bias has the same dtype as the weight + if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None: + if module.base_layer.weight.dtype != module.base_layer.bias.dtype: + log_bias_dtype_mismatch = True + module.base_layer.bias.data = module.base_layer.bias.data.to( + module.base_layer.weight.dtype + ) + + for active_adapter in module.active_adapters: + if module.lora_A: + module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype) + if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None: + module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype) + if module.lora_B: + module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype) + if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None: + module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype) + if module.lora_embedding_A: + module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype) + if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None: + module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype) + if module.lora_embedding_B: + module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype) + if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None: + module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype) + if module.lora_magnitude_vector: + module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype) + if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None: + module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype) def _process_lora_module_for_fsdp(module, fsdp2_kwargs): """Helper function to process LoRA modules for FSDP2.""" @@ -324,10 +356,11 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: if auto_wrap_policy is not None: for module in get_module_children_bottom_up(model)[:-1]: if is_peft_model and isinstance(module, LoraLayer): - module_log_bias_mismatch = _process_lora_module_for_fsdp( - module, fsdp2_kwargs - ) - log_bias_dtype_mismatch |= module_log_bias_mismatch + cast_lora_module(module) + # module_log_bias_mismatch = _process_lora_module_for_fsdp( + # module, fsdp2_kwargs + # ) + # log_bias_dtype_mismatch |= module_log_bias_mismatch if auto_wrap_policy(module) and not isinstance(module, FSDPModule): fully_shard(module, **fsdp2_kwargs) From 6874d32e0c9749997b627cdf5c5a7029b10646c9 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi <“salman.mohammadi@outlook.com”> Date: Fri, 12 Sep 2025 15:26:12 +0000 Subject: [PATCH 2/4] more lora handling --- src/axolotl/monkeypatch/accelerate/fsdp2.py | 55 +++++++++++++++------ 1 file changed, 39 insertions(+), 16 deletions(-) diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 9351846525..68306a689d 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -225,18 +225,37 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs): module.base_layer.bias.data = module.base_layer.bias.data.to( module.base_layer.weight.dtype ) - - for active_adapter in module.active_adapters: - if module.lora_A: - fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs) - if module.lora_B: - fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs) - if module.lora_embedding_A: - fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs) - if module.lora_embedding_B: - fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs) - if module.lora_magnitude_vector: - fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs) + fully_shard(module, **fsdp2_kwargs) + module.set_reshard_after_forward(False) + module.set_reshard_after_backward(False) + # for active_adapter in module.active_adapters: + # for adapter_name in [ + # "lora_A", + # "lora_B", + # "lora_embedding_A", + # "lora_embedding_B", + # "lora_magnitude_vector", + # ]: + # adapter_module = getattr(module, adapter_name, None) + # # print(adapter_module, adapter_name) + # # torch.distributed.breakpoint() + # if not adapter_module: + # continue + # fsdp_adapter_module = fully_shard(adapter_module[active_adapter], **fsdp2_kwargs) + # # fsdp_adapter_module.unshard() + # fsdp_adapter_module.set_reshard_after_backward(False) + # fsdp_adapter_module.set_reshard_after_forward(False) + # torch.distributed.breakpoint() + # if module.lora_A: + # fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs) + # if module.lora_B: + # fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs) + # if module.lora_embedding_A: + # fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs) + # if module.lora_embedding_B: + # fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs) + # if module.lora_magnitude_vector: + # fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs) return log_bias_dtype_mismatch @@ -356,11 +375,12 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: if auto_wrap_policy is not None: for module in get_module_children_bottom_up(model)[:-1]: if is_peft_model and isinstance(module, LoraLayer): + # torch.distributed.breakpoint() cast_lora_module(module) - # module_log_bias_mismatch = _process_lora_module_for_fsdp( - # module, fsdp2_kwargs - # ) - # log_bias_dtype_mismatch |= module_log_bias_mismatch + module_log_bias_mismatch = _process_lora_module_for_fsdp( + module, fsdp2_kwargs + ) + log_bias_dtype_mismatch |= module_log_bias_mismatch if auto_wrap_policy(module) and not isinstance(module, FSDPModule): fully_shard(module, **fsdp2_kwargs) @@ -377,6 +397,9 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: accelerator, model, original_sd, offload_to_cpu=offload_to_cpu ) + # for module in model.named_modules(): + # if "Lora" in + if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: # We re-register the buffers, as they may not be in the state_dict for fqn, buffer_tensor in original_non_persistent_buffers.items(): From 850489405b167fdc7b7feef3f3d2f376c36d3af5 Mon Sep 17 00:00:00 2001 From: Salman Mohammadi <“salman.mohammadi@outlook.com”> Date: Fri, 12 Sep 2025 17:34:41 +0000 Subject: [PATCH 3/4] working? --- src/axolotl/monkeypatch/accelerate/fsdp2.py | 19 +++++++++++++------ 1 file changed, 13 insertions(+), 6 deletions(-) diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index 68306a689d..c5429e05ea 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -369,18 +369,25 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: model.tie_weights() is_peft_model = isinstance(model, PeftModel) - + for name, module in model.named_children(): + if name == "experts": + # torch.distributed.breakpoint() + for expert in module.children(): + # torch.distributed.breakpoint() + print(f"expert: {expert}") + for lora_module in expert.children(): + print(f"lora {lora_module}") + # torch.distributed.breakpoint() + cast_lora_module(lora_module) + _process_lora_module_for_fsdp(lora_module, fsdp2_kwargs) auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model) log_bias_dtype_mismatch = False if auto_wrap_policy is not None: for module in get_module_children_bottom_up(model)[:-1]: - if is_peft_model and isinstance(module, LoraLayer): + if is_peft_model and isinstance(module, LoraLayer) and not isinstance(module, FSDPModule): # torch.distributed.breakpoint() cast_lora_module(module) - module_log_bias_mismatch = _process_lora_module_for_fsdp( - module, fsdp2_kwargs - ) - log_bias_dtype_mismatch |= module_log_bias_mismatch + # torch.distributed.breakpoint() if auto_wrap_policy(module) and not isinstance(module, FSDPModule): fully_shard(module, **fsdp2_kwargs) From a7676af44dd1e438888652d2792b4ccf53eec5cc Mon Sep 17 00:00:00 2001 From: Salman Mohammadi Date: Fri, 12 Sep 2025 18:51:10 +0100 Subject: [PATCH 4/4] hmmm --- src/axolotl/monkeypatch/accelerate/fsdp2.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index d65f7472e5..5ef7803650 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -371,6 +371,7 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: model.tie_weights() is_peft_model = isinstance(model, PeftModel) + # TODO - this doesn't actually do anything for name, module in model.named_children(): if name == "experts": # torch.distributed.breakpoint()