@@ -470,8 +470,12 @@ def create_model(
470470 role = None ,
471471 vpc_config_override = VPC_CONFIG_DEFAULT ,
472472 endpoint_type = None ,
473+ entry_point = None ,
474+ source_dir = None ,
475+ dependencies = None ,
473476 ):
474- """Create a SageMaker ``TensorFlowModel`` object that can be deployed to an ``Endpoint``.
477+ """Create a ``Model`` object that can be used for creating SageMaker model entities,
478+ deploying to a SageMaker endpoint, or starting SageMaker Batch Transform jobs.
475479
476480 Args:
477481 role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
@@ -482,27 +486,55 @@ def create_model(
482486 Default: use subnets and security groups from this Estimator.
483487 * 'Subnets' (list[str]): List of subnet ids.
484488 * 'SecurityGroupIds' (list[str]): List of security group ids.
485- endpoint_type: Optional. Selects the software stack used by the inference server.
489+ endpoint_type (str) : Optional. Selects the software stack used by the inference server.
486490 If not specified, the model will be configured to use the default
487491 SageMaker model server. If 'tensorflow-serving', the model will be configured to
488492 use the SageMaker Tensorflow Serving container.
493+ entry_point (str): Path (absolute or relative) to the local Python source file which should be executed
494+ as the entry point to training. If not specified and ``endpoint_type`` is 'tensorflow-serving',
495+ no entry point is used. If ``endpoint_type`` is also ``None``, then the training entry point is used.
496+ source_dir (str): Path (absolute or relative) to a directory with any other serving
497+ source code dependencies aside from the entry point file. If not specified and
498+ ``endpoint_type`` is 'tensorflow-serving', no source_dir is used. If ``endpoint_type`` is also ``None``,
499+ then the model source directory from training is used.
500+ dependencies (list[str]): A list of paths to directories (absolute or relative) with
501+ any additional libraries that will be exported to the container.
502+ If not specified and ``endpoint_type`` is 'tensorflow-serving', ``dependencies`` is set to ``None``.
503+ If ``endpoint_type`` is also ``None``, then the dependencies from training are used.
489504
490505 Returns:
491- sagemaker.tensorflow.model.TensorFlowModel: A SageMaker ``TensorFlowModel`` object.
492- See :func:`~sagemaker.tensorflow.model.TensorFlowModel` for full details.
506+ sagemaker.tensorflow.model.TensorFlowModel or sagemaker.tensorflow.serving.Model: A ``Model`` object.
507+ See :class:`~sagemaker.tensorflow.serving.Model` or :class:`~sagemaker.tensorflow.model.TensorFlowModel`
508+ for full details.
493509 """
494-
495510 role = role or self .role
511+
496512 if endpoint_type == "tensorflow-serving" or self ._script_mode_enabled ():
497- return self ._create_tfs_model (role = role , vpc_config_override = vpc_config_override )
513+ return self ._create_tfs_model (
514+ role = role ,
515+ vpc_config_override = vpc_config_override ,
516+ entry_point = entry_point ,
517+ source_dir = source_dir ,
518+ dependencies = dependencies ,
519+ )
498520
499521 return self ._create_default_model (
500522 model_server_workers = model_server_workers ,
501523 role = role ,
502524 vpc_config_override = vpc_config_override ,
525+ entry_point = entry_point ,
526+ source_dir = source_dir ,
527+ dependencies = dependencies ,
503528 )
504529
505- def _create_tfs_model (self , role = None , vpc_config_override = VPC_CONFIG_DEFAULT ):
530+ def _create_tfs_model (
531+ self ,
532+ role = None ,
533+ vpc_config_override = VPC_CONFIG_DEFAULT ,
534+ entry_point = None ,
535+ source_dir = None ,
536+ dependencies = None ,
537+ ):
506538 """Placeholder docstring"""
507539 return Model (
508540 model_data = self .model_data ,
@@ -513,15 +545,26 @@ def _create_tfs_model(self, role=None, vpc_config_override=VPC_CONFIG_DEFAULT):
513545 framework_version = utils .get_short_version (self .framework_version ),
514546 sagemaker_session = self .sagemaker_session ,
515547 vpc_config = self .get_vpc_config (vpc_config_override ),
548+ entry_point = entry_point ,
549+ source_dir = source_dir ,
550+ dependencies = dependencies ,
516551 )
517552
518- def _create_default_model (self , model_server_workers , role , vpc_config_override ):
553+ def _create_default_model (
554+ self ,
555+ model_server_workers ,
556+ role ,
557+ vpc_config_override ,
558+ entry_point = None ,
559+ source_dir = None ,
560+ dependencies = None ,
561+ ):
519562 """Placeholder docstring"""
520563 return TensorFlowModel (
521564 self .model_data ,
522565 role ,
523- self .entry_point ,
524- source_dir = self ._model_source_dir (),
566+ entry_point or self .entry_point ,
567+ source_dir = source_dir or self ._model_source_dir (),
525568 enable_cloudwatch_metrics = self .enable_cloudwatch_metrics ,
526569 env = {"SAGEMAKER_REQUIREMENTS" : self .requirements_file },
527570 image = self .image_name ,
@@ -533,7 +576,7 @@ def _create_default_model(self, model_server_workers, role, vpc_config_override)
533576 model_server_workers = model_server_workers ,
534577 sagemaker_session = self .sagemaker_session ,
535578 vpc_config = self .get_vpc_config (vpc_config_override ),
536- dependencies = self .dependencies ,
579+ dependencies = dependencies or self .dependencies ,
537580 )
538581
539582 def hyperparameters (self ):
@@ -625,6 +668,7 @@ def transformer(
625668 model_server_workers = None ,
626669 volume_kms_key = None ,
627670 endpoint_type = None ,
671+ entry_point = None ,
628672 ):
629673 """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
630674 SageMaker Session and base job name used by the Estimator.
@@ -656,6 +700,9 @@ def transformer(
656700 SageMaker model server.
657701 If 'tensorflow-serving', the model will be configured to
658702 use the SageMaker Tensorflow Serving container.
703+ entry_point (str): Path (absolute or relative) to the local Python source file which should be executed
704+ as the entry point to training. If not specified and ``endpoint_type`` is 'tensorflow-serving',
705+ no entry point is used. If ``endpoint_type`` is also ``None``, then the training entry point is used.
659706 """
660707
661708 role = role or self .role
@@ -664,6 +711,7 @@ def transformer(
664711 role = role ,
665712 vpc_config_override = VPC_CONFIG_DEFAULT ,
666713 endpoint_type = endpoint_type ,
714+ entry_point = entry_point ,
667715 )
668716 return model .transformer (
669717 instance_count ,
0 commit comments