@@ -201,6 +201,7 @@ def deploy(
201201 vpc_config = None ,
202202 enable_network_isolation = False ,
203203 model_kms_key = None ,
204+ predictor_cls = None ,
204205 ):
205206 """Deploy a candidate to a SageMaker Inference Pipeline and return a Predictor
206207
@@ -237,10 +238,15 @@ def deploy(
237238 training cluster for distributed training. Default: False
238239 model_kms_key (str): KMS key ARN used to encrypt the repacked
239240 model archive file if the model is repacked
241+ predictor_cls (callable[string, sagemaker.session.Session]): A
242+ function to call to create a predictor (default: None). If
243+ specified, ``deploy()`` returns the result of invoking this
244+ function on the created endpoint name.
240245
241246 Returns:
242- callable[string, sagemaker.session.Session]: Invocation of
243- ``self.predictor_cls`` on the created endpoint name.
247+ callable[string, sagemaker.session.Session] or ``None``:
248+ If ``predictor_cls`` is specified, the invocation of ``self.predictor_cls`` on
249+ the created endpoint name. Otherwise, ``None``.
244250 """
245251 if candidate is None :
246252 candidate_dict = self .best_candidate ()
@@ -264,6 +270,7 @@ def deploy(
264270 vpc_config = vpc_config ,
265271 enable_network_isolation = enable_network_isolation ,
266272 model_kms_key = model_kms_key ,
273+ predictor_cls = predictor_cls ,
267274 )
268275
269276 def _check_problem_type_and_job_objective (self , problem_type , job_objective ):
@@ -299,6 +306,7 @@ def _deploy_inference_pipeline(
299306 vpc_config = None ,
300307 enable_network_isolation = False ,
301308 model_kms_key = None ,
309+ predictor_cls = None ,
302310 ):
303311 """Deploy a SageMaker Inference Pipeline.
304312
@@ -329,6 +337,10 @@ def _deploy_inference_pipeline(
329337 contains "SecurityGroupIds", "Subnets"
330338 model_kms_key (str): KMS key ARN used to encrypt the repacked
331339 model archive file if the model is repacked
340+ predictor_cls (callable[string, sagemaker.session.Session]): A
341+ function to call to create a predictor (default: None). If
342+ specified, ``deploy()`` returns the result of invoking this
343+ function on the created endpoint name.
332344 """
333345 # construct Model objects
334346 models = []
@@ -352,6 +364,7 @@ def _deploy_inference_pipeline(
352364 pipeline = PipelineModel (
353365 models = models ,
354366 role = self .role ,
367+ predictor_cls = predictor_cls ,
355368 name = name ,
356369 vpc_config = vpc_config ,
357370 sagemaker_session = sagemaker_session or self .sagemaker_session ,
0 commit comments