|
3 | 3 |
|
4 | 4 | import pytest |
5 | 5 | import torch |
6 | | -from transformers import Starcoder2Config |
| 6 | +from peft import LoraConfig as PeftLoraConfig |
| 7 | +from peft import get_peft_model |
| 8 | +from transformers import AutoModelForCausalLM, Starcoder2Config |
7 | 9 | from transformers import Starcoder2ForCausalLM as HFStarcoder2ForCausalLM |
| 10 | +from utils.llm_data import llm_models_root |
8 | 11 | from utils.util import default_dtype |
9 | 12 |
|
10 | 13 | import tensorrt_llm |
| 14 | +from tensorrt_llm import LLM |
11 | 15 | from tensorrt_llm._torch.attention_backend.utils import get_attention_backend |
12 | 16 | from tensorrt_llm._torch.metadata import KVCacheParams |
13 | 17 | from tensorrt_llm._torch.model_config import ModelConfig |
14 | 18 | from tensorrt_llm._torch.models.modeling_starcoder2 import Starcoder2ForCausalLM |
15 | 19 | from tensorrt_llm._torch.modules.layer_norm import LayerNorm |
16 | 20 | from tensorrt_llm._torch.pyexecutor.resource_manager import KVCacheManager |
17 | 21 | from tensorrt_llm.bindings.executor import KvCacheConfig |
| 22 | +from tensorrt_llm.executor.request import LoRARequest |
| 23 | +from tensorrt_llm.lora_manager import LoraConfig |
18 | 24 | from tensorrt_llm.mapping import Mapping |
| 25 | +from tensorrt_llm.sampling_params import SamplingParams |
19 | 26 |
|
20 | 27 | # Base config for all StarCoder2 models (based on HuggingFace configs) |
21 | 28 | _STARCODER2_BASE_CONFIG = { |
@@ -311,3 +318,109 @@ def test_starcoder2_allclose_to_hf(scenario: Scenario) -> None: |
311 | 318 | if graph_runner is not None: |
312 | 319 | graph_runner.clear() |
313 | 320 | kv_cache_manager.shutdown() |
| 321 | + |
| 322 | + |
| 323 | +@torch.no_grad() |
| 324 | +def test_starcoder2_multi_lora(tmp_path) -> None: |
| 325 | + """ |
| 326 | + Test StarCoder2 3b model with multiple synthetic LoRA adapters created using PEFT. |
| 327 | +
|
| 328 | + This test creates dummy LoRA adapters for StarCoder2 and verifies that: |
| 329 | + 1. Multiple LoRA adapters can be loaded and used simultaneously |
| 330 | + 2. Different requests can use different LoRA adapters |
| 331 | + 3. The model produces reasonable outputs with LoRA adapters applied |
| 332 | + """ |
| 333 | + |
| 334 | + # Check if we have enough GPU memory (need ~10GB for StarCoder2-3B + LoRA) |
| 335 | + _, total_mem = torch.cuda.mem_get_info() |
| 336 | + min_mem_required = 10 * (2**30) # 10 GB |
| 337 | + if total_mem < min_mem_required: |
| 338 | + pytest.skip("Insufficient GPU memory for StarCoder2 with LoRA test") |
| 339 | + |
| 340 | + # Check for pretrained model |
| 341 | + model_path = f"{llm_models_root()}/starcoder2-3b" |
| 342 | + |
| 343 | + # Target modules for LoRA - attention projections |
| 344 | + target_modules = ["attn_q", "attn_k", "attn_v", "attn_dense"] |
| 345 | + |
| 346 | + # Load the pretrained model to create LoRA adapters |
| 347 | + model = AutoModelForCausalLM.from_pretrained( |
| 348 | + model_path, torch_dtype=torch.bfloat16, device_map="auto", trust_remote_code=True |
| 349 | + ) |
| 350 | + |
| 351 | + # HuggingFace module names for StarCoder2 attention |
| 352 | + hf_modules = ["q_proj", "k_proj", "v_proj", "o_proj"] |
| 353 | + |
| 354 | + peft_lora_config = PeftLoraConfig( |
| 355 | + r=8, # LoRA rank |
| 356 | + lora_alpha=16, |
| 357 | + target_modules=hf_modules, |
| 358 | + lora_dropout=0.0, |
| 359 | + bias="none", |
| 360 | + task_type="CAUSAL_LM", |
| 361 | + ) |
| 362 | + |
| 363 | + # Create two synthetic LoRA adapters with zeroed weights |
| 364 | + lora_paths = [] |
| 365 | + for i in range(2): |
| 366 | + lora_model = get_peft_model(model, peft_lora_config) |
| 367 | + |
| 368 | + # Zero out all LoRA parameters for deterministic testing |
| 369 | + for name, param in lora_model.named_parameters(): |
| 370 | + if "lora_" in name: |
| 371 | + param.data.zero_() |
| 372 | + |
| 373 | + # Save the LoRA adapter |
| 374 | + lora_path = tmp_path / f"lora_{i}" |
| 375 | + lora_model.save_pretrained(lora_path) |
| 376 | + lora_paths.append(str(lora_path)) |
| 377 | + |
| 378 | + del model |
| 379 | + del lora_model |
| 380 | + torch.cuda.empty_cache() |
| 381 | + |
| 382 | + # Configure TensorRT-LLM LoRA |
| 383 | + trtllm_lora_config = LoraConfig( |
| 384 | + lora_target_modules=target_modules, max_lora_rank=8, max_loras=2, max_cpu_loras=2 |
| 385 | + ) |
| 386 | + |
| 387 | + llm = LLM( |
| 388 | + model_path, |
| 389 | + lora_config=trtllm_lora_config, |
| 390 | + # Disable CUDA graph for LoRA (LoRA is not supported with CUDA graphs yet) |
| 391 | + cuda_graph_config=None, |
| 392 | + ) |
| 393 | + |
| 394 | + with llm: |
| 395 | + prompts = [ |
| 396 | + "def fibonacci(n):", |
| 397 | + "def quick_sort(arr):", |
| 398 | + ] |
| 399 | + |
| 400 | + lora_req1 = LoRARequest("lora-1", 0, lora_paths[0]) |
| 401 | + lora_req2 = LoRARequest("lora-2", 1, lora_paths[1]) |
| 402 | + lora_requests = [lora_req1, lora_req2] |
| 403 | + |
| 404 | + # Sampling parameters |
| 405 | + sampling_params = SamplingParams( |
| 406 | + max_tokens=50, |
| 407 | + temperature=0.0, # Greedy decoding for deterministic output |
| 408 | + ) |
| 409 | + |
| 410 | + outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests) |
| 411 | + |
| 412 | + # Verify we got outputs for both prompts |
| 413 | + assert len(outputs) == 2, f"Expected 2 outputs, got {len(outputs)}" |
| 414 | + |
| 415 | + # Verify each output has text |
| 416 | + for i, output in enumerate(outputs): |
| 417 | + assert len(output.outputs) > 0, f"Output {i} has no results" |
| 418 | + assert len(output.outputs[0].text) > 0, f"Output {i} generated empty text" |
| 419 | + |
| 420 | + # Test without LoRA for comparison |
| 421 | + outputs_no_lora = llm.generate(prompts, sampling_params, lora_request=None) |
| 422 | + |
| 423 | + assert len(outputs_no_lora) == 2 |
| 424 | + |
| 425 | + assert outputs[0].outputs[0].text == outputs_no_lora[0].outputs[0].text |
| 426 | + assert outputs[1].outputs[0].text == outputs_no_lora[1].outputs[0].text |
0 commit comments