Skip to content

Commit 9acb65c

Browse files
amitz-nvdominicshanshan
authored andcommitted
[None][chore] Restore asserts in pytorch flow LoRA tests (NVIDIA#8227)
Signed-off-by: Amit Zuker <203509407+amitz-nv@users.noreply.github.com>
1 parent 3477245 commit 9acb65c

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -267,9 +267,13 @@ def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None:
267267
max_lora_rank=8,
268268
max_loras=2,
269269
max_cpu_loras=2)
270-
llm = LLM(model=f"{llm_models_root()}/llama-models/llama-7b-hf",
271-
lora_config=lora_config,
272-
**llm_kwargs)
270+
llm = LLM(
271+
model=f"{llm_models_root()}/llama-models/llama-7b-hf",
272+
lora_config=lora_config,
273+
# Disable CUDA graph
274+
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
275+
cuda_graph_config=None,
276+
**llm_kwargs)
273277
try:
274278
prompts = [
275279
"美国的首都在哪里? \n答案:",
@@ -285,10 +289,7 @@ def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None:
285289
outputs = llm.generate(prompts,
286290
sampling_params,
287291
lora_request=lora_request)
288-
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
289-
# assert similar(outputs[0].outputs[0].text, references[0])
290-
print(f"lora output: {outputs[0].outputs[0].text}")
291-
print(f"ref output: {references[0]}")
292+
assert similar(outputs[0].outputs[0].text, references[0])
292293
finally:
293294
llm.shutdown()
294295

@@ -304,7 +305,12 @@ def test_llama_7b_lora_default_modules() -> None:
304305

305306
hf_model_dir = f"{llm_models_root()}/llama-models/llama-7b-hf"
306307

307-
llm = LLM(model=hf_model_dir, lora_config=lora_config)
308+
llm = LLM(
309+
model=hf_model_dir,
310+
lora_config=lora_config,
311+
# Disable CUDA graph
312+
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
313+
cuda_graph_config=None)
308314

309315
hf_lora_dir = f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"
310316
try:
@@ -323,9 +329,7 @@ def test_llama_7b_lora_default_modules() -> None:
323329
sampling_params,
324330
lora_request=lora_request)
325331

326-
# assert similar(outputs[0].outputs[0].text, references[0])
327-
print(f"lora output: {outputs[0].outputs[0].text}")
328-
print(f"ref output: {references[0]}")
332+
assert similar(outputs[0].outputs[0].text, references[0])
329333
finally:
330334
llm.shutdown()
331335

0 commit comments

Comments
 (0)