@@ -593,3 +593,75 @@ def train_image(self):
593593 )
594594
595595 return super (TensorFlow , self ).train_image ()
596+
597+ def transformer (
598+ self ,
599+ instance_count ,
600+ instance_type ,
601+ strategy = None ,
602+ assemble_with = None ,
603+ output_path = None ,
604+ output_kms_key = None ,
605+ accept = None ,
606+ env = None ,
607+ max_concurrent_transforms = None ,
608+ max_payload = None ,
609+ tags = None ,
610+ role = None ,
611+ model_server_workers = None ,
612+ volume_kms_key = None ,
613+ endpoint_type = None ,
614+ ):
615+ """Return a ``Transformer`` that uses a SageMaker Model based on the training job. It reuses the
616+ SageMaker Session and base job name used by the Estimator.
617+
618+ Args:
619+ instance_count (int): Number of EC2 instances to use.
620+ instance_type (str): Type of EC2 instance to use, for example, 'ml.c4.xlarge'.
621+ strategy (str): The strategy used to decide how to batch records in a single request (default: None).
622+ Valid values: 'MULTI_RECORD' and 'SINGLE_RECORD'.
623+ assemble_with (str): How the output is assembled (default: None). Valid values: 'Line' or 'None'.
624+ output_path (str): S3 location for saving the transform result. If not specified, results are stored to
625+ a default bucket.
626+ output_kms_key (str): Optional. KMS key ID for encrypting the transform output (default: None).
627+ accept (str): The content type accepted by the endpoint deployed during the transform job.
628+ env (dict): Environment variables to be set for use during the transform job (default: None).
629+ max_concurrent_transforms (int): The maximum number of HTTP requests to be made to
630+ each individual transform container at one time.
631+ max_payload (int): Maximum size of the payload in a single HTTP request to the container in MB.
632+ tags (list[dict]): List of tags for labeling a transform job. If none specified, then the tags used for
633+ the training job are used for the transform job.
634+ role (str): The ``ExecutionRoleArn`` IAM Role ARN for the ``Model``, which is also used during
635+ transform jobs. If not specified, the role from the Estimator will be used.
636+ model_server_workers (int): Optional. The number of worker processes used by the inference server.
637+ If None, server will use one worker per vCPU.
638+ volume_kms_key (str): Optional. KMS key ID for encrypting the volume attached to the ML
639+ compute instance (default: None).
640+ endpoint_type (str): Optional. Selects the software stack used by the inference server.
641+ If not specified, the model will be configured to use the default
642+ SageMaker model server.
643+ If 'tensorflow-serving', the model will be configured to
644+ use the SageMaker Tensorflow Serving container.
645+ """
646+
647+ role = role or self .role
648+ model = self .create_model (
649+ model_server_workers = model_server_workers ,
650+ role = role ,
651+ vpc_config_override = VPC_CONFIG_DEFAULT ,
652+ endpoint_type = endpoint_type ,
653+ )
654+ return model .transformer (
655+ instance_count ,
656+ instance_type ,
657+ strategy = strategy ,
658+ assemble_with = assemble_with ,
659+ output_path = output_path ,
660+ output_kms_key = output_kms_key ,
661+ accept = accept ,
662+ env = env ,
663+ max_concurrent_transforms = max_concurrent_transforms ,
664+ max_payload = max_payload ,
665+ tags = tags ,
666+ volume_kms_key = volume_kms_key ,
667+ )
0 commit comments