3434import sagemaker
3535from sagemaker .utils import get_config_value
3636
37- CONTAINER_PREFIX = " algo"
37+ CONTAINER_PREFIX = ' algo'
3838DOCKER_COMPOSE_FILENAME = 'docker-compose.yaml'
3939
40+ # Environment variables to be set during training
41+ REGION_ENV_NAME = 'AWS_REGION'
42+ TRAINING_JOB_NAME_ENV_NAME = 'TRAINING_JOB_NAME'
43+
4044logger = logging .getLogger (__name__ )
4145logger .setLevel (logging .WARNING )
4246
@@ -102,7 +106,12 @@ def train(self, input_data_config, hyperparameters):
102106 self .write_config_files (host , hyperparameters , input_data_config )
103107 shutil .copytree (data_dir , os .path .join (self .container_root , host , 'input' , 'data' ))
104108
105- compose_data = self ._generate_compose_file ('train' , additional_volumes = volumes )
109+ training_env_vars = {
110+ REGION_ENV_NAME : self .sagemaker_session .boto_region_name ,
111+ TRAINING_JOB_NAME_ENV_NAME : json .loads (hyperparameters .get (sagemaker .model .JOB_NAME_PARAM_NAME )),
112+ }
113+ compose_data = self ._generate_compose_file ('train' , additional_volumes = volumes ,
114+ additional_env_vars = training_env_vars )
106115 compose_command = self ._compose ()
107116
108117 _ecr_login_if_needed (self .sagemaker_session .boto_session , self .image )
@@ -149,7 +158,6 @@ def serve(self, model_dir, environment):
149158 logger .info ('creating hosting dir in {}' .format (self .container_root ))
150159
151160 volumes = self ._prepare_serving_volumes (model_dir )
152- env_vars = ['{}={}' .format (k , v ) for k , v in environment .items ()]
153161
154162 # If the user script was passed as a file:// mount it to the container.
155163 if sagemaker .estimator .DIR_PARAM_NAME .upper () in environment :
@@ -161,7 +169,7 @@ def serve(self, model_dir, environment):
161169 _ecr_login_if_needed (self .sagemaker_session .boto_session , self .image )
162170
163171 self ._generate_compose_file ('serve' ,
164- additional_env_vars = env_vars ,
172+ additional_env_vars = environment ,
165173 additional_volumes = volumes )
166174 compose_command = self ._compose ()
167175 self .container = _HostingContainer (compose_command )
@@ -384,7 +392,8 @@ def _generate_compose_file(self, command, additional_volumes=None, additional_en
384392 if aws_creds is not None :
385393 environment .extend (aws_creds )
386394
387- environment .extend (additional_env_vars )
395+ additional_env_var_list = ['{}={}' .format (k , v ) for k , v in additional_env_vars .items ()]
396+ environment .extend (additional_env_var_list )
388397
389398 if command == 'train' :
390399 optml_dirs = {'output' , 'output/data' , 'input' }
0 commit comments