Skip to content

Commit bbae7a0

Browse files
authored
[https://nvbugs/5521949][fix] Replace test_codellama_fp8_with_bf16_lora with test_llama_3_1_8b_fp8_with_bf16_lora (#8199)
Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com>
1 parent 1e0fbb7 commit bbae7a0

File tree

1 file changed

+27
-59
lines changed

1 file changed

+27
-59
lines changed

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 27 additions & 59 deletions
Original file line numberDiff line numberDiff line change
@@ -26,14 +26,13 @@
2626
prompts, run_llm_abort_request,
2727
run_llm_with_postprocess_parallel_and_result_handler,
2828
tinyllama_logits_processor_test_harness)
29-
from utils.util import (force_ampere, similar, skip_gpu_memory_less_than_40gb,
29+
from utils.util import (force_ampere, similar, skip_fp8_pre_ada,
30+
skip_gpu_memory_less_than_40gb,
3031
skip_gpu_memory_less_than_80gb,
3132
skip_gpu_memory_less_than_138gb, skip_ray)
3233
from utils.llm_data import llm_models_root
3334
from tensorrt_llm.lora_helper import LoraConfig
3435
from tensorrt_llm.executor.request import LoRARequest
35-
from tensorrt_llm.models.modeling_utils import QuantConfig
36-
from tensorrt_llm.quantization.mode import QuantAlgo
3736
import tempfile
3837

3938
import torch
@@ -508,64 +507,33 @@ def test_nemotron_nas_lora() -> None:
508507

509508

510509
@skip_gpu_memory_less_than_80gb
511-
@pytest.mark.skip(reason="https://nvbugs/5521949")
512-
def test_codellama_fp8_with_bf16_lora() -> None:
513-
model_dir = f"{llm_models_root()}/codellama/CodeLlama-7b-Instruct-hf/"
514-
quant_config = QuantConfig(quant_algo=QuantAlgo.FP8,
515-
kv_cache_quant_algo=QuantAlgo.FP8)
516-
517-
target_modules = ['attn_q', 'attn_k', 'attn_v']
518-
519-
# Set up temporary directory for LoRA adapters
520-
with tempfile.TemporaryDirectory() as lora_dir:
521-
print("Creating dummy LoRAs...")
522-
523-
model = AutoModelForCausalLM.from_pretrained(
524-
model_dir,
525-
torch_dtype=torch.bfloat16,
526-
device_map="auto",
527-
trust_remote_code=True,
528-
)
529-
530-
hf_modules = ["q_proj", "k_proj", "v_proj"]
531-
532-
lora_config = PeftLoraConfig(r=8,
533-
target_modules=hf_modules,
534-
bias="none",
535-
task_type="CAUSAL_LM")
536-
537-
lora_paths = []
538-
for i in range(2):
539-
lora_model = get_peft_model(model, lora_config)
540-
for param in lora_model.parameters():
541-
param.data.zero_()
542-
lora_path = f"{lora_dir}/lora_{i}"
543-
lora_model.save_pretrained(lora_path)
544-
lora_paths.append(lora_path)
545-
546-
lora_config = LoraConfig(lora_dir=lora_paths,
547-
lora_target_modules=target_modules,
548-
max_lora_rank=8,
549-
max_loras=2,
550-
max_cpu_loras=2)
551-
552-
llm = LLM(model_dir, quant_config=quant_config, lora_config=lora_config)
553-
554-
prompts = [
555-
"Write a function that calculates the Fibonacci sequence.",
556-
"Convert this C++ code to Python: int x = 0; x++;",
557-
]
558-
559-
lora_req1 = LoRARequest("lora-1", 0, lora_paths[0])
560-
lora_req2 = LoRARequest("lora-2", 1, lora_paths[1])
561-
lora_requests = [lora_req1, lora_req2]
562-
sampling_params = SamplingParams(max_tokens=200)
510+
def test_llama_3_1_8b_fp8_with_bf16_lora() -> None:
511+
skip_fp8_pre_ada(use_fp8=True)
512+
model_dir = f"{llm_models_root()}/llama-3.1-model/Llama-3.1-8B-Instruct-FP8"
513+
lora_dir = f"{llm_models_root()}/lora/llama-3-chinese-8b-instruct-v2-lora"
514+
prompt = "美国的首都是哪里?"
515+
reference = "华盛顿特区。华盛顿特区是美国的首都和一个行政区"
516+
517+
lora_config = LoraConfig(lora_dir=[lora_dir],
518+
max_lora_rank=64,
519+
max_loras=2,
520+
max_cpu_loras=2)
521+
lora_req = LoRARequest("lora-chinese", 0, lora_dir)
563522

564-
outputs = llm.generate(prompts,
565-
sampling_params,
566-
lora_request=lora_requests)
523+
llm = LLM(
524+
model_dir,
525+
lora_config=lora_config,
526+
# Disable CUDA graph
527+
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
528+
cuda_graph_config=None)
567529

568-
assert len(outputs) == 2
530+
try:
531+
output = llm.generate(prompt,
532+
SamplingParams(max_tokens=20),
533+
lora_request=[lora_req])
534+
finally:
535+
llm.shutdown()
536+
assert similar(output.outputs[0].text, reference)
569537

570538

571539
@skip_gpu_memory_less_than_80gb

0 commit comments

Comments
 (0)