|
6 | 6 | import torch.nn as nn
|
7 | 7 | import torch.nn.init as init
|
8 | 8 | from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
| 9 | +from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint |
9 | 10 |
|
10 | 11 | from ..config import PEFTAttributeConfig
|
11 | 12 | from .layer import LoRAModule, LoRAModuleRegistry
|
@@ -129,6 +130,40 @@ def update_layer_lora(
|
129 | 130 | adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale, attr_config.enable
|
130 | 131 | )
|
131 | 132 |
|
| 133 | + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): |
| 134 | + """Sharding along axis 0 for ColumnParallelLinear, bias not sharded. |
| 135 | +
|
| 136 | + For ColumnParallelLinear: |
| 137 | + - lora_a weight: sharded at dim 0 |
| 138 | + - lora_b weight: sharded at dim 0 |
| 139 | + """ |
| 140 | + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) |
| 141 | + |
| 142 | + if hasattr(self, "_lora_adapters"): |
| 143 | + lora_state_dict = {} |
| 144 | + state_dict = self.state_dict(prefix="", keep_vars=True) |
| 145 | + |
| 146 | + for adapter_name in self._lora_adapters: |
| 147 | + lora_a_key = f"lora_a_{adapter_name}.weight" |
| 148 | + lora_b_key = f"lora_b_{adapter_name}.weight" |
| 149 | + |
| 150 | + if lora_a_key in state_dict: |
| 151 | + lora_state_dict[lora_a_key] = state_dict[lora_a_key] |
| 152 | + if lora_b_key in state_dict: |
| 153 | + lora_state_dict[lora_b_key] = state_dict[lora_b_key] |
| 154 | + |
| 155 | + lora_sharding_dims = {} |
| 156 | + for key in lora_state_dict: |
| 157 | + lora_sharding_dims[key] = 0 |
| 158 | + |
| 159 | + if lora_state_dict: |
| 160 | + lora_sharded = make_sharded_tensors_for_checkpoint( |
| 161 | + lora_state_dict, prefix, lora_sharding_dims, sharded_offsets |
| 162 | + ) |
| 163 | + sharded_state_dict.update(lora_sharded) |
| 164 | + |
| 165 | + return sharded_state_dict |
| 166 | + |
132 | 167 |
|
133 | 168 | @LoRAModuleRegistry.register({RowParallelLinear: "megatron_RowParallelLinear"})
|
134 | 169 | class _LoRAMegatronRowParallelLinear(_MegatronParallelLoRABase):
|
@@ -172,6 +207,43 @@ def update_layer_lora(
|
172 | 207 | adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale, attr_config.enable
|
173 | 208 | )
|
174 | 209 |
|
| 210 | + def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None): |
| 211 | + """Sharding along axis 1 for RowParallelLinear, bias not sharded. |
| 212 | +
|
| 213 | + For RowParallelLinear: |
| 214 | + - lora_a weight: sharded at dim 1 (RowParallelLinear) |
| 215 | + - lora_b weight: sharded at dim 0 (ColumnParallelLinear) |
| 216 | + """ |
| 217 | + sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata) |
| 218 | + |
| 219 | + if hasattr(self, "_lora_adapters"): |
| 220 | + lora_state_dict = {} |
| 221 | + state_dict = self.state_dict() |
| 222 | + |
| 223 | + for adapter_name in self._lora_adapters: |
| 224 | + lora_a_key = f"lora_a_{adapter_name}.weight" |
| 225 | + lora_b_key = f"lora_b_{adapter_name}.weight" |
| 226 | + |
| 227 | + if lora_a_key in state_dict: |
| 228 | + lora_state_dict[lora_a_key] = state_dict[lora_a_key] |
| 229 | + if lora_b_key in state_dict: |
| 230 | + lora_state_dict[lora_b_key] = state_dict[lora_b_key] |
| 231 | + |
| 232 | + lora_sharding_dims = {} |
| 233 | + for key in lora_state_dict: |
| 234 | + if "lora_a_" in key: |
| 235 | + lora_sharding_dims[key] = 1 |
| 236 | + elif "lora_b_" in key: |
| 237 | + lora_sharding_dims[key] = 0 |
| 238 | + |
| 239 | + if lora_state_dict: |
| 240 | + lora_sharded = make_sharded_tensors_for_checkpoint( |
| 241 | + lora_state_dict, prefix, lora_sharding_dims, sharded_offsets |
| 242 | + ) |
| 243 | + sharded_state_dict.update(lora_sharded) |
| 244 | + |
| 245 | + return sharded_state_dict |
| 246 | + |
175 | 247 |
|
176 | 248 | # Register quantized versions if available
|
177 | 249 | if QUANT_MODULES_AVAILABLE:
|
|
0 commit comments