|
21 | 21 |
|
22 | 22 | import ray |
23 | 23 | import torch |
| 24 | +from nemo_automodel.components._peft.lora import LinearLoRA |
24 | 25 | from nemo_automodel.components.distributed.cp_utils import ( |
25 | 26 | create_context_parallel_ctx, |
26 | 27 | ) |
@@ -85,19 +86,66 @@ def dtensor_params_generator( |
85 | 86 | Args: |
86 | 87 | model: The model whose parameters to generate. |
87 | 88 | target_dtype: The dtype to convert tensors to. |
| 89 | + peft_config: Optional LoRA config for filtering which layers to merge. |
88 | 90 |
|
89 | 91 | Yields: |
90 | 92 | Tuples of (fully_qualified_name, tensor) where tensors are converted to target dtype and made contiguous. |
91 | 93 | """ |
| 94 | + module_map = dict(model.named_modules()) |
92 | 95 | for name, tensor in model.state_dict().items(): |
| 96 | + if name.endswith(".lora_A.weight") or name.endswith(".lora_B.weight"): |
| 97 | + continue |
93 | 98 | full_tensor = tensor.full_tensor() if isinstance(tensor, DTensor) else tensor |
94 | | - adapted_fqn_tensors = _maybe_adapt_tensor_to_hf(model, name, full_tensor) |
| 99 | + merged_tensor = _maybe_merge_lora_weight(module_map, name, full_tensor) |
| 100 | + |
| 101 | + adapted_fqn_tensors = _maybe_adapt_tensor_to_hf(model, name, merged_tensor) |
95 | 102 | for adapted_fqn, adapted_tensor in adapted_fqn_tensors: |
96 | 103 | # Convert to target dtype |
97 | 104 | yield ( |
98 | 105 | adapted_fqn, |
99 | 106 | adapted_tensor.to(target_dtype, non_blocking=True).contiguous(), |
100 | 107 | ) |
| 108 | + del adapted_tensor |
| 109 | + del adapted_fqn_tensors |
| 110 | + del merged_tensor |
| 111 | + del full_tensor |
| 112 | + |
| 113 | + |
| 114 | +@torch.no_grad() |
| 115 | +def _maybe_merge_lora_weight( |
| 116 | + module_map: dict[str, nn.Module], |
| 117 | + fqn: str, |
| 118 | + tensor: torch.Tensor, |
| 119 | +) -> torch.Tensor: |
| 120 | + if not fqn.endswith(".weight"): |
| 121 | + return tensor |
| 122 | + module_name = fqn[: -len(".weight")] |
| 123 | + module = module_map.get(module_name) |
| 124 | + if not isinstance(module, LinearLoRA): |
| 125 | + return tensor |
| 126 | + if not (hasattr(module, "lora_A") and hasattr(module, "lora_B")): |
| 127 | + return tensor |
| 128 | + |
| 129 | + lora_a = ( |
| 130 | + module.lora_A.weight.full_tensor() |
| 131 | + if isinstance(module.lora_A.weight, DTensor) |
| 132 | + else module.lora_A.weight |
| 133 | + ) |
| 134 | + lora_b = ( |
| 135 | + module.lora_B.weight.full_tensor() |
| 136 | + if isinstance(module.lora_B.weight, DTensor) |
| 137 | + else module.lora_B.weight |
| 138 | + ) |
| 139 | + lora_a = lora_a.to(device=tensor.device, dtype=tensor.dtype) |
| 140 | + lora_b = lora_b.to(device=tensor.device, dtype=tensor.dtype) |
| 141 | + scale = getattr(module, "scale", None) |
| 142 | + |
| 143 | + if scale is None and hasattr(module, "alpha") and hasattr(module, "dim"): |
| 144 | + scale = module.alpha / module.dim |
| 145 | + if scale is None: |
| 146 | + scale = 1.0 |
| 147 | + |
| 148 | + return tensor + torch.matmul(lora_b, lora_a) * scale |
101 | 149 |
|
102 | 150 |
|
103 | 151 | def _maybe_adapt_tensor_to_hf( |
@@ -1208,6 +1256,8 @@ def prepare_refit_info(self) -> Optional[dict[str, Any]]: |
1208 | 1256 | """Prepare state dict metadata for weight refitting and IPC streaming.""" |
1209 | 1257 | state_dict_info = {} |
1210 | 1258 | for name, tensor in self.model.state_dict().items(): |
| 1259 | + if name.endswith(".lora_A.weight") or name.endswith(".lora_B.weight"): |
| 1260 | + continue |
1211 | 1261 | full_tensor = ( |
1212 | 1262 | tensor.full_tensor() if isinstance(tensor, DTensor) else tensor |
1213 | 1263 | ) |
|
0 commit comments