Skip to content

Commit d560054

Browse files
authored
[None][chore] Restore asserts in pytorch flow LoRA tests (#8227)
Signed-off-by: Amit Zuker <[email protected]>
1 parent e101213 commit d560054

File tree

1 file changed

+15
-11
lines changed

1 file changed

+15
-11
lines changed

tests/unittest/llmapi/test_llm_pytorch.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -268,9 +268,13 @@ def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None:
268268
max_lora_rank=8,
269269
max_loras=2,
270270
max_cpu_loras=2)
271-
llm = LLM(model=f"{llm_models_root()}/llama-models/llama-7b-hf",
272-
lora_config=lora_config,
273-
**llm_kwargs)
271+
llm = LLM(
272+
model=f"{llm_models_root()}/llama-models/llama-7b-hf",
273+
lora_config=lora_config,
274+
# Disable CUDA graph
275+
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
276+
cuda_graph_config=None,
277+
**llm_kwargs)
274278
try:
275279
prompts = [
276280
"美国的首都在哪里? \n答案:",
@@ -286,10 +290,7 @@ def llama_7b_lora_from_dir_test_harness(**llm_kwargs) -> None:
286290
outputs = llm.generate(prompts,
287291
sampling_params,
288292
lora_request=lora_request)
289-
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
290-
# assert similar(outputs[0].outputs[0].text, references[0])
291-
print(f"lora output: {outputs[0].outputs[0].text}")
292-
print(f"ref output: {references[0]}")
293+
assert similar(outputs[0].outputs[0].text, references[0])
293294
finally:
294295
llm.shutdown()
295296

@@ -305,7 +306,12 @@ def test_llama_7b_lora_default_modules() -> None:
305306

306307
hf_model_dir = f"{llm_models_root()}/llama-models/llama-7b-hf"
307308

308-
llm = LLM(model=hf_model_dir, lora_config=lora_config)
309+
llm = LLM(
310+
model=hf_model_dir,
311+
lora_config=lora_config,
312+
# Disable CUDA graph
313+
# TODO: remove this once we have a proper fix for CUDA graph in LoRA
314+
cuda_graph_config=None)
309315

310316
hf_lora_dir = f"{llm_models_root()}/llama-models/luotuo-lora-7b-0.1"
311317
try:
@@ -324,9 +330,7 @@ def test_llama_7b_lora_default_modules() -> None:
324330
sampling_params,
325331
lora_request=lora_request)
326332

327-
# assert similar(outputs[0].outputs[0].text, references[0])
328-
print(f"lora output: {outputs[0].outputs[0].text}")
329-
print(f"ref output: {references[0]}")
333+
assert similar(outputs[0].outputs[0].text, references[0])
330334
finally:
331335
llm.shutdown()
332336

0 commit comments

Comments
 (0)