Skip to content

Commit 2857625

Browse files
committed
Mark test to run only when there are 2 GPUs, improve documentation
Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com>
1 parent 633e838 commit 2857625

File tree

2 files changed

+10
-5
lines changed

2 files changed

+10
-5
lines changed

tensorrt_llm/lora_manager.py

Lines changed: 9 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -976,6 +976,11 @@ def preprocess_lora_weights(lora_model, model_config):
976976
def interleave_fused_lora_weights_for_tp(
977977
weight: torch.Tensor, rank_dim: int, tp_size: int, part_sizes: List[int]
978978
) -> List[torch.Tensor]:
979+
"""Interleaves fused LoRA modules weights for TP.
980+
e.g. In case of attn_qkv: Convert t_out=torch.cat([Wq, Wk, Wv]) to
981+
torch.cat([Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN])
982+
where N=TP size.
983+
""" # noqa: D205
979984
assert weight.shape[rank_dim] == sum(part_sizes)
980985

981986
# Split the weights into their respective parts. e.g. weight -> [Wq, Wk, Wv] for attn_qkv.
@@ -1004,11 +1009,10 @@ def interleave_fused_lora_weights_for_tp(
10041009
def prepare_fused_lora_modules_for_tp(
10051010
lora_module: str, t_out: torch.Tensor, rank_dim: int
10061011
) -> torch.Tensor:
1007-
"""Interleaves fused LoRA modules weights for TP. This is required since HF stores the parts weights
1008-
sequentially, whereas with TP>1 we need them to be interleaved.
1009-
e.g. In case of attn_qkv: Convert t_out=torch.cat([Wq, Wk, Wv]) to
1010-
torch.cat([Wq_rank0, Wk_rank0, Wv_rank0, ..., Wq_rankN, Wk_rankN, Wv_rankN])
1011-
where N=TP size.
1012+
"""Reorders fused LoRA modules weights for TP. This is required since HF stores the parts weights
1013+
sequentially, whereas with TP>1 we need them to be interleaved so they would be sharded correctly.
1014+
1015+
See interleave_fused_lora_weights_for_tp for more details.
10121016
""" # noqa: D205
10131017
tp_size = self._mapping.tp_size
10141018
if tp_size == 1:

tests/unittest/llmapi/test_llm_multi_gpu_pytorch.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -62,6 +62,7 @@ def test_llama_7b_multi_lora_tp2():
6262
cuda_graph_config=None)
6363

6464

65+
@pytest.mark.gpu2
6566
def test_phi3_lora_fused_modules_output_on_tp2_identical_to_tp1() -> None:
6667
check_phi3_lora_fused_modules_output_tp2_identical_to_tp1(
6768
LLM,

0 commit comments

Comments
 (0)