2929
3030import six
3131
32- import sagemaker
33-
3432ECR_URI_PATTERN = r'^(\d+)(\.)dkr(\.)ecr(\.)(.+)(\.)(amazonaws.com|c2s.ic.gov)(/)(.*:.*)$'
3533
3634
@@ -300,7 +298,12 @@ def _tmpdir(suffix='', prefix='tmp'):
300298 shutil .rmtree (tmp )
301299
302300
303- def repack_model (inference_script , source_directory , model_uri , sagemaker_session ):
301+ def repack_model (inference_script ,
302+ source_directory ,
303+ dependencies ,
304+ model_uri ,
305+ repacked_model_uri ,
306+ sagemaker_session ):
304307 """Unpack model tarball and creates a new model tarball with the provided code script.
305308
306309 This function does the following:
@@ -311,60 +314,91 @@ def repack_model(inference_script, source_directory, model_uri, sagemaker_sessio
311314 Args:
312315 inference_script (str): path or basename of the inference script that will be packed into the model
313316 source_directory (str): path including all the files that will be packed into the model
317+ dependencies (list[str]): A list of paths to directories (absolute or relative) with
318+ any additional libraries that will be exported to the container (default: []).
319+ The library folders will be copied to SageMaker in the same folder where the entrypoint is copied.
320+ Example:
321+
322+ The following call
323+ >>> Estimator(entry_point='train.py', dependencies=['my/libs/common', 'virtual-env'])
324+ results in the following inside the container:
325+
326+ >>> $ ls
327+
328+ >>> opt/ml/code
329+ >>> |------ train.py
330+ >>> |------ common
331+ >>> |------ virtual-env
332+
333+ repacked_model_uri (str): path or file system location where the new model will be saved
314334 model_uri (str): S3 or file system location of the original model tar
315335 sagemaker_session (:class:`sagemaker.session.Session`): a sagemaker session to interact with S3.
316336
317337 Returns:
318338 str: path to the new packed model
319339 """
320- new_model_name = 'model-%s.tar.gz' % sagemaker . utils . sagemaker_short_timestamp ()
340+ dependencies = dependencies or []
321341
322342 with _tmpdir () as tmp :
323- tmp_model_dir = os .path .join (tmp , 'model' )
324- os .mkdir (tmp_model_dir )
343+ model_dir = _extract_model (model_uri , sagemaker_session , tmp )
325344
326- model_from_s3 = model_uri .lower ().startswith ('s3://' )
327- if model_from_s3 :
328- local_model_path = os .path .join (tmp , 'tar_file' )
329- download_file_from_url (model_uri , local_model_path , sagemaker_session )
345+ _create_or_update_code_dir (model_dir , inference_script , source_directory , dependencies , sagemaker_session , tmp )
330346
331- new_model_path = os .path .join (tmp , new_model_name )
332- else :
333- local_model_path = model_uri .replace ('file://' , '' )
334- new_model_path = os .path .join (os .path .dirname (local_model_path ), new_model_name )
347+ tmp_model_path = os .path .join (tmp , 'temp-model.tar.gz' )
348+ with tarfile .open (tmp_model_path , mode = 'w:gz' ) as t :
349+ t .add (model_dir , arcname = os .path .sep )
335350
336- with tarfile .open (name = local_model_path , mode = 'r:gz' ) as t :
337- t .extractall (path = tmp_model_dir )
351+ _save_model (repacked_model_uri , tmp_model_path , sagemaker_session )
338352
339- code_dir = os .path .join (tmp_model_dir , 'code' )
340- if os .path .exists (code_dir ):
341- shutil .rmtree (code_dir , ignore_errors = True )
342353
343- if source_directory and source_directory .lower ().startswith ('s3://' ):
344- local_code_path = os .path .join (tmp , 'local_code.tar.gz' )
345- download_file_from_url (source_directory , local_code_path , sagemaker_session )
354+ def _save_model (repacked_model_uri , tmp_model_path , sagemaker_session ):
355+ if repacked_model_uri .lower ().startswith ('s3://' ):
356+ url = parse .urlparse (repacked_model_uri )
357+ bucket , key = url .netloc , url .path .lstrip ('/' )
358+ new_key = key .replace (os .path .basename (key ), os .path .basename (repacked_model_uri ))
346359
347- with tarfile .open (name = local_code_path , mode = 'r:gz' ) as t :
348- t .extractall (path = code_dir )
360+ sagemaker_session .boto_session .resource ('s3' ).Object (bucket , new_key ).upload_file (
361+ tmp_model_path )
362+ else :
363+ shutil .move (tmp_model_path , repacked_model_uri .replace ('file://' , '' ))
349364
350- elif source_directory :
351- shutil .copytree (source_directory , code_dir )
352- else :
353- os .mkdir (code_dir )
354- shutil .copy2 (inference_script , code_dir )
355365
356- with tarfile .open (new_model_path , mode = 'w:gz' ) as t :
357- t .add (tmp_model_dir , arcname = os .path .sep )
366+ def _create_or_update_code_dir (model_dir , inference_script , source_directory ,
367+ dependencies , sagemaker_session , tmp ):
368+ code_dir = os .path .join (model_dir , 'code' )
369+ if os .path .exists (code_dir ):
370+ shutil .rmtree (code_dir , ignore_errors = True )
371+ if source_directory and source_directory .lower ().startswith ('s3://' ):
372+ local_code_path = os .path .join (tmp , 'local_code.tar.gz' )
373+ download_file_from_url (source_directory , local_code_path , sagemaker_session )
374+
375+ with tarfile .open (name = local_code_path , mode = 'r:gz' ) as t :
376+ t .extractall (path = code_dir )
358377
359- if model_from_s3 :
360- url = parse .urlparse (model_uri )
361- bucket , key = url .netloc , url .path .lstrip ('/' )
362- new_key = key .replace (os .path .basename (key ), new_model_name )
378+ elif source_directory :
379+ shutil .copytree (source_directory , code_dir )
380+ else :
381+ os .mkdir (code_dir )
382+ shutil .copy2 (inference_script , code_dir )
363383
364- sagemaker_session .boto_session .resource ('s3' ).Object (bucket , new_key ).upload_file (new_model_path )
365- return 's3://%s/%s' % (bucket , new_key )
384+ for dependency in dependencies :
385+ if os .path .isdir (dependency ):
386+ shutil .copytree (dependency , code_dir )
366387 else :
367- return 'file://%s' % new_model_path
388+ shutil .copy2 (dependency , code_dir )
389+
390+
391+ def _extract_model (model_uri , sagemaker_session , tmp ):
392+ tmp_model_dir = os .path .join (tmp , 'model' )
393+ os .mkdir (tmp_model_dir )
394+ if model_uri .lower ().startswith ('s3://' ):
395+ local_model_path = os .path .join (tmp , 'tar_file' )
396+ download_file_from_url (model_uri , local_model_path , sagemaker_session )
397+ else :
398+ local_model_path = model_uri .replace ('file://' , '' )
399+ with tarfile .open (name = local_model_path , mode = 'r:gz' ) as t :
400+ t .extractall (path = tmp_model_dir )
401+ return tmp_model_dir
368402
369403
370404def download_file_from_url (url , dst , sagemaker_session ):
0 commit comments