Skip to content

Commit 6e0cca2

Browse files
[ci] convert nxdi tests to aot compiled (#2700)
1 parent a9fafa7 commit 6e0cca2

File tree

1 file changed

+10
-4
lines changed

1 file changed

+10
-4
lines changed

tests/integration/tests.py

Lines changed: 10 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -902,15 +902,21 @@ def test_llama_speculative_compiled(self):
902902
"transformers_neuronx_rolling_batch llama-speculative-compiled-rb"
903903
.split())
904904

905-
def test_llama_vllm_nxdi(self):
905+
def test_llama_8b_vllm_nxdi(self):
906906
# For neuron, handler is names as transformers_neuronx, but this handler supports, TNX, NXDI and optimum.
907907
with Runner('pytorch-inf2', 'llama-3-1-8b-instruct-vllm-nxdi') as r:
908908
prepare.build_transformers_neuronx_handler_model(
909909
"llama-3-1-8b-instruct-vllm-nxdi")
910-
r.launch(container='pytorch-inf2-4')
910+
r.launch(
911+
container="pytorch-inf2-4",
912+
cmd=
913+
"partition --model-dir /opt/ml/input/data/training --save-mp-checkpoint-path /opt/ml/input/data/training/aot --skip-copy"
914+
)
915+
r.launch(container="pytorch-inf2-4",
916+
cmd="serve -m test=file:/opt/ml/model/test/aot")
911917
client.run(
912918
"transformers_neuronx_rolling_batch llama-3-1-8b-instruct-vllm-nxdi"
913-
)
919+
.split())
914920

915921
def test_llama_vllm_nxdi_aot(self):
916922
with Runner('pytorch-inf2',
@@ -926,7 +932,7 @@ def test_llama_vllm_nxdi_aot(self):
926932
cmd="serve -m test=file:/opt/ml/model/test/aot")
927933
client.run(
928934
"transformers_neuronx_rolling_batch llama-3-2-1b-instruct-vllm-nxdi-aot"
929-
)
935+
.split())
930936

931937

932938
@pytest.mark.correctness

0 commit comments

Comments
 (0)