File tree Expand file tree Collapse file tree 3 files changed +50
-45
lines changed
tests/integ/sagemaker/jumpstart/private_hub Expand file tree Collapse file tree 3 files changed +50
-45
lines changed Original file line number Diff line number Diff line change 34
34
get_training_dataset_for_model_and_version ,
35
35
)
36
36
37
- from tests . integ . sagemaker . jumpstart . private_hub . setup import add_model_references
37
+ MAX_INIT_TIME_SECONDS = 5
38
38
39
+ TEST_MODEL_IDS = {
40
+ "huggingface-spc-bert-base-cased" ,
41
+ "meta-textgeneration-llama-2-7b" ,
42
+ "catboost-regression-model" ,
43
+ }
39
44
40
- MAX_INIT_TIME_SECONDS = 5
45
+
46
+ @with_exponential_backoff ()
47
+ def create_model_reference (hub_instance , model_arn ):
48
+ try :
49
+ hub_instance .create_model_reference (model_arn = model_arn )
50
+ except :
51
+ pass
52
+
53
+
54
+ @pytest .fixture (scope = "session" )
55
+ def add_model_references ():
56
+ # Create Model References to test in Hub
57
+ hub_instance = Hub (
58
+ hub_name = os .environ [ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME ], sagemaker_session = get_sm_session ()
59
+ )
60
+ for model in TEST_MODEL_IDS :
61
+ model_arn = get_public_hub_model_arn (hub_instance , model )
62
+ create_model_reference (hub_instance , model_arn )
41
63
42
64
43
65
def test_jumpstart_hub_estimator (setup , add_model_references ):
Original file line number Diff line number Diff line change 35
35
with_exponential_backoff ,
36
36
)
37
37
38
- from tests . integ . sagemaker . jumpstart . private_hub . setup import add_model_references
38
+ MAX_INIT_TIME_SECONDS = 5
39
39
40
+ TEST_MODEL_IDS = {
41
+ "catboost-classification-model" ,
42
+ "huggingface-txt2img-conflictx-complex-lineart" ,
43
+ "meta-textgeneration-llama-2-7b" ,
44
+ "meta-textgeneration-llama-3-2-1b" ,
45
+ "catboost-regression-model" ,
46
+ }
40
47
41
- MAX_INIT_TIME_SECONDS = 5
48
+
49
+ @with_exponential_backoff ()
50
+ def create_model_reference (hub_instance , model_arn ):
51
+ try :
52
+ hub_instance .create_model_reference (model_arn = model_arn )
53
+ except :
54
+ pass
55
+
56
+
57
+ @pytest .fixture (scope = "session" )
58
+ def add_model_references ():
59
+ # Create Model References to test in Hub
60
+ hub_instance = Hub (
61
+ hub_name = os .environ [ENV_VAR_JUMPSTART_SDK_TEST_HUB_NAME ], sagemaker_session = get_sm_session ()
62
+ )
63
+ for model in TEST_MODEL_IDS :
64
+ model_arn = get_public_hub_model_arn (hub_instance , model )
65
+ create_model_reference (hub_instance , model_arn )
42
66
43
67
44
68
def test_jumpstart_hub_model (setup , add_model_references ):
Load Diff This file was deleted.
You can’t perform that action at this time.
0 commit comments