Skip to content

Commit fac47e2

Browse files
authored
[https://nvbugs/5510879][fix] Fix pytorch & TRT-python flows fused LoRA adapter modules weight split with TP>1 (#8063)
Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com>
1 parent a1ed03f commit fac47e2

File tree

11 files changed

+229
-78
lines changed

11 files changed

+229
-78
lines changed

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
1414
from tensorrt_llm.lora_helper import LoraConfig
1515
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
16+
from tensorrt_llm.runtime import ModelConfig as ModelConfigPython
1617
from tensorrt_llm.sampling_params import SamplingParams
1718

1819
from ..._utils import binding_to_str_dtype, get_size_in_bytes, nvtx_range
@@ -32,7 +33,7 @@
3233
KVCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheManager
3334
KvCacheConfigCpp = tensorrt_llm.bindings.executor.KvCacheConfig
3435
CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType
35-
ModelConfig = tensorrt_llm.bindings.ModelConfig
36+
ModelConfigCpp = tensorrt_llm.bindings.ModelConfig
3637
DataType = tensorrt_llm.bindings.DataType
3738
KVCacheEventManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheEventManager
3839
RequestList = list[LlmRequest]
@@ -160,7 +161,7 @@ def __init__(
160161
spec_config: Optional["DecodingBaseConfig"] = None,
161162
layer_mask: Optional[List[bool]] = None,
162163
max_num_tokens: int = 8192,
163-
model_config: Optional[ModelConfig] = None,
164+
model_config: Optional[ModelConfigCpp] = None,
164165
max_beam_width: int = 1,
165166
is_draft: bool = False,
166167
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
@@ -371,7 +372,7 @@ def shutdown(self):
371372

372373
@classmethod
373374
def from_model_config(cls,
374-
model_config: ModelConfig,
375+
model_config: ModelConfigCpp,
375376
kv_cache_config: KvCacheConfigCpp,
376377
mapping: Mapping,
377378
kv_cache_type: CacheTypeCpp = CacheTypeCpp.SELF,
@@ -772,7 +773,7 @@ def adjust_window_sizes_for_vswa(
772773
window_size_to_layers: Dict[int, List[int]],
773774
max_attention_window_vec: List[int],
774775
kv_cache_config: KvCacheConfigCpp,
775-
model_config: ModelConfig,
776+
model_config: ModelConfigCpp,
776777
pool_memory_bytes: int,
777778
kv_factor: int,
778779
dtype: DataType,
@@ -887,7 +888,7 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int:
887888
def calculate_max_num_blocks_from_cpp(
888889
self,
889890
kv_cache_config: KvCacheConfigCpp,
890-
model_config: ModelConfig,
891+
model_config: ModelConfigCpp,
891892
extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]:
892893
"""
893894
This function is a wrapper of KVCacheManagerCpp.calculate_max_num_blocks.
@@ -1133,7 +1134,7 @@ class PeftCacheManager(BaseResourceManager):
11331134
def __init__(self,
11341135
peft_cache_config: PeftCacheConfig,
11351136
lora_config: LoraConfig,
1136-
model_config: ModelConfig,
1137+
model_config: ModelConfigCpp,
11371138
world_config: WorldConfig | None = None):
11381139
import tensorrt_llm.bindings as _tb
11391140

@@ -1169,7 +1170,20 @@ def __init__(self,
11691170
lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size,
11701171
binding_to_str_dtype(model_config.data_type),
11711172
lora_config.swap_gate_up_proj_lora_b_weight)
1172-
self._lora_manager = LoraManager()
1173+
mapping = Mapping(
1174+
world_size=world_config.size,
1175+
rank=world_config.rank,
1176+
tp_size=world_config.tensor_parallelism,
1177+
pp_size=world_config.pipeline_parallelism,
1178+
gpus_per_node=world_config.gpus_per_node,
1179+
)
1180+
self._lora_manager = LoraManager(
1181+
mapping=mapping,
1182+
model_config=ModelConfigPython.from_model_config_cpp(model_config),
1183+
cpp_peft_cache_manager=self.impl)
1184+
1185+
def get_lora_manager(self) -> LoraManager:
1186+
return self._lora_manager
11731187

11741188
def add_request_peft(self, request: LlmRequest):
11751189
if request.lora_task_id is not None:
@@ -1183,7 +1197,6 @@ def add_request_peft(self, request: LlmRequest):
11831197
self._lora_manager.load_from_ckpt(
11841198
[request.py_lora_path],
11851199
model_config=self._lora_model_config,
1186-
runtime_mapping=None,
11871200
uids=[request.lora_task_id],
11881201
ckpt_source=self._lora_config.lora_ckpt_source)
11891202
request.lora_weights = self._lora_manager.cpp_lora_weights[

tensorrt_llm/_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -42,7 +42,7 @@
4242
import tensorrt as trt
4343
# isort: on
4444

45-
from tensorrt_llm.bindings import DataType, GptJsonConfig
45+
from tensorrt_llm.bindings import DataType, GptJsonConfig, LayerType
4646
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
4747
from tensorrt_llm.logger import logger
4848

@@ -198,6 +198,10 @@ def str_dtype_to_torch(dtype):
198198
}
199199

200200

201+
def binding_layer_type_to_str(layer_type: LayerType) -> str:
202+
return layer_type.name.lower()
203+
204+
201205
def binding_to_str_dtype(binding_dtype) -> str:
202206
ret = _binding_to_str_dtype.get(binding_dtype)
203207
assert ret is not None, f'Unsupported binding dtype: {binding_dtype}'

tensorrt_llm/executor/base_worker.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,10 @@ def _create_engine(executor_config):
205205
# point in the TRT flow is currently not supported (it's at the CPP
206206
# Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA
207207
# optimization is not available in TRT-python flow.
208-
self._lora_manager = LoraManager(cpp_peft_cache_manager=None)
208+
self._lora_manager = LoraManager(
209+
mapping=engine_config.pretrained_config.mapping,
210+
model_config=self._runtime_model_config,
211+
cpp_peft_cache_manager=None)
209212
if engine_config.build_config.max_prompt_embedding_table_size > 0:
210213
self._prompt_adapter_manager = PromptAdapterManager()
211214

@@ -216,8 +219,7 @@ def _create_engine(executor_config):
216219
ResourceManagerType
217220
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
218221
ResourceManagerType.PEFT_CACHE_MANAGER)
219-
self._lora_manager = LoraManager(
220-
cpp_peft_cache_manager=peft_cache_manager.impl)
222+
self._lora_manager = peft_cache_manager.get_lora_manager()
221223
lora_model_config = self.engine.model_engine.lora_model_config
222224
assert lora_model_config is not None
223225
self._lora_model_config = lora_model_config
@@ -302,7 +304,6 @@ def _load_lora_adapter(self, lora_request: LoRARequest) -> bool:
302304
[lora_request.path],
303305
model_config=self._runtime_model_config if
304306
self._runtime_model_config is not None else self._lora_model_config,
305-
runtime_mapping=None,
306307
uids=[adapter_id],
307308
ckpt_source=lora_request.ckpt_source)
308309
return adapter_id in newly_loaded_uids

tensorrt_llm/lora_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def get_default_trtllm_modules_to_hf_modules():
4646
"attn_q": "q_proj",
4747
"attn_k": "k_proj",
4848
"attn_v": "v_proj",
49+
"attn_qkv": "qkv_proj",
4950
"attn_dense": "o_proj",
5051
"mlp_h_to_4h": "gate_proj",
5152
"mlp_4h_to_h": "down_proj",

tensorrt_llm/lora_manager.py

Lines changed: 76 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
import itertools
23
import json
34
import logging
45
import re
@@ -660,11 +661,17 @@ class LoraManager(object):
660661
}
661662

662663
def __init__(
663-
self, cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None
664+
self,
665+
*,
666+
mapping: Mapping,
667+
model_config: "ModelConfig",
668+
cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None,
664669
):
665670
"""Constructor.
666671
667672
Args:
673+
mapping (Mapping): Parallelism related information.
674+
model_config (ModelConfig): model configuration python class instance.
668675
cpp_peft_cache_manager (PeftCacheManager, optional): used by is_adapter_in_cpu_cache method, that's used for
669676
a performance optimization with LoRA of not sending the LoRA adapter weights with every LLM request when
670677
the adapter is already loaded in the LoRA CPU cache.
@@ -704,6 +711,8 @@ def __init__(
704711
self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu
705712
self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu
706713
self.lora_target_modules: List[str] = []
714+
self._mapping = mapping
715+
self._model_config = model_config
707716
self._cpp_peft_cache_manager = cpp_peft_cache_manager
708717

709718
def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool:
@@ -730,7 +739,6 @@ def load_from_ckpt(
730739
self,
731740
model_dirs_or_files: List[str],
732741
model_config: Union["ModelConfig", LoraModelConfig],
733-
runtime_mapping: Optional[Mapping] = None,
734742
uids: Optional[List[str]] = None,
735743
ckpt_source: str = "hf",
736744
) -> List[str]:
@@ -743,7 +751,6 @@ def load_from_ckpt(
743751
return self.load_from_hf(
744752
model_dirs=model_dirs_or_files,
745753
model_config=model_config,
746-
runtime_mapping=runtime_mapping,
747754
uids=uids,
748755
)
749756
elif ckpt_source == "nemo":
@@ -754,7 +761,6 @@ def load_from_ckpt(
754761
return self.load_from_nemo(
755762
model_files=nemo_files,
756763
model_config=model_config,
757-
runtime_mapping=runtime_mapping,
758764
uids=uids,
759765
)
760766
else:
@@ -764,19 +770,13 @@ def load_from_nemo(
764770
self,
765771
model_files: List[str],
766772
model_config: Union["ModelConfig", LoraModelConfig],
767-
runtime_mapping: Optional[Mapping] = None,
768773
uids: Optional[List[str]] = None,
769774
) -> List[str]:
770775
"""Returns the adapter UIDs that were loaded by this call.
771776
772777
Note that when an adapter was already loaded before this call, it would not be
773778
included in the returned list of UIDs.
774779
"""
775-
if runtime_mapping is None:
776-
runtime_mapping = Mapping()
777-
tp_size = runtime_mapping.tp_size
778-
tp_rank = runtime_mapping.tp_rank
779-
780780
if uids is None:
781781
uids = [self._generate_uid() for _ in range(len(model_files))]
782782
assert len(uids) == len(model_files)
@@ -829,10 +829,6 @@ def load_from_model_file(uid, model_file):
829829

830830
t_in = all_lora_weights[layer_idx]["in"]
831831
t_out = all_lora_weights[layer_idx]["out"]
832-
assert t_out.shape[0] % tp_size == 0
833-
t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[
834-
tp_rank
835-
].contiguous()
836832
else:
837833
t_in = None
838834
t_out = None
@@ -882,7 +878,6 @@ def load_from_hf(
882878
self,
883879
model_dirs: List[str],
884880
model_config: Union["ModelConfig", LoraModelConfig],
885-
runtime_mapping: Optional[Mapping] = None,
886881
uids: Optional[List[str]] = None,
887882
component: Optional[str] = None,
888883
) -> List[str]:
@@ -939,11 +934,6 @@ def load_from_hf(
939934
...
940935
941936
"""
942-
if runtime_mapping is None:
943-
runtime_mapping = Mapping()
944-
tp_size = runtime_mapping.tp_size
945-
tp_rank = runtime_mapping.tp_rank
946-
947937
if uids is None:
948938
uids = [self._generate_uid() for _ in range(len(model_dirs))]
949939
assert len(uids) == len(model_dirs)
@@ -983,6 +973,70 @@ def preprocess_lora_weights(lora_model, model_config):
983973
lora_model[key] = value
984974
return lora_model
985975

976+
def interleave_fused_lora_weights_for_tp(
977+
weight: torch.Tensor, rank_dim: int, tp_size: int, part_sizes: List[int]
978+
) -> List[torch.Tensor]:
979+
"""Interleaves fused LoRA modules weights for TP.
980+
e.g. In case of attn_qkv: Convert t_out=torch.cat([Wq, Wk, Wv]) to
981+
torch.cat([Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN])
982+
where N=TP size.
983+
""" # noqa: D205
984+
assert weight.shape[rank_dim] == sum(part_sizes)
985+
986+
# Split the weights into their respective parts. e.g. weight -> [Wq, Wk, Wv] for attn_qkv.
987+
weight_parts = [
988+
weight.narrow(rank_dim, sum(part_sizes[:i]), part_sizes[i])
989+
for i in range(len(part_sizes))
990+
]
991+
for i in range(len(part_sizes)):
992+
assert weight_parts[i].shape[rank_dim] % tp_size == 0
993+
994+
# Split each part into tp_size chunks.
995+
# e.g. [Wq, Wk, Wv] -> [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]]
996+
# where N is TP size, for attn_qkv.
997+
weight_parts_tp_weights = [
998+
torch.split(
999+
weight_parts[i], weight_parts[i].shape[rank_dim] // tp_size, dim=rank_dim
1000+
)
1001+
for i in range(len(part_sizes))
1002+
]
1003+
1004+
# Interleave the parts across TP ranks and flatten the list of lists into a single list.
1005+
# e.g. [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]]
1006+
# -> [Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN] where N is TP size, for attn_qkv.
1007+
return list(itertools.chain.from_iterable(zip(*weight_parts_tp_weights)))
1008+
1009+
def prepare_fused_lora_modules_for_tp(
1010+
lora_module: str, t_out: torch.Tensor, rank_dim: int
1011+
) -> torch.Tensor:
1012+
"""Reorders fused LoRA modules weights for TP. This is required since HF stores the parts weights
1013+
sequentially, whereas with TP>1 we need them to be interleaved so they would be sharded correctly.
1014+
1015+
See interleave_fused_lora_weights_for_tp for more details.
1016+
""" # noqa: D205
1017+
tp_size = self._mapping.tp_size
1018+
if tp_size == 1:
1019+
return t_out
1020+
part_sizes = []
1021+
if lora_module == "mlp_gate_up":
1022+
assert t_out.shape[rank_dim] % 2 == 0
1023+
half_size = t_out.shape[rank_dim] // 2
1024+
part_sizes = [half_size, half_size]
1025+
elif lora_module == "attn_qkv":
1026+
# The sizes are multiplied by tp_size because num_heads and num_kv_heads here were already
1027+
# divided by tp_size in tensorrt_llm/_torch/model_config.py::ModelConfig.get_bindings_model_config
1028+
q_size = self._model_config.head_size * self._model_config.num_heads * tp_size
1029+
kv_size = self._model_config.head_size * self._model_config.num_kv_heads * tp_size
1030+
part_sizes = [q_size, kv_size, kv_size]
1031+
1032+
if part_sizes:
1033+
interleaved_parts = interleave_fused_lora_weights_for_tp(
1034+
t_out, rank_dim, tp_size, part_sizes
1035+
)
1036+
# Concatenate them all after interleaving, as the CPP implementation expects the full non-split weights.
1037+
t_out = torch.cat(interleaved_parts, dim=rank_dim)
1038+
return t_out
1039+
9861040
def load_from_model_dir(uid, model_dir, hf_config):
9871041
if uid not in self._cpp_lora_weights:
9881042
self._cpp_lora_weights[uid] = [] # Will be converted to tensor later
@@ -1060,36 +1114,9 @@ def load_from_model_dir(uid, model_dir, hf_config):
10601114
t_mag = module_weights.get("magnitude", None)
10611115

10621116
is_dora = t_mag is not None
1063-
1064-
if lora_module in ["moe_router", "mlp_router"]:
1065-
pass
1066-
elif "moe" in lora_module and runtime_mapping.has_moe_ep():
1067-
pass
1068-
elif lora_module in [
1069-
"attn_dense",
1070-
"cross_attn_dense",
1071-
"mlp_4h_to_h",
1072-
"moe_4h_to_h",
1073-
]:
1074-
# split by row
1075-
dim = 2 if has_expert_indices else 1
1076-
assert t_in.shape[dim] % tp_size == 0
1077-
t_in = torch.split(t_in, t_in.shape[dim] // tp_size, dim=dim)[
1078-
tp_rank
1079-
].contiguous()
1080-
else:
1081-
# split by column
1082-
dim = 1 if has_expert_indices else 0
1083-
assert t_out.shape[dim] % tp_size == 0
1084-
t_out = torch.split(t_out, t_out.shape[dim] // tp_size, dim=dim)[
1085-
tp_rank
1086-
].contiguous()
1087-
if dim == 0 and is_dora and t_mag is not None:
1088-
t_mag = torch.split(t_mag, t_mag.shape[0] // tp_size, dim=0)[
1089-
tp_rank
1090-
].contiguous()
1091-
10921117
rank_dim = 1 if has_expert_indices else 0
1118+
t_out = prepare_fused_lora_modules_for_tp(lora_module, t_out, rank_dim)
1119+
10931120
effective_rank = t_in.shape[rank_dim]
10941121

10951122
t_in = t_in.cuda().contiguous()

0 commit comments

Comments
 (0)