Skip to content

Commit 6df1954

Browse files
committed
Update the comment
Signed-off-by: Jingyu Xin <[email protected]>
1 parent 81f8d06 commit 6df1954

File tree

1 file changed

+2
-2
lines changed

1 file changed

+2
-2
lines changed

modelopt/torch/peft/lora/plugins/megatron.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,7 +157,7 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
157157
"""Sharding along axis 0 for ColumnParallelLinear, bias not sharded.
158158
159159
For ColumnParallelLinear:
160-
- lora_a weight: sharded at dim 0
160+
(lora_a is a regular nn.Linear and is not sharded)
161161
- lora_b weight: sharded at dim 0
162162
"""
163163
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
@@ -233,7 +233,7 @@ def sharded_state_dict(self, prefix="", sharded_offsets=(), metadata=None):
233233
234234
For RowParallelLinear:
235235
- lora_a weight: sharded at dim 1 (RowParallelLinear)
236-
- lora_b weight: sharded at dim 0 (ColumnParallelLinear)
236+
(lora_b is a regular nn.Linear and is not sharded)
237237
"""
238238
sharded_state_dict = super().sharded_state_dict(prefix, sharded_offsets, metadata)
239239

0 commit comments

Comments
 (0)