Skip to content

Commit 34038b0

Browse files
committed
[https://nvbugs/5510879][fix] Fix pytorch & TRT-python flows fused LoRA adapter modules weight split with TP>1 (NVIDIA#8063)
Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com>
1 parent d5b7926 commit 34038b0

File tree

11 files changed

+265
-78
lines changed

11 files changed

+265
-78
lines changed

tensorrt_llm/_torch/pyexecutor/resource_manager.py

Lines changed: 21 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
1313
from tensorrt_llm.lora_helper import LoraConfig
1414
from tensorrt_llm.lora_manager import LoraManager, LoraModelConfig
15+
from tensorrt_llm.runtime import ModelConfig as ModelConfigPython
1516
from tensorrt_llm.sampling_params import SamplingParams
1617

1718
from ..._utils import binding_to_str_dtype, get_size_in_bytes, nvtx_range
@@ -31,7 +32,7 @@
3132
KVCacheManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheManager
3233
KvCacheConfigCpp = tensorrt_llm.bindings.executor.KvCacheConfig
3334
CacheTypeCpp = tensorrt_llm.bindings.internal.batch_manager.CacheType
34-
ModelConfig = tensorrt_llm.bindings.ModelConfig
35+
ModelConfigCpp = tensorrt_llm.bindings.ModelConfig
3536
DataType = tensorrt_llm.bindings.DataType
3637
KVCacheEventManagerCpp = tensorrt_llm.bindings.internal.batch_manager.KVCacheEventManager
3738
RequestList = list[LlmRequest]
@@ -159,7 +160,7 @@ def __init__(
159160
spec_config: Optional["DecodingBaseConfig"] = None,
160161
layer_mask: Optional[List[bool]] = None,
161162
max_num_tokens: int = 8192,
162-
model_config: Optional[ModelConfig] = None,
163+
model_config: Optional[ModelConfigCpp] = None,
163164
max_beam_width: int = 1,
164165
is_draft: bool = False,
165166
kv_connector_manager: Optional[KvCacheConnectorManager] = None,
@@ -370,7 +371,7 @@ def shutdown(self):
370371

371372
@classmethod
372373
def from_model_config(cls,
373-
model_config: ModelConfig,
374+
model_config: ModelConfigCpp,
374375
kv_cache_config: KvCacheConfigCpp,
375376
mapping: Mapping,
376377
kv_cache_type: CacheTypeCpp = CacheTypeCpp.SELF,
@@ -753,7 +754,7 @@ def adjust_window_sizes_for_vswa(
753754
window_size_to_layers: Dict[int, List[int]],
754755
max_attention_window_vec: List[int],
755756
kv_cache_config: KvCacheConfigCpp,
756-
model_config: ModelConfig,
757+
model_config: ModelConfigCpp,
757758
pool_memory_bytes: int,
758759
kv_factor: int,
759760
dtype: DataType,
@@ -868,7 +869,7 @@ def calculate_cache_size_per_token(layers: Set[int]) -> int:
868869
def calculate_max_num_blocks_from_cpp(
869870
self,
870871
kv_cache_config: KvCacheConfigCpp,
871-
model_config: ModelConfig,
872+
model_config: ModelConfigCpp,
872873
extra_cost_memory: int = 0) -> dict[int, tuple[int, int]]:
873874
"""
874875
This function is a wrapper of KVCacheManagerCpp.calculate_max_num_blocks.
@@ -1111,7 +1112,7 @@ class PeftCacheManager(BaseResourceManager):
11111112
def __init__(self,
11121113
peft_cache_config: PeftCacheConfig,
11131114
lora_config: LoraConfig,
1114-
model_config: ModelConfig,
1115+
model_config: ModelConfigCpp,
11151116
world_config: WorldConfig | None = None):
11161117
import tensorrt_llm.bindings as _tb
11171118

@@ -1147,7 +1148,20 @@ def __init__(self,
11471148
lora_config.trtllm_modules_to_hf_modules, model_config.hidden_size,
11481149
binding_to_str_dtype(model_config.data_type),
11491150
lora_config.swap_gate_up_proj_lora_b_weight)
1150-
self._lora_manager = LoraManager()
1151+
mapping = Mapping(
1152+
world_size=world_config.size,
1153+
rank=world_config.rank,
1154+
tp_size=world_config.tensor_parallelism,
1155+
pp_size=world_config.pipeline_parallelism,
1156+
gpus_per_node=world_config.gpus_per_node,
1157+
)
1158+
self._lora_manager = LoraManager(
1159+
mapping=mapping,
1160+
model_config=ModelConfigPython.from_model_config_cpp(model_config),
1161+
cpp_peft_cache_manager=self.impl)
1162+
1163+
def get_lora_manager(self) -> LoraManager:
1164+
return self._lora_manager
11511165

11521166
def add_request_peft(self, request: LlmRequest):
11531167
if request.lora_task_id is not None:
@@ -1161,7 +1175,6 @@ def add_request_peft(self, request: LlmRequest):
11611175
self._lora_manager.load_from_ckpt(
11621176
[request.py_lora_path],
11631177
model_config=self._lora_model_config,
1164-
runtime_mapping=None,
11651178
uids=[request.lora_task_id],
11661179
ckpt_source=self._lora_config.lora_ckpt_source)
11671180
request.lora_weights = self._lora_manager.cpp_lora_weights[

tensorrt_llm/_utils.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@
4141
import tensorrt as trt
4242
# isort: on
4343

44-
from tensorrt_llm.bindings import DataType, GptJsonConfig
44+
from tensorrt_llm.bindings import DataType, GptJsonConfig, LayerType
4545
from tensorrt_llm.bindings.BuildInfo import ENABLE_MULTI_DEVICE
4646
from tensorrt_llm.logger import logger
4747

@@ -197,6 +197,10 @@ def str_dtype_to_torch(dtype):
197197
}
198198

199199

200+
def binding_layer_type_to_str(layer_type: LayerType) -> str:
201+
return layer_type.name.lower()
202+
203+
200204
def binding_to_str_dtype(binding_dtype) -> str:
201205
ret = _binding_to_str_dtype.get(binding_dtype)
202206
assert ret is not None, f'Unsupported binding dtype: {binding_dtype}'

tensorrt_llm/executor/base_worker.py

Lines changed: 5 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -205,7 +205,10 @@ def _create_engine(executor_config):
205205
# point in the TRT flow is currently not supported (it's at the CPP
206206
# Executor->ExecutorImpl->TrtGptModel->mPeftCacheManager) therefore for now this LoRA
207207
# optimization is not available in TRT-python flow.
208-
self._lora_manager = LoraManager(cpp_peft_cache_manager=None)
208+
self._lora_manager = LoraManager(
209+
mapping=engine_config.pretrained_config.mapping,
210+
model_config=self._runtime_model_config,
211+
cpp_peft_cache_manager=None)
209212
if engine_config.build_config.max_prompt_embedding_table_size > 0:
210213
self._prompt_adapter_manager = PromptAdapterManager()
211214

@@ -216,8 +219,7 @@ def _create_engine(executor_config):
216219
ResourceManagerType
217220
peft_cache_manager = self.engine.resource_manager.resource_managers.get(
218221
ResourceManagerType.PEFT_CACHE_MANAGER)
219-
self._lora_manager = LoraManager(
220-
cpp_peft_cache_manager=peft_cache_manager.impl)
222+
self._lora_manager = peft_cache_manager.get_lora_manager()
221223
lora_model_config = self.engine.model_engine.lora_model_config
222224
assert lora_model_config is not None
223225
self._lora_model_config = lora_model_config
@@ -296,7 +298,6 @@ def _load_lora_adapter(self, lora_request: LoRARequest) -> bool:
296298
[lora_request.path],
297299
model_config=self._runtime_model_config if
298300
self._runtime_model_config is not None else self._lora_model_config,
299-
runtime_mapping=None,
300301
uids=[adapter_id],
301302
ckpt_source=lora_request.ckpt_source)
302303
return adapter_id in newly_loaded_uids

tensorrt_llm/lora_helper.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -46,6 +46,7 @@ def get_default_trtllm_modules_to_hf_modules():
4646
"attn_q": "q_proj",
4747
"attn_k": "k_proj",
4848
"attn_v": "v_proj",
49+
"attn_qkv": "qkv_proj",
4950
"attn_dense": "o_proj",
5051
"mlp_h_to_4h": "gate_proj",
5152
"mlp_4h_to_h": "down_proj",

tensorrt_llm/lora_manager.py

Lines changed: 76 additions & 49 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
import io
2+
import itertools
23
import json
34
import logging
45
import re
@@ -660,11 +661,17 @@ class LoraManager(object):
660661
}
661662

662663
def __init__(
663-
self, cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None
664+
self,
665+
*,
666+
mapping: Mapping,
667+
model_config: "ModelConfig",
668+
cpp_peft_cache_manager: tb_internal.batch_manager.PeftCacheManager | None = None,
664669
):
665670
"""Constructor.
666671

667672
Args:
673+
mapping (Mapping): Parallelism related information.
674+
model_config (ModelConfig): model configuration python class instance.
668675
cpp_peft_cache_manager (PeftCacheManager, optional): used by is_adapter_in_cpu_cache method, that's used for
669676
a performance optimization with LoRA of not sending the LoRA adapter weights with every LLM request when
670677
the adapter is already loaded in the LoRA CPU cache.
@@ -704,6 +711,8 @@ def __init__(
704711
self._cpp_lora_weights: Dict[str, torch.Tensor] = {} # on cpu
705712
self._cpp_lora_config: Dict[str, torch.Tensor] = {} # on cpu
706713
self.lora_target_modules: List[str] = []
714+
self._mapping = mapping
715+
self._model_config = model_config
707716
self._cpp_peft_cache_manager = cpp_peft_cache_manager
708717

709718
def is_adapter_in_cpu_cache(self, adapter_uid: int) -> bool:
@@ -730,7 +739,6 @@ def load_from_ckpt(
730739
self,
731740
model_dirs_or_files: List[str],
732741
model_config: Union["ModelConfig", LoraModelConfig],
733-
runtime_mapping: Optional[Mapping] = None,
734742
uids: Optional[List[str]] = None,
735743
ckpt_source: str = "hf",
736744
) -> List[str]:
@@ -743,7 +751,6 @@ def load_from_ckpt(
743751
return self.load_from_hf(
744752
model_dirs=model_dirs_or_files,
745753
model_config=model_config,
746-
runtime_mapping=runtime_mapping,
747754
uids=uids,
748755
)
749756
elif ckpt_source == "nemo":
@@ -754,7 +761,6 @@ def load_from_ckpt(
754761
return self.load_from_nemo(
755762
model_files=nemo_files,
756763
model_config=model_config,
757-
runtime_mapping=runtime_mapping,
758764
uids=uids,
759765
)
760766
else:
@@ -764,19 +770,13 @@ def load_from_nemo(
764770
self,
765771
model_files: List[str],
766772
model_config: Union["ModelConfig", LoraModelConfig],
767-
runtime_mapping: Optional[Mapping] = None,
768773
uids: Optional[List[str]] = None,
769774
) -> List[str]:
770775
"""Returns the adapter UIDs that were loaded by this call.
771776

772777
Note that when an adapter was already loaded before this call, it would not be
773778
included in the returned list of UIDs.
774779
"""
775-
if runtime_mapping is None:
776-
runtime_mapping = Mapping()
777-
tp_size = runtime_mapping.tp_size
778-
tp_rank = runtime_mapping.tp_rank
779-
780780
if uids is None:
781781
uids = [self._generate_uid() for _ in range(len(model_files))]
782782
assert len(uids) == len(model_files)
@@ -829,10 +829,6 @@ def load_from_model_file(uid, model_file):
829829

830830
t_in = all_lora_weights[layer_idx]["in"]
831831
t_out = all_lora_weights[layer_idx]["out"]
832-
assert t_out.shape[0] % tp_size == 0
833-
t_out = torch.split(t_out, t_out.shape[0] // tp_size, dim=0)[
834-
tp_rank
835-
].contiguous()
836832
else:
837833
t_in = None
838834
t_out = None
@@ -882,7 +878,6 @@ def load_from_hf(
882878
self,
883879
model_dirs: List[str],
884880
model_config: Union["ModelConfig", LoraModelConfig],
885-
runtime_mapping: Optional[Mapping] = None,
886881
uids: Optional[List[str]] = None,
887882
component: Optional[str] = None,
888883
) -> List[str]:
@@ -939,11 +934,6 @@ def load_from_hf(
939934
...
940935

941936
"""
942-
if runtime_mapping is None:
943-
runtime_mapping = Mapping()
944-
tp_size = runtime_mapping.tp_size
945-
tp_rank = runtime_mapping.tp_rank
946-
947937
if uids is None:
948938
uids = [self._generate_uid() for _ in range(len(model_dirs))]
949939
assert len(uids) == len(model_dirs)
@@ -983,6 +973,70 @@ def preprocess_lora_weights(lora_model, model_config):
983973
lora_model[key] = value
984974
return lora_model
985975

976+
def interleave_fused_lora_weights_for_tp(
977+
weight: torch.Tensor, rank_dim: int, tp_size: int, part_sizes: List[int]
978+
) -> List[torch.Tensor]:
979+
"""Interleaves fused LoRA modules weights for TP.
980+
e.g. In case of attn_qkv: Convert t_out=torch.cat([Wq, Wk, Wv]) to
981+
torch.cat([Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN])
982+
where N=TP size.
983+
""" # noqa: D205
984+
assert weight.shape[rank_dim] == sum(part_sizes)
985+
986+
# Split the weights into their respective parts. e.g. weight -> [Wq, Wk, Wv] for attn_qkv.
987+
weight_parts = [
988+
weight.narrow(rank_dim, sum(part_sizes[:i]), part_sizes[i])
989+
for i in range(len(part_sizes))
990+
]
991+
for i in range(len(part_sizes)):
992+
assert weight_parts[i].shape[rank_dim] % tp_size == 0
993+
994+
# Split each part into tp_size chunks.
995+
# e.g. [Wq, Wk, Wv] -> [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]]
996+
# where N is TP size, for attn_qkv.
997+
weight_parts_tp_weights = [
998+
torch.split(
999+
weight_parts[i], weight_parts[i].shape[rank_dim] // tp_size, dim=rank_dim
1000+
)
1001+
for i in range(len(part_sizes))
1002+
]
1003+
1004+
# Interleave the parts across TP ranks and flatten the list of lists into a single list.
1005+
# e.g. [[Wq_rank0, ..., Wq_rankN], [Wk_rank0, ..., Wk_rankN], [Wv_rank0, ..., Wv_rankN]]
1006+
# -> [Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN] where N is TP size, for attn_qkv.
1007+
return list(itertools.chain.from_iterable(zip(*weight_parts_tp_weights)))
1008+
1009+
def prepare_fused_lora_modules_for_tp(
1010+
lora_module: str, t_out: torch.Tensor, rank_dim: int
1011+
) -> torch.Tensor:
1012+
"""Reorders fused LoRA modules weights for TP. This is required since HF stores the parts weights
1013+
sequentially, whereas with TP>1 we need them to be interleaved so they would be sharded correctly.
1014+
1015+
See interleave_fused_lora_weights_for_tp for more details.
1016+
""" # noqa: D205
1017+
tp_size = self._mapping.tp_size
1018+
if tp_size == 1:
1019+
return t_out
1020+
part_sizes = []
1021+
if lora_module == "mlp_gate_up":
1022+
assert t_out.shape[rank_dim] % 2 == 0
1023+
half_size = t_out.shape[rank_dim] // 2
1024+
part_sizes = [half_size, half_size]
1025+
elif lora_module == "attn_qkv":
1026+
# The sizes are multiplied by tp_size because num_heads and num_kv_heads here were already
1027+
# divided by tp_size in tensorrt_llm/_torch/model_config.py::ModelConfig.get_bindings_model_config
1028+
q_size = self._model_config.head_size * self._model_config.num_heads * tp_size
1029+
kv_size = self._model_config.head_size * self._model_config.num_kv_heads * tp_size
1030+
part_sizes = [q_size, kv_size, kv_size]
1031+
1032+
if part_sizes:
1033+
interleaved_parts = interleave_fused_lora_weights_for_tp(
1034+
t_out, rank_dim, tp_size, part_sizes
1035+
)
1036+
# Concatenate them all after interleaving, as the CPP implementation expects the full non-split weights.
1037+
t_out = torch.cat(interleaved_parts, dim=rank_dim)
1038+
return t_out
1039+
9861040
def load_from_model_dir(uid, model_dir, hf_config):
9871041
if uid not in self._cpp_lora_weights:
9881042
self._cpp_lora_weights[uid] = [] # Will be converted to tensor later
@@ -1060,36 +1114,9 @@ def load_from_model_dir(uid, model_dir, hf_config):
10601114
t_mag = module_weights.get("magnitude", None)
10611115

10621116
is_dora = t_mag is not None
1063-
1064-
if lora_module in ["moe_router", "mlp_router"]:
1065-
pass
1066-
elif "moe" in lora_module and runtime_mapping.has_moe_ep():
1067-
pass
1068-
elif lora_module in [
1069-
"attn_dense",
1070-
"cross_attn_dense",
1071-
"mlp_4h_to_h",
1072-
"moe_4h_to_h",
1073-
]:
1074-
# split by row
1075-
dim = 2 if has_expert_indices else 1
1076-
assert t_in.shape[dim] % tp_size == 0
1077-
t_in = torch.split(t_in, t_in.shape[dim] // tp_size, dim=dim)[
1078-
tp_rank
1079-
].contiguous()
1080-
else:
1081-
# split by column
1082-
dim = 1 if has_expert_indices else 0
1083-
assert t_out.shape[dim] % tp_size == 0
1084-
t_out = torch.split(t_out, t_out.shape[dim] // tp_size, dim=dim)[
1085-
tp_rank
1086-
].contiguous()
1087-
if dim == 0 and is_dora and t_mag is not None:
1088-
t_mag = torch.split(t_mag, t_mag.shape[0] // tp_size, dim=0)[
1089-
tp_rank
1090-
].contiguous()
1091-
10921117
rank_dim = 1 if has_expert_indices else 0
1118+
t_out = prepare_fused_lora_modules_for_tp(lora_module, t_out, rank_dim)
1119+
10931120
effective_rank = t_in.shape[rank_dim]
10941121

10951122
t_in = t_in.cuda().contiguous()

0 commit comments

Comments
 (0)