6767_DIR_NAME = "/opt/ml/model/code"
6868_XGBOOST_PATH = os .path .join (DATA_DIR , "xgboost_abalone" )
6969_TENSORFLOW_PATH = os .path .join (DATA_DIR , "tfs/tfs-test-entrypoint-and-dependencies" )
70+ _REPACK_OUTPUT_KEY_PREFIX = "code-output"
71+ _MODEL_CODE_LOCATION = f"s3://{ _BUCKET } /{ _REPACK_OUTPUT_KEY_PREFIX } "
7072
7173
7274@pytest .fixture
@@ -688,6 +690,7 @@ def test_conditional_model_create_and_regis(
688690 entry_point = f"{ DATA_DIR } /{ _SCRIPT_NAME } " ,
689691 role = _ROLE ,
690692 enable_network_isolation = True ,
693+ code_location = _MODEL_CODE_LOCATION ,
691694 ),
692695 2 ,
693696 ),
@@ -711,6 +714,7 @@ def test_conditional_model_create_and_regis(
711714 entry_point = f"{ DATA_DIR } /{ _SCRIPT_NAME } " ,
712715 role = _ROLE ,
713716 framework_version = "1.5.0" ,
717+ code_location = _MODEL_CODE_LOCATION ,
714718 ),
715719 2 ,
716720 ),
@@ -742,6 +746,7 @@ def test_conditional_model_create_and_regis(
742746 image_uri = _IMAGE_URI ,
743747 entry_point = f"{ DATA_DIR } /{ _SCRIPT_NAME } " ,
744748 role = _ROLE ,
749+ code_location = _MODEL_CODE_LOCATION ,
745750 ),
746751 2 ,
747752 ),
@@ -758,21 +763,45 @@ def test_conditional_model_create_and_regis(
758763 ],
759764)
760765def test_create_model_among_different_model_types (test_input , pipeline_session , model_data_param ):
766+ def assert_test_result (steps : list ):
767+ # If expected_step_num is 2, it means a runtime repack step is appended
768+ # If expected_step_num is 1, it means no runtime repack is needed
769+ assert len (steps ) == expected_step_num
770+ if expected_step_num == 2 :
771+ assert steps [0 ]["Type" ] == "Training"
772+ if model .key_prefix == _REPACK_OUTPUT_KEY_PREFIX :
773+ assert steps [0 ]["Arguments" ]["OutputDataConfig" ]["S3OutputPath" ] == (
774+ f"{ _MODEL_CODE_LOCATION } /{ model .name } "
775+ )
776+ else :
777+ assert steps [0 ]["Arguments" ]["OutputDataConfig" ]["S3OutputPath" ] == (
778+ f"s3://{ _BUCKET } /{ model .name } "
779+ )
780+
761781 model , expected_step_num = test_input
762782 model .sagemaker_session = pipeline_session
763783 model .model_data = model_data_param
764- step_args = model .create (
784+ create_model_step_args = model .create (
765785 instance_type = "c4.4xlarge" ,
766786 )
767- model_steps = ModelStep (
787+ create_model_steps = ModelStep (
768788 name = "MyModelStep" ,
769- step_args = step_args ,
789+ step_args = create_model_step_args ,
770790 )
771- steps = model_steps .request_dicts ()
791+ assert_test_result ( create_model_steps .request_dicts () )
772792
773- # If expected_step_num is 2, it means a runtime repack step is appended
774- # If expected_step_num is 1, it means no runtime repack is needed
775- assert len (steps ) == expected_step_num
793+ register_model_step_args = model .register (
794+ content_types = ["text/csv" ],
795+ response_types = ["text/csv" ],
796+ inference_instances = ["ml.t2.medium" , "ml.m5.xlarge" ],
797+ transform_instances = ["ml.m5.xlarge" ],
798+ model_package_group_name = "MyModelPackageGroup" ,
799+ )
800+ register_model_steps = ModelStep (
801+ name = "MyModelStep" ,
802+ step_args = register_model_step_args ,
803+ )
804+ assert_test_result (register_model_steps .request_dicts ())
776805
777806
778807@pytest .mark .parametrize (
0 commit comments