Skip to content

Commit 3b93c33

Browse files
chad119Chad Chiang
andauthored
integration test for jumpstart with mig profile (#334)
* integration test for jumpstart with mig profile * template fix for mig with jumpstart * skipped mig tests until instances setup finished --------- Co-authored-by: Chad Chiang <[email protected]>
1 parent 36140e3 commit 3b93c33

File tree

3 files changed

+248
-2
lines changed

3 files changed

+248
-2
lines changed

hyperpod-jumpstart-inference-template/hyperpod_jumpstart_inference_template/v1_1/template.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
11
TEMPLATE_CONTENT = """
2-
apiVersion: inference.sagemaker.aws.amazon.com/v1alpha1
2+
apiVersion: inference.sagemaker.aws.amazon.com/v1
33
kind: JumpStartModel
44
metadata:
5-
name: {{ model_id }}
5+
name: {{ metadata_name or endpoint_name }}
66
namespace: {{ namespace or "default" }}
77
spec:
88
model:
@@ -18,4 +18,6 @@
1818
{% if accelerator_partition_validation is not none %}validations:
1919
{% if accelerator_partition_validation is not none %} acceleratorPartitionValidation: {{ accelerator_partition_validation }}{% endif %}
2020
{% endif %}
21+
tlsConfig:
22+
tlsCertificateOutputS3Uri: {{ tls_certificate_output_s3_uri or "" }}
2123
"""
Lines changed: 121 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,121 @@
1+
import time
2+
import pytest
3+
import boto3
4+
from click.testing import CliRunner
5+
from sagemaker.hyperpod.cli.commands.inference import (
6+
js_create, custom_invoke, js_list, js_describe, js_delete, js_get_operator_logs, js_list_pods
7+
)
8+
from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint
9+
from test.integration_tests.utils import get_time_str
10+
11+
# --------- Test Configuration ---------
12+
NAMESPACE = "integration"
13+
VERSION = "1.1"
14+
REGION = "us-east-2"
15+
TIMEOUT_MINUTES = 20
16+
POLL_INTERVAL_SECONDS = 30
17+
18+
@pytest.fixture(scope="module")
19+
def runner():
20+
return CliRunner()
21+
22+
@pytest.fixture(scope="module")
23+
def js_endpoint_name():
24+
return "js-mig-cli-integration-" + get_time_str()
25+
26+
@pytest.fixture(scope="module")
27+
def sagemaker_client():
28+
return boto3.client("sagemaker", region_name=REGION)
29+
30+
# --------- JumpStart Endpoint Tests ---------
31+
@pytest.mark.skip(reason="Temporarily disabled")
32+
@pytest.mark.dependency(name="create")
33+
def test_js_create(runner, js_endpoint_name):
34+
result = runner.invoke(js_create, [
35+
"--namespace", NAMESPACE,
36+
"--version", VERSION,
37+
"--model-id", "deepseek-llm-r1-distill-qwen-1-5b",
38+
"--instance-type", "ml.p4d.24xlarge",
39+
"--endpoint-name", js_endpoint_name,
40+
"--accelerator-partition-type", "mig-7g.40gb",
41+
"--accelerator-partition-validation", "true",
42+
])
43+
assert result.exit_code == 0, result.output
44+
45+
46+
@pytest.mark.dependency(depends=["create"])
47+
def test_js_list(runner, js_endpoint_name):
48+
result = runner.invoke(js_list, ["--namespace", NAMESPACE])
49+
assert result.exit_code == 0
50+
assert js_endpoint_name in result.output
51+
52+
53+
@pytest.mark.dependency(name="describe", depends=["create"])
54+
def test_js_describe(runner, js_endpoint_name):
55+
result = runner.invoke(js_describe, [
56+
"--name", js_endpoint_name,
57+
"--namespace", NAMESPACE,
58+
"--full"
59+
])
60+
assert result.exit_code == 0
61+
assert js_endpoint_name in result.output
62+
63+
64+
@pytest.mark.dependency(depends=["create", "describe"])
65+
def test_wait_until_inservice(js_endpoint_name):
66+
"""Poll SDK until specific JumpStart endpoint reaches DeploymentComplete"""
67+
print(f"[INFO] Waiting for JumpStart endpoint '{js_endpoint_name}' to be DeploymentComplete...")
68+
deadline = time.time() + (TIMEOUT_MINUTES * 60)
69+
poll_count = 0
70+
71+
while time.time() < deadline:
72+
poll_count += 1
73+
print(f"[DEBUG] Poll #{poll_count}: Checking endpoint status...")
74+
75+
try:
76+
ep = HPJumpStartEndpoint.get(name=js_endpoint_name, namespace=NAMESPACE)
77+
state = ep.status.endpoints.sagemaker.state
78+
print(f"[DEBUG] Current state: {state}")
79+
if state == "CreationCompleted":
80+
print("[INFO] Endpoint is in CreationCompleted state.")
81+
return
82+
83+
deployment_state = ep.status.deploymentStatus.deploymentObjectOverallState
84+
if deployment_state == "DeploymentFailed":
85+
pytest.fail("Endpoint deployment failed.")
86+
87+
except Exception as e:
88+
print(f"[ERROR] Exception during polling: {e}")
89+
90+
time.sleep(POLL_INTERVAL_SECONDS)
91+
92+
pytest.fail("[ERROR] Timed out waiting for endpoint to be DeploymentComplete")
93+
94+
95+
@pytest.mark.dependency(depends=["create"])
96+
def test_custom_invoke(runner, js_endpoint_name):
97+
result = runner.invoke(custom_invoke, [
98+
"--endpoint-name", js_endpoint_name,
99+
"--body", '{"inputs": "What is the capital of USA?"}'
100+
])
101+
assert result.exit_code == 0
102+
assert "error" not in result.output.lower()
103+
104+
@pytest.mark.skip(reason="Temporarily disabled")
105+
def test_js_get_operator_logs(runner):
106+
result = runner.invoke(js_get_operator_logs, ["--since-hours", "1"])
107+
assert result.exit_code == 0
108+
109+
@pytest.mark.skip(reason="Temporarily disabled")
110+
def test_js_list_pods(runner):
111+
result = runner.invoke(js_list_pods, ["--namespace", NAMESPACE])
112+
assert result.exit_code == 0
113+
114+
115+
@pytest.mark.dependency(depends=["create"])
116+
def test_js_delete(runner, js_endpoint_name):
117+
result = runner.invoke(js_delete, [
118+
"--name", js_endpoint_name,
119+
"--namespace", NAMESPACE
120+
])
121+
assert result.exit_code == 0
Lines changed: 123 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,123 @@
1+
import time
2+
import pytest
3+
import boto3
4+
from sagemaker.hyperpod.inference.hp_jumpstart_endpoint import HPJumpStartEndpoint
5+
from sagemaker.hyperpod.inference.config.hp_jumpstart_endpoint_config import (
6+
Model, Server, SageMakerEndpoint, Validations
7+
)
8+
import sagemaker_core.main.code_injection.codec as codec
9+
from test.integration_tests.utils import get_time_str
10+
from sagemaker.hyperpod.common.config.metadata import Metadata
11+
12+
# --------- Config ---------
13+
NAMESPACE = "integration"
14+
REGION = "us-east-2"
15+
ENDPOINT_NAME = "js-mig-sdk-integration-" + get_time_str()
16+
17+
INSTANCE_TYPE = "ml.p4d.24xlarge"
18+
MODEL_ID = "deepseek-llm-r1-distill-qwen-1-5b"
19+
20+
TIMEOUT_MINUTES = 20
21+
POLL_INTERVAL_SECONDS = 30
22+
23+
@pytest.fixture(scope="module")
24+
def sagemaker_client():
25+
return boto3.client("sagemaker", region_name=REGION)
26+
27+
@pytest.fixture(scope="module")
28+
def endpoint_obj():
29+
model = Model(model_id=MODEL_ID)
30+
validations = Validations(accelerator_partition_validation=True)
31+
server = Server(
32+
instance_type=INSTANCE_TYPE,
33+
accelerator_partition_type="mig-7g.40gb",
34+
validations=validations
35+
)
36+
sm_endpoint = SageMakerEndpoint(name=ENDPOINT_NAME)
37+
metadata = Metadata(name=ENDPOINT_NAME, namespace=NAMESPACE)
38+
39+
return HPJumpStartEndpoint(metadata=metadata, model=model, server=server, sage_maker_endpoint=sm_endpoint)
40+
41+
@pytest.mark.skip(reason="Temporarily disabled")
42+
@pytest.mark.dependency(name="create")
43+
def test_create_endpoint(endpoint_obj):
44+
endpoint_obj.create()
45+
assert endpoint_obj.metadata.name == ENDPOINT_NAME
46+
47+
@pytest.mark.dependency(depends=["create"])
48+
def test_list_endpoint():
49+
endpoints = HPJumpStartEndpoint.list(namespace=NAMESPACE)
50+
names = [ep.metadata.name for ep in endpoints]
51+
assert ENDPOINT_NAME in names
52+
53+
@pytest.mark.dependency(name="describe", depends=["create"])
54+
def test_get_endpoint():
55+
ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE)
56+
assert ep.metadata.name == ENDPOINT_NAME
57+
assert ep.model.modelId == MODEL_ID
58+
59+
@pytest.mark.dependency(depends=["create", "describe"])
60+
def test_wait_until_inservice():
61+
"""Poll SDK until specific JumpStart endpoint reaches DeploymentComplete"""
62+
print(f"[INFO] Waiting for JumpStart endpoint '{ENDPOINT_NAME}' to be DeploymentComplete...")
63+
deadline = time.time() + (TIMEOUT_MINUTES * 60)
64+
poll_count = 0
65+
66+
while time.time() < deadline:
67+
poll_count += 1
68+
print(f"[DEBUG] Poll #{poll_count}: Checking endpoint status...")
69+
70+
try:
71+
ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE)
72+
state = ep.status.endpoints.sagemaker.state
73+
print(f"[DEBUG] Current state: {state}")
74+
if state == "CreationCompleted":
75+
print("[INFO] Endpoint is in CreationCompleted state.")
76+
return
77+
78+
deployment_state = ep.status.deploymentStatus.deploymentObjectOverallState
79+
if deployment_state == "DeploymentFailed":
80+
pytest.fail("Endpoint deployment failed.")
81+
82+
except Exception as e:
83+
print(f"[ERROR] Exception during polling: {e}")
84+
85+
time.sleep(POLL_INTERVAL_SECONDS)
86+
87+
pytest.fail("[ERROR] Timed out waiting for endpoint to be DeploymentComplete")
88+
89+
90+
@pytest.mark.dependency(depends=["create"])
91+
def test_invoke_endpoint(monkeypatch):
92+
original_transform = codec.transform # Save original
93+
94+
def mock_transform(data, shape, object_instance=None):
95+
if "Body" in data:
96+
return {"body": data["Body"].read().decode("utf-8")}
97+
return original_transform(data, shape, object_instance) # Call original
98+
99+
monkeypatch.setattr("sagemaker_core.main.resources.transform", mock_transform)
100+
101+
ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE)
102+
data = '{"inputs":"What is the capital of USA?"}'
103+
response = ep.invoke(body=data)
104+
105+
assert "error" not in response.body.lower()
106+
107+
108+
@pytest.mark.skip(reason="Temporarily disabled")
109+
def test_get_operator_logs():
110+
ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE)
111+
logs = ep.get_operator_logs(since_hours=1)
112+
assert logs
113+
114+
@pytest.mark.skip(reason="Temporarily disabled")
115+
def test_list_pods():
116+
ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE)
117+
pods = ep.list_pods(NAMESPACE)
118+
assert pods
119+
120+
@pytest.mark.dependency(depends=["create"])
121+
def test_delete_endpoint():
122+
ep = HPJumpStartEndpoint.get(name=ENDPOINT_NAME, namespace=NAMESPACE)
123+
ep.delete()

0 commit comments

Comments
 (0)