@@ -976,6 +976,11 @@ def preprocess_lora_weights(lora_model, model_config):
976976 def interleave_fused_lora_weights_for_tp (
977977 weight : torch .Tensor , rank_dim : int , tp_size : int , part_sizes : List [int ]
978978 ) -> List [torch .Tensor ]:
979+ """Interleaves fused LoRA modules weights for TP.
980+ e.g. In case of attn_qkv: Convert t_out=torch.cat([Wq, Wk, Wv]) to
981+ torch.cat([Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN])
982+ where N=TP size.
983+ """ # noqa: D205
979984 assert weight .shape [rank_dim ] == sum (part_sizes )
980985
981986 # Split the weights into their respective parts. e.g. weight -> [Wq, Wk, Wv] for attn_qkv.
@@ -1004,11 +1009,10 @@ def interleave_fused_lora_weights_for_tp(
10041009 def prepare_fused_lora_modules_for_tp (
10051010 lora_module : str , t_out : torch .Tensor , rank_dim : int
10061011 ) -> torch .Tensor :
1007- """Interleaves fused LoRA modules weights for TP. This is required since HF stores the parts weights
1008- sequentially, whereas with TP>1 we need them to be interleaved.
1009- e.g. In case of attn_qkv: Convert t_out=torch.cat([Wq, Wk, Wv]) to
1010- torch.cat([Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN])
1011- where N=TP size.
1012+ """Reorders fused LoRA modules weights for TP. This is required since HF stores the parts weights
1013+ sequentially, whereas with TP>1 we need them to be interleaved so they would be sharded correctly.
1014+
1015+ See interleave_fused_lora_weights_for_tp for more details.
10121016 """ # noqa: D205
10131017 tp_size = self ._mapping .tp_size
10141018 if tp_size == 1 :
0 commit comments