Skip to content

Commit e4974f3

Browse files
committed
separate private hub setup code
1 parent ed93b9e commit e4974f3

File tree

3 files changed

+45
-44
lines changed

3 files changed

+45
-44
lines changed

tests/integ/sagemaker/jumpstart/private_hub/estimator/test_jumpstart_private_hub_estimator.py

Lines changed: 2 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -34,29 +34,10 @@
3434
get_training_dataset_for_model_and_version,
3535
)
3636

37-
MAX_INIT_TIME_SECONDS = 5
38-
39-
TEST_MODEL_IDS = {
40-
"huggingface-spc-bert-base-cased",
41-
"meta-textgeneration-llama-2-7b",
42-
"catboost-regression-model",
43-
}
44-
37+
from tests.integ.sagemaker.jumpstart.private_hub.setup import add_model_references
4538

46-
@with_exponential_backoff()
47-
def create_model_reference(hub_instance, model_arn):
48-
hub_instance.create_model_reference(model_arn=model_arn)
4939

50-
51-
@pytest.fixture(scope="session")
52-
def add_model_references():
53-
# Create Model References to test in Hub
54-
hub_instance = Hub(
55-
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
56-
)
57-
for model in TEST_MODEL_IDS:
58-
model_arn = get_public_hub_model_arn(hub_instance, model)
59-
create_model_reference(hub_instance, model_arn)
40+
MAX_INIT_TIME_SECONDS = 5
6041

6142

6243
def test_jumpstart_hub_estimator(setup, add_model_references):

tests/integ/sagemaker/jumpstart/private_hub/model/test_jumpstart_private_hub_model.py

Lines changed: 2 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -35,31 +35,10 @@
3535
with_exponential_backoff,
3636
)
3737

38-
MAX_INIT_TIME_SECONDS = 5
39-
40-
TEST_MODEL_IDS = {
41-
"catboost-classification-model",
42-
"huggingface-txt2img-conflictx-complex-lineart",
43-
"meta-textgeneration-llama-2-7b",
44-
"meta-textgeneration-llama-3-2-1b",
45-
"catboost-regression-model",
46-
}
47-
38+
from tests.integ.sagemaker.jumpstart.private_hub.setup import add_model_references
4839

49-
@with_exponential_backoff()
50-
def create_model_reference(hub_instance, model_arn):
51-
hub_instance.create_model_reference(model_arn=model_arn)
5240

53-
54-
@pytest.fixture(scope="session")
55-
def add_model_references():
56-
# Create Model References to test in Hub
57-
hub_instance = Hub(
58-
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
59-
)
60-
for model in TEST_MODEL_IDS:
61-
model_arn = get_public_hub_model_arn(hub_instance, model)
62-
create_model_reference(hub_instance, model_arn)
41+
MAX_INIT_TIME_SECONDS = 5
6342

6443

6544
def test_jumpstart_hub_model(setup, add_model_references):
Lines changed: 41 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,41 @@
1+
from __future__ import absolute_import
2+
3+
import os
4+
5+
import pytest
6+
from sagemaker.jumpstart.hub.hub import Hub
7+
8+
from tests.integ.sagemaker.jumpstart.constants import (
9+
ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME,
10+
)
11+
from tests.integ.sagemaker.jumpstart.utils import (
12+
get_public_hub_model_arn,
13+
get_sm_session,
14+
with_exponential_backoff,
15+
)
16+
17+
18+
TEST_MODEL_IDS = {
19+
"catboost-classification-model",
20+
"huggingface-txt2img-conflictx-complex-lineart",
21+
"meta-textgeneration-llama-2-7b",
22+
"meta-textgeneration-llama-3-2-1b",
23+
"catboost-regression-model",
24+
"huggingface-spc-bert-base-cased",
25+
}
26+
27+
28+
@with_exponential_backoff()
29+
def create_model_reference(hub_instance, model_arn):
30+
hub_instance.create_model_reference(model_arn=model_arn)
31+
32+
33+
@pytest.fixture(scope="session")
34+
def add_model_references():
35+
# Create Model References to test in Hub
36+
hub_instance = Hub(
37+
hub_name=os.environ[ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME], sagemaker_session=get_sm_session()
38+
)
39+
for model in TEST_MODEL_IDS:
40+
model_arn = get_public_hub_model_arn(hub_instance, model)
41+
create_model_reference(hub_instance, model_arn)

0 commit comments

Comments
 (0)