diff --git a/src/axolotl/monkeypatch/accelerate/fsdp2.py b/src/axolotl/monkeypatch/accelerate/fsdp2.py index d8ba02cb2c..5ef7803650 100644 --- a/src/axolotl/monkeypatch/accelerate/fsdp2.py +++ b/src/axolotl/monkeypatch/accelerate/fsdp2.py @@ -180,13 +180,8 @@ def get_state_dict(self, model, unwrap=True): return state_dict - -def _process_lora_module_for_fsdp(module, fsdp2_kwargs): - """Helper function to process LoRA modules for FSDP2.""" - from torch.distributed.fsdp import fully_shard - - log_bias_dtype_mismatch = False - +def cast_lora_module(module): + base_layer_dtype = module.base_layer.weight.dtype # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to # wrap this. Therefore we must ensure the bias has the same dtype as the weight if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None: @@ -198,15 +193,71 @@ def _process_lora_module_for_fsdp(module, fsdp2_kwargs): for active_adapter in module.active_adapters: if module.lora_A: - fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs) + module.lora_A[active_adapter] = module.lora_A[active_adapter].to(base_layer_dtype) + if hasattr(module.lora_A[active_adapter], 'bias') and module.lora_A[active_adapter].bias is not None: + module.lora_A[active_adapter].bias.data = module.lora_A[active_adapter].bias.data.to(base_layer_dtype) if module.lora_B: - fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs) + module.lora_B[active_adapter] = module.lora_B[active_adapter].to(base_layer_dtype) + if hasattr(module.lora_B[active_adapter], 'bias') and module.lora_B[active_adapter].bias is not None: + module.lora_B[active_adapter].bias.data = module.lora_B[active_adapter].bias.data.to(base_layer_dtype) if module.lora_embedding_A: - fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs) + module.lora_embedding_A[active_adapter] = module.lora_embedding_A[active_adapter].to(base_layer_dtype) + if hasattr(module.lora_embedding_A[active_adapter], 'bias') and module.lora_embedding_A[active_adapter].bias is not None: + module.lora_embedding_A[active_adapter].bias.data = module.lora_embedding_A[active_adapter].bias.data.to(base_layer_dtype) if module.lora_embedding_B: - fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs) + module.lora_embedding_B[active_adapter] = module.lora_embedding_B[active_adapter].to(base_layer_dtype) + if hasattr(module.lora_embedding_B[active_adapter], 'bias') and module.lora_embedding_B[active_adapter].bias is not None: + module.lora_embedding_B[active_adapter].bias.data = module.lora_embedding_B[active_adapter].bias.data.to(base_layer_dtype) if module.lora_magnitude_vector: - fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs) + module.lora_magnitude_vector[active_adapter] = module.lora_magnitude_vector[active_adapter].to(base_layer_dtype) + if hasattr(module.lora_magnitude_vector[active_adapter], 'bias') and module.lora_magnitude_vector[active_adapter].bias is not None: + module.lora_magnitude_vector[active_adapter].bias.data = module.lora_magnitude_vector[active_adapter].bias.data.to(base_layer_dtype) + +def _process_lora_module_for_fsdp(module, fsdp2_kwargs): + """Helper function to process LoRA modules for FSDP2.""" + from torch.distributed.fsdp import fully_shard + + log_bias_dtype_mismatch = False + + # Linear4Bit will keep it's bias term in fp32. If the weight dtype is in bf16 we are not able to + # wrap this. Therefore we must ensure the bias has the same dtype as the weight + if hasattr(module.base_layer, "bias") and module.base_layer.bias is not None: + if module.base_layer.weight.dtype != module.base_layer.bias.dtype: + log_bias_dtype_mismatch = True + module.base_layer.bias.data = module.base_layer.bias.data.to( + module.base_layer.weight.dtype + ) + fully_shard(module, **fsdp2_kwargs) + module.set_reshard_after_forward(False) + module.set_reshard_after_backward(False) + # for active_adapter in module.active_adapters: + # for adapter_name in [ + # "lora_A", + # "lora_B", + # "lora_embedding_A", + # "lora_embedding_B", + # "lora_magnitude_vector", + # ]: + # adapter_module = getattr(module, adapter_name, None) + # # print(adapter_module, adapter_name) + # # torch.distributed.breakpoint() + # if not adapter_module: + # continue + # fsdp_adapter_module = fully_shard(adapter_module[active_adapter], **fsdp2_kwargs) + # # fsdp_adapter_module.unshard() + # fsdp_adapter_module.set_reshard_after_backward(False) + # fsdp_adapter_module.set_reshard_after_forward(False) + # torch.distributed.breakpoint() + # if module.lora_A: + # fully_shard(module.lora_A[active_adapter], **fsdp2_kwargs) + # if module.lora_B: + # fully_shard(module.lora_B[active_adapter], **fsdp2_kwargs) + # if module.lora_embedding_A: + # fully_shard(module.lora_embedding_A[active_adapter], **fsdp2_kwargs) + # if module.lora_embedding_B: + # fully_shard(module.lora_embedding_B[active_adapter], **fsdp2_kwargs) + # if module.lora_magnitude_vector: + # fully_shard(module.lora_magnitude_vector[active_adapter], **fsdp2_kwargs) return log_bias_dtype_mismatch @@ -320,16 +371,26 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: model.tie_weights() is_peft_model = isinstance(model, PeftModel) - + # TODO - this doesn't actually do anything + for name, module in model.named_children(): + if name == "experts": + # torch.distributed.breakpoint() + for expert in module.children(): + # torch.distributed.breakpoint() + print(f"expert: {expert}") + for lora_module in expert.children(): + print(f"lora {lora_module}") + # torch.distributed.breakpoint() + cast_lora_module(lora_module) + _process_lora_module_for_fsdp(lora_module, fsdp2_kwargs) auto_wrap_policy = fsdp2_prepare_auto_wrap_policy(fsdp2_plugin, model) log_bias_dtype_mismatch = False if auto_wrap_policy is not None: for module in get_module_children_bottom_up(model)[:-1]: - if is_peft_model and isinstance(module, LoraLayer): - module_log_bias_mismatch = _process_lora_module_for_fsdp( - module, fsdp2_kwargs - ) - log_bias_dtype_mismatch |= module_log_bias_mismatch + if is_peft_model and isinstance(module, LoraLayer) and not isinstance(module, FSDPModule): + # torch.distributed.breakpoint() + cast_lora_module(module) + # torch.distributed.breakpoint() if auto_wrap_policy(module) and not isinstance(module, FSDPModule): fully_shard(module, **fsdp2_kwargs) @@ -346,6 +407,9 @@ def fsdp2_prepare_model(accelerator, model: torch.nn.Module) -> torch.nn.Module: accelerator, model, original_sd, offload_to_cpu=offload_to_cpu ) + # for module in model.named_modules(): + # if "Lora" in + if fsdp2_plugin.cpu_ram_efficient_loading and not model_has_params4bit: # We re-register the buffers, as they may not be in the state_dict for fqn, buffer_tensor in original_non_persistent_buffers.items():