2424import string
2525import subprocess
2626import sys
27+ import tarfile
2728import tempfile
2829from fcntl import fcntl , F_GETFL , F_SETFL
2930from six .moves .urllib .parse import urlparse
@@ -137,7 +138,7 @@ def serve(self, primary_container):
137138 Args:
138139 primary_container (dict): dictionary containing the container runtime settings
139140 for serving. Expected keys:
140- - 'ModelDataUrl' pointing to a local file
141+ - 'ModelDataUrl' pointing to a file or s3:// location.
141142 - 'Environment' a dictionary of environment variables to be passed to the hosting container.
142143
143144 """
@@ -147,22 +148,17 @@ def serve(self, primary_container):
147148 logger .info ('creating hosting dir in {}' .format (self .container_root ))
148149
149150 model_dir = primary_container ['ModelDataUrl' ]
150- if not model_dir .lower ().startswith ("s3://" ):
151- for h in self .hosts :
152- host_dir = os .path .join (self .container_root , h )
153- os .makedirs (host_dir )
154- shutil .copytree (model_dir , os .path .join (self .container_root , h , 'model' ))
155-
151+ volumes = self ._prepare_serving_volumes (model_dir )
156152 env_vars = ['{}={}' .format (k , v ) for k , v in primary_container ['Environment' ].items ()]
157153
158- _ecr_login_if_needed (self .sagemaker_session .boto_session , self .image )
159-
160154 # If the user script was passed as a file:// mount it to the container.
161- script_dir = primary_container ['Environment' ][sagemaker .estimator .DIR_PARAM_NAME .upper ()]
162- parsed_uri = urlparse (script_dir )
163- volumes = []
164- if parsed_uri .scheme == 'file' :
165- volumes .append (_Volume (parsed_uri .path , '/opt/ml/code' ))
155+ if sagemaker .estimator .DIR_PARAM_NAME .upper () in primary_container ['Environment' ]:
156+ script_dir = primary_container ['Environment' ][sagemaker .estimator .DIR_PARAM_NAME .upper ()]
157+ parsed_uri = urlparse (script_dir )
158+ if parsed_uri .scheme == 'file' :
159+ volumes .append (_Volume (parsed_uri .path , '/opt/ml/code' ))
160+
161+ _ecr_login_if_needed (self .sagemaker_session .boto_session , self .image )
166162
167163 self ._generate_compose_file ('serve' ,
168164 additional_env_vars = env_vars ,
@@ -278,9 +274,20 @@ def _download_folder(self, bucket_name, prefix, target):
278274 pass
279275 obj .download_file (file_path )
280276
277+ def _download_file (self , bucket_name , path , target ):
278+ path = path .lstrip ('/' )
279+ boto_session = self .sagemaker_session .boto_session
280+
281+ s3 = boto_session .resource ('s3' )
282+ bucket = s3 .Bucket (bucket_name )
283+ bucket .download_file (path , target )
284+
281285 def _prepare_training_volumes (self , data_dir , input_data_config , hyperparameters ):
282286 shared_dir = os .path .join (self .container_root , 'shared' )
287+ model_dir = os .path .join (self .container_root , 'model' )
283288 volumes = []
289+
290+ volumes .append (_Volume (model_dir , '/opt/ml/model' ))
284291 # Set up the channels for the containers. For local data we will
285292 # mount the local directory to the container. For S3 Data we will download the S3 data
286293 # first.
@@ -321,6 +328,32 @@ def _prepare_training_volumes(self, data_dir, input_data_config, hyperparameters
321328
322329 return volumes
323330
331+ def _prepare_serving_volumes (self , model_location ):
332+ volumes = []
333+ host = self .hosts [0 ]
334+ # Make the model available to the container. If this is a local file just mount it to
335+ # the container as a volume. If it is an S3 location download it and extract the tar file.
336+ host_dir = os .path .join (self .container_root , host )
337+ os .makedirs (host_dir )
338+
339+ if model_location .startswith ('s3' ):
340+ container_model_dir = os .path .join (self .container_root , host , 'model' )
341+ os .makedirs (container_model_dir )
342+
343+ parsed_uri = urlparse (model_location )
344+ filename = os .path .basename (parsed_uri .path )
345+ tar_location = os .path .join (container_model_dir , filename )
346+ self ._download_file (parsed_uri .netloc , parsed_uri .path , tar_location )
347+
348+ if tarfile .is_tarfile (tar_location ):
349+ with tarfile .open (tar_location ) as tar :
350+ tar .extractall (path = container_model_dir )
351+ volumes .append (_Volume (container_model_dir , '/opt/ml/model' ))
352+ else :
353+ volumes .append (_Volume (model_location , '/opt/ml/model' ))
354+
355+ return volumes
356+
324357 def _generate_compose_file (self , command , additional_volumes = None , additional_env_vars = None ):
325358 """Writes a config file describing a training/hosting environment.
326359
@@ -452,10 +485,6 @@ def _build_optml_volumes(self, host, subdirs):
452485 """
453486 volumes = []
454487
455- # Ensure that model is in the subdirs
456- if 'model' not in subdirs :
457- subdirs .add ('model' )
458-
459488 for subdir in subdirs :
460489 host_dir = os .path .join (self .container_root , host , subdir )
461490 container_dir = '/opt/ml/{}' .format (subdir )
0 commit comments