3535 },
3636}
3737
38- DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py310"
38+ PARAM_SCRIPT_SOURCE_DIR = f"{ DATA_DIR } /modules/params_script"
39+ PARAM_SCRIPT_SOURCE_CODE = SourceCode (
40+ source_dir = PARAM_SCRIPT_SOURCE_DIR ,
41+ requirements = "requirements.txt" ,
42+ entry_script = "train.py" ,
43+ )
3944
45+ DEFAULT_CPU_IMAGE = "763104351884.dkr.ecr.us-west-2.amazonaws.com/pytorch-training:2.0.0-cpu-py31"
4046
41- def test_hp_contract_basic_py_script (modules_sagemaker_session ):
42- source_code = SourceCode (
43- source_dir = f"{ DATA_DIR } /modules/params_script" ,
44- requirements = "requirements.txt" ,
45- entry_script = "train.py" ,
46- )
4747
48+ def test_hp_contract_basic_py_script (modules_sagemaker_session ):
4849 model_trainer = ModelTrainer (
4950 sagemaker_session = modules_sagemaker_session ,
5051 training_image = DEFAULT_CPU_IMAGE ,
5152 hyperparameters = EXPECTED_HYPERPARAMETERS ,
52- source_code = source_code ,
53+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
5354 base_job_name = "hp-contract-basic-py-script" ,
5455 )
5556
@@ -59,6 +60,7 @@ def test_hp_contract_basic_py_script(modules_sagemaker_session):
5960def test_hp_contract_basic_sh_script (modules_sagemaker_session ):
6061 source_code = SourceCode (
6162 source_dir = f"{ DATA_DIR } /modules/params_script" ,
63+ requirements = "requirements.txt" ,
6264 entry_script = "train.sh" ,
6365 )
6466 model_trainer = ModelTrainer (
@@ -73,17 +75,13 @@ def test_hp_contract_basic_sh_script(modules_sagemaker_session):
7375
7476
7577def test_hp_contract_mpi_script (modules_sagemaker_session ):
76- source_code = SourceCode (
77- source_dir = f"{ DATA_DIR } /modules/params_script" ,
78- entry_script = "train.py" ,
79- )
8078 compute = Compute (instance_type = "ml.m5.xlarge" , instance_count = 2 )
8179 model_trainer = ModelTrainer (
8280 sagemaker_session = modules_sagemaker_session ,
8381 training_image = DEFAULT_CPU_IMAGE ,
8482 compute = compute ,
8583 hyperparameters = EXPECTED_HYPERPARAMETERS ,
86- source_code = source_code ,
84+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
8785 distributed = MPI (),
8886 base_job_name = "hp-contract-mpi-script" ,
8987 )
@@ -92,17 +90,13 @@ def test_hp_contract_mpi_script(modules_sagemaker_session):
9290
9391
9492def test_hp_contract_torchrun_script (modules_sagemaker_session ):
95- source_code = SourceCode (
96- source_dir = f"{ DATA_DIR } /modules/params_script" ,
97- entry_script = "train.py" ,
98- )
9993 compute = Compute (instance_type = "ml.m5.xlarge" , instance_count = 2 )
10094 model_trainer = ModelTrainer (
10195 sagemaker_session = modules_sagemaker_session ,
10296 training_image = DEFAULT_CPU_IMAGE ,
10397 compute = compute ,
10498 hyperparameters = EXPECTED_HYPERPARAMETERS ,
105- source_code = source_code ,
99+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
106100 distributed = Torchrun (),
107101 base_job_name = "hp-contract-torchrun-script" ,
108102 )
@@ -111,33 +105,23 @@ def test_hp_contract_torchrun_script(modules_sagemaker_session):
111105
112106
113107def test_hp_contract_hyperparameter_json (modules_sagemaker_session ):
114- source_dir = f"{ DATA_DIR } /modules/params_script"
115- source_code = SourceCode (
116- source_dir = source_dir ,
117- entry_script = "train.py" ,
118- )
119108 model_trainer = ModelTrainer (
120109 sagemaker_session = modules_sagemaker_session ,
121110 training_image = DEFAULT_CPU_IMAGE ,
122- hyperparameters = f"{ source_dir } /hyperparameters.json" ,
123- source_code = source_code ,
111+ hyperparameters = f"{ PARAM_SCRIPT_SOURCE_DIR } /hyperparameters.json" ,
112+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
124113 base_job_name = "hp-contract-hyperparameter-json" ,
125114 )
126115 assert model_trainer .hyperparameters == EXPECTED_HYPERPARAMETERS
127116 model_trainer .train ()
128117
129118
130119def test_hp_contract_hyperparameter_yaml (modules_sagemaker_session ):
131- source_dir = f"{ DATA_DIR } /modules/params_script"
132- source_code = SourceCode (
133- source_dir = source_dir ,
134- entry_script = "train.py" ,
135- )
136120 model_trainer = ModelTrainer (
137121 sagemaker_session = modules_sagemaker_session ,
138122 training_image = DEFAULT_CPU_IMAGE ,
139- hyperparameters = f"{ source_dir } /hyperparameters.yaml" ,
140- source_code = source_code ,
123+ hyperparameters = f"{ PARAM_SCRIPT_SOURCE_DIR } /hyperparameters.yaml" ,
124+ source_code = PARAM_SCRIPT_SOURCE_CODE ,
141125 base_job_name = "hp-contract-hyperparameter-yaml" ,
142126 )
143127 assert model_trainer .hyperparameters == EXPECTED_HYPERPARAMETERS
0 commit comments