3535 },
3636}
3737
38- DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310"
38+ PARAM_SCRIPT_SOURCE_DIR = f"{ DATA_DIR } /modules/params_script"
39+ PARAM_SCRIPT_SOURCE_CODE = SourceCode (
40+ source_dir = PARAM_SCRIPT_SOURCE_DIR ,
41+ requirements = "requirements.txt" ,
42+ entry_script = "train.py" ,
43+ )
44+
45+ DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py31"
3946
4047
4148def test_hp_contract_basic_py_script (modules_sagemaker_session ):
@@ -59,6 +66,7 @@ def test_hp_contract_basic_py_script(modules_sagemaker_session):
5966def test_hp_contract_basic_sh_script (modules_sagemaker_session ):
6067 source_code = SourceCode (
6168 source_dir = f"{ DATA_DIR } /modules/params_script" ,
69+ requirements = "requirements.txt" ,
6270 entry_script = "train.sh" ,
6371 )
6472 model_trainer = ModelTrainer (
@@ -73,17 +81,13 @@ def test_hp_contract_basic_sh_script(modules_sagemaker_session):
7381
7482
7583def test_hp_contract_mpi_script (modules_sagemaker_session ):
76- source_code = SourceCode (
77- source_dir = f"{ DATA_DIR } /modules/params_script" ,
78- entry_script = "train.py" ,
79- )
8084 compute = Compute (instance_type = "ml.m5.xlarge" , instance_count = 2 )
8185 model_trainer = ModelTrainer (
8286 sagemaker_session = modules_sagemaker_session ,
8387 training_image = DEFAULT_CPU_IMAGE ,
8488 compute = compute ,
8589 hyperparameters = EXPECTED_HYPERPARAMETERS ,
86- source_code = source_code ,
90+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
8791 distributed = MPI (),
8892 base_job_name = "hp-contract-mpi-script" ,
8993 )
@@ -92,17 +96,13 @@ def test_hp_contract_mpi_script(modules_sagemaker_session):
9296
9397
9498def test_hp_contract_torchrun_script (modules_sagemaker_session ):
95- source_code = SourceCode (
96- source_dir = f"{ DATA_DIR } /modules/params_script" ,
97- entry_script = "train.py" ,
98- )
9999 compute = Compute (instance_type = "ml.m5.xlarge" , instance_count = 2 )
100100 model_trainer = ModelTrainer (
101101 sagemaker_session = modules_sagemaker_session ,
102102 training_image = DEFAULT_CPU_IMAGE ,
103103 compute = compute ,
104104 hyperparameters = EXPECTED_HYPERPARAMETERS ,
105- source_code = source_code ,
105+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
106106 distributed = Torchrun (),
107107 base_job_name = "hp-contract-torchrun-script" ,
108108 )
@@ -111,33 +111,23 @@ def test_hp_contract_torchrun_script(modules_sagemaker_session):
111111
112112
113113def test_hp_contract_hyperparameter_json (modules_sagemaker_session ):
114- source_dir = f"{ DATA_DIR } /modules/params_script"
115- source_code = SourceCode (
116- source_dir = source_dir ,
117- entry_script = "train.py" ,
118- )
119114 model_trainer = ModelTrainer (
120115 sagemaker_session = modules_sagemaker_session ,
121116 training_image = DEFAULT_CPU_IMAGE ,
122- hyperparameters = f"{ source_dir } /hyperparameters.json" ,
123- source_code = source_code ,
117+ hyperparameters = f"{ PARAM_SCRIPT_SOURCE_DIR } /hyperparameters.json" ,
118+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
124119 base_job_name = "hp-contract-hyperparameter-json" ,
125120 )
126121 assert model_trainer .hyperparameters == EXPECTED_HYPERPARAMETERS
127122 model_trainer .train ()
128123
129124
130125def test_hp_contract_hyperparameter_yaml (modules_sagemaker_session ):
131- source_dir = f"{ DATA_DIR } /modules/params_script"
132- source_code = SourceCode (
133- source_dir = source_dir ,
134- entry_script = "train.py" ,
135- )
136126 model_trainer = ModelTrainer (
137127 sagemaker_session = modules_sagemaker_session ,
138128 training_image = DEFAULT_CPU_IMAGE ,
139- hyperparameters = f"{ source_dir } /hyperparameters.yaml" ,
140- source_code = source_code ,
129+ hyperparameters = f"{ PARAM_SCRIPT_SOURCE_DIR } /hyperparameters.yaml" ,
130+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
141131 base_job_name = "hp-contract-hyperparameter-yaml" ,
142132 )
143133 assert model_trainer .hyperparameters == EXPECTED_HYPERPARAMETERS
0 commit comments