11import io
2+ import itertools
23import json
34import logging
45import re
@@ -660,11 +661,17 @@ class LoraManager(object):
660661 }
661662
662663 def __init__ (
663- self , cpp_peft_cache_manager : tb_internal .batch_manager .PeftCacheManager | None = None
664+ self ,
665+ * ,
666+ mapping : Mapping ,
667+ model_config : "ModelConfig" ,
668+ cpp_peft_cache_manager : tb_internal .batch_manager .PeftCacheManager | None = None ,
664669 ):
665670 """Constructor.
666671
667672 Args:
673+ mapping (Mapping): Parallelism related information.
674+ model_config (ModelConfig): model configuration python class instance.
668675 cpp_peft_cache_manager (PeftCacheManager, optional): used by is_adapter_in_cpu_cache method, that's used for
669676 a performance optimization with LoRA of not sending the LoRA adapter weights with every LLM request when
670677 the adapter is already loaded in the LoRA CPU cache.
@@ -704,6 +711,8 @@ def __init__(
704711 self ._cpp_lora_weights : Dict [str , torch .Tensor ] = {} # on cpu
705712 self ._cpp_lora_config : Dict [str , torch .Tensor ] = {} # on cpu
706713 self .lora_target_modules : List [str ] = []
714+ self ._mapping = mapping
715+ self ._model_config = model_config
707716 self ._cpp_peft_cache_manager = cpp_peft_cache_manager
708717
709718 def is_adapter_in_cpu_cache (self , adapter_uid : int ) -> bool :
@@ -730,7 +739,6 @@ def load_from_ckpt(
730739 self ,
731740 model_dirs_or_files : List [str ],
732741 model_config : Union ["ModelConfig" , LoraModelConfig ],
733- runtime_mapping : Optional [Mapping ] = None ,
734742 uids : Optional [List [str ]] = None ,
735743 ckpt_source : str = "hf" ,
736744 ) -> List [str ]:
@@ -743,7 +751,6 @@ def load_from_ckpt(
743751 return self .load_from_hf (
744752 model_dirs = model_dirs_or_files ,
745753 model_config = model_config ,
746- runtime_mapping = runtime_mapping ,
747754 uids = uids ,
748755 )
749756 elif ckpt_source == "nemo" :
@@ -754,7 +761,6 @@ def load_from_ckpt(
754761 return self .load_from_nemo (
755762 model_files = nemo_files ,
756763 model_config = model_config ,
757- runtime_mapping = runtime_mapping ,
758764 uids = uids ,
759765 )
760766 else :
@@ -764,19 +770,13 @@ def load_from_nemo(
764770 self ,
765771 model_files : List [str ],
766772 model_config : Union ["ModelConfig" , LoraModelConfig ],
767- runtime_mapping : Optional [Mapping ] = None ,
768773 uids : Optional [List [str ]] = None ,
769774 ) -> List [str ]:
770775 """Returns the adapter UIDs that were loaded by this call.
771776
772777 Note that when an adapter was already loaded before this call, it would not be
773778 included in the returned list of UIDs.
774779 """
775- if runtime_mapping is None :
776- runtime_mapping = Mapping ()
777- tp_size = runtime_mapping .tp_size
778- tp_rank = runtime_mapping .tp_rank
779-
780780 if uids is None :
781781 uids = [self ._generate_uid () for _ in range (len (model_files ))]
782782 assert len (uids ) == len (model_files )
@@ -829,10 +829,6 @@ def load_from_model_file(uid, model_file):
829829
830830 t_in = all_lora_weights [layer_idx ]["in" ]
831831 t_out = all_lora_weights [layer_idx ]["out" ]
832- assert t_out .shape [0 ] % tp_size == 0
833- t_out = torch .split (t_out , t_out .shape [0 ] // tp_size , dim = 0 )[
834- tp_rank
835- ].contiguous ()
836832 else :
837833 t_in = None
838834 t_out = None
@@ -882,7 +878,6 @@ def load_from_hf(
882878 self ,
883879 model_dirs : List [str ],
884880 model_config : Union ["ModelConfig" , LoraModelConfig ],
885- runtime_mapping : Optional [Mapping ] = None ,
886881 uids : Optional [List [str ]] = None ,
887882 component : Optional [str ] = None ,
888883 ) -> List [str ]:
@@ -939,11 +934,6 @@ def load_from_hf(
939934 ...
940935
941936 """
942- if runtime_mapping is None :
943- runtime_mapping = Mapping ()
944- tp_size = runtime_mapping .tp_size
945- tp_rank = runtime_mapping .tp_rank
946-
947937 if uids is None :
948938 uids = [self ._generate_uid () for _ in range (len (model_dirs ))]
949939 assert len (uids ) == len (model_dirs )
@@ -983,6 +973,70 @@ def preprocess_lora_weights(lora_model, model_config):
983973 lora_model [key ] = value
984974 return lora_model
985975
976+ def interleave_fused_lora_weights_for_tp (
977+ weight : torch .Tensor , rank_dim : int , tp_size : int , part_sizes : List [int ]
978+ ) -> List [torch .Tensor ]:
979+ """Interleaves fused LoRA modules weights for TP.
980+ e.g. In case of attn_qkv: Convert t_out=torch.cat([Wq, Wk, Wv]) to
981+ torch.cat([Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN])
982+ where N=TP size.
983+ """ # noqa: D205
984+ assert weight .shape [rank_dim ] == sum (part_sizes )
985+
986+ # Split the weights into their respective parts. e.g. weight -> [Wq, Wk, Wv] for attn_qkv.
987+ weight_parts = [
988+ weight .narrow (rank_dim , sum (part_sizes [:i ]), part_sizes [i ])
989+ for i in range (len (part_sizes ))
990+ ]
991+ for i in range (len (part_sizes )):
992+ assert weight_parts [i ].shape [rank_dim ] % tp_size == 0
993+
994+ # Split each part into tp_size chunks.
995+ # e.g. [Wq, Wk, Wv] -> [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]]
996+ # where N is TP size, for attn_qkv.
997+ weight_parts_tp_weights = [
998+ torch .split (
999+ weight_parts [i ], weight_parts [i ].shape [rank_dim ] // tp_size , dim = rank_dim
1000+ )
1001+ for i in range (len (part_sizes ))
1002+ ]
1003+
1004+ # Interleave the parts across TP ranks and flatten the list of lists into a single list.
1005+ # e.g. [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]]
1006+ # -> [Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN] where N is TP size, for attn_qkv.
1007+ return list (itertools .chain .from_iterable (zip (* weight_parts_tp_weights )))
1008+
1009+ def prepare_fused_lora_modules_for_tp (
1010+ lora_module : str , t_out : torch .Tensor , rank_dim : int
1011+ ) -> torch .Tensor :
1012+ """Reorders fused LoRA modules weights for TP. This is required since HF stores the parts weights
1013+ sequentially, whereas with TP>1 we need them to be interleaved so they would be sharded correctly.
1014+
1015+ See interleave_fused_lora_weights_for_tp for more details.
1016+ """ # noqa: D205
1017+ tp_size = self ._mapping .tp_size
1018+ if tp_size == 1 :
1019+ return t_out
1020+ part_sizes = []
1021+ if lora_module == "mlp_gate_up" :
1022+ assert t_out .shape [rank_dim ] % 2 == 0
1023+ half_size = t_out .shape [rank_dim ] // 2
1024+ part_sizes = [half_size , half_size ]
1025+ elif lora_module == "attn_qkv" :
1026+ # The sizes are multiplied by tp_size because num_heads and num_kv_heads here were already
1027+ # divided by tp_size in tensorrt_llm/_torch/model_config.py::ModelConfig.get_bindings_model_config
1028+ q_size = self ._model_config .head_size * self ._model_config .num_heads * tp_size
1029+ kv_size = self ._model_config .head_size * self ._model_config .num_kv_heads * tp_size
1030+ part_sizes = [q_size , kv_size , kv_size ]
1031+
1032+ if part_sizes :
1033+ interleaved_parts = interleave_fused_lora_weights_for_tp (
1034+ t_out , rank_dim , tp_size , part_sizes
1035+ )
1036+ # Concatenate them all after interleaving, as the CPP implementation expects the full non-split weights.
1037+ t_out = torch .cat (interleaved_parts , dim = rank_dim )
1038+ return t_out
1039+
9861040 def load_from_model_dir (uid , model_dir , hf_config ):
9871041 if uid not in self ._cpp_lora_weights :
9881042 self ._cpp_lora_weights [uid ] = [] # Will be converted to tensor later
@@ -1060,36 +1114,9 @@ def load_from_model_dir(uid, model_dir, hf_config):
10601114 t_mag = module_weights .get ("magnitude" , None )
10611115
10621116 is_dora = t_mag is not None
1063-
1064- if lora_module in ["moe_router" , "mlp_router" ]:
1065- pass
1066- elif "moe" in lora_module and runtime_mapping .has_moe_ep ():
1067- pass
1068- elif lora_module in [
1069- "attn_dense" ,
1070- "cross_attn_dense" ,
1071- "mlp_4h_to_h" ,
1072- "moe_4h_to_h" ,
1073- ]:
1074- # split by row
1075- dim = 2 if has_expert_indices else 1
1076- assert t_in .shape [dim ] % tp_size == 0
1077- t_in = torch .split (t_in , t_in .shape [dim ] // tp_size , dim = dim )[
1078- tp_rank
1079- ].contiguous ()
1080- else :
1081- # split by column
1082- dim = 1 if has_expert_indices else 0
1083- assert t_out .shape [dim ] % tp_size == 0
1084- t_out = torch .split (t_out , t_out .shape [dim ] // tp_size , dim = dim )[
1085- tp_rank
1086- ].contiguous ()
1087- if dim == 0 and is_dora and t_mag is not None :
1088- t_mag = torch .split (t_mag , t_mag .shape [0 ] // tp_size , dim = 0 )[
1089- tp_rank
1090- ].contiguous ()
1091-
10921117 rank_dim = 1 if has_expert_indices else 0
1118+ t_out = prepare_fused_lora_modules_for_tp (lora_module , t_out , rank_dim )
1119+
10931120 effective_rank = t_in .shape [rank_dim ]
10941121
10951122 t_in = t_in .cuda ().contiguous ()
0 commit comments