Skip to content

Commit 935b07b

Browse files
committed
Update sharded axis
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 1d38784 commit 935b07b

File tree

3 files changed

+74
-3
lines changed

3 files changed

+74
-3
lines changed

modelopt/torch/peft/convert.py

Lines changed: 0 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -44,11 +44,8 @@ def update_model(
4444
Returns:
4545
The updated model with LoRA adapters
4646
"""
47-
# Validate config by converting to PEFTConfig if needed
48-
4947
# Check if model is already in PEFT mode by looking for LoRA modules
5048
if not is_peft_model(model):
51-
# First time - need to convert to PEFT mode
5249
apply_mode(model, mode=[("peft", config)], registry=PEFTModeRegistry)
5350
else:
5451
if not isinstance(config, PEFTConfig):

modelopt/torch/peft/lora/layer.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -64,6 +64,8 @@ def _register_adapter(
6464
self.add_module(f"lora_b_{adapter_name}", lora_b)
6565

6666
# Store in adapter dictionary with explicit rank
67+
if adapter_name in self._lora_adapters:
68+
raise ValueError(f"adapter_name: {adapter_name} is already exist..")
6769
self._lora_adapters[adapter_name] = {
6870
"lora_a": lora_a,
6971
"lora_b": lora_b,

modelopt/torch/peft/lora/tp_layer.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import torch.nn as nn
77
import torch.nn.init as init
88
from megatron.core.tensor_parallel.layers import ColumnParallelLinear, RowParallelLinear
9+
from megatron.core.transformer.utils import make_sharded_tensors_for_checkpoint
910

1011
from ..config import PEFTAttributeConfig
1112
from .layer import LoRAModule, LoRAModuleRegistry
@@ -129,6 +130,40 @@ def update_layer_lora(
129130
adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale, attr_config.enable
130131
)
131132

133+
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
134+
"""Sharding along axis 0 for ColumnParallelLinear, bias not sharded.
135+
136+
For ColumnParallelLinear:
137+
- lora_a weight: sharded at dim 0
138+
- lora_b weight: sharded at dim 0
139+
"""
140+
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
141+
142+
if hasattr(self, "_lora_adapters"):
143+
lora_state_dict = {}
144+
state_dict = self.state_dict(prefix="", keep_vars=True)
145+
146+
for adapter_name in self._lora_adapters:
147+
lora_a_key = f"lora_a_{adapter_name}.weight"
148+
lora_b_key = f"lora_b_{adapter_name}.weight"
149+
150+
if lora_a_key in state_dict:
151+
lora_state_dict[lora_a_key] = state_dict[lora_a_key]
152+
if lora_b_key in state_dict:
153+
lora_state_dict[lora_b_key] = state_dict[lora_b_key]
154+
155+
lora_sharding_dims = {}
156+
for key in lora_state_dict:
157+
lora_sharding_dims[key] = 0
158+
159+
if lora_state_dict:
160+
lora_sharded = make_sharded_tensors_for_checkpoint(
161+
lora_state_dict, prefix, lora_sharding_dims, sharded_offsets
162+
)
163+
sharded_state_dict.update(lora_sharded)
164+
165+
return sharded_state_dict
166+
132167

133168
@LoRAModuleRegistry.register({RowParallelLinear: "megatron_RowParallelLinear"})
134169
class _LoRAMegatronRowParallelLinear(_MegatronParallelLoRABase):
@@ -172,6 +207,43 @@ def update_layer_lora(
172207
adapter_name, lora_a, lora_b, attr_config.rank, attr_config.scale, attr_config.enable
173208
)
174209

210+
def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
211+
"""Sharding along axis 1 for RowParallelLinear, bias not sharded.
212+
213+
For RowParallelLinear:
214+
- lora_a weight: sharded at dim 1 (RowParallelLinear)
215+
- lora_b weight: sharded at dim 0 (ColumnParallelLinear)
216+
"""
217+
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
218+
219+
if hasattr(self, "_lora_adapters"):
220+
lora_state_dict = {}
221+
state_dict = self.state_dict()
222+
223+
for adapter_name in self._lora_adapters:
224+
lora_a_key = f"lora_a_{adapter_name}.weight"
225+
lora_b_key = f"lora_b_{adapter_name}.weight"
226+
227+
if lora_a_key in state_dict:
228+
lora_state_dict[lora_a_key] = state_dict[lora_a_key]
229+
if lora_b_key in state_dict:
230+
lora_state_dict[lora_b_key] = state_dict[lora_b_key]
231+
232+
lora_sharding_dims = {}
233+
for key in lora_state_dict:
234+
if "lora_a_" in key:
235+
lora_sharding_dims[key] = 1
236+
elif "lora_b_" in key:
237+
lora_sharding_dims[key] = 0
238+
239+
if lora_state_dict:
240+
lora_sharded = make_sharded_tensors_for_checkpoint(
241+
lora_state_dict, prefix, lora_sharding_dims, sharded_offsets
242+
)
243+
sharded_state_dict.update(lora_sharded)
244+
245+
return sharded_state_dict
246+
175247

176248
# Register quantized versions if available
177249
if QUANT_MODULES_AVAILABLE:

0 commit comments

Comments
 (0)