|
14 | 14 | from __future__ import absolute_import |
15 | 15 |
|
16 | 16 | import logging |
17 | | -from typing import Optional, Union, List, Dict |
| 17 | +from typing import Callable, Optional, Union, List, Dict |
18 | 18 |
|
19 | 19 | import sagemaker |
20 | 20 | from sagemaker import image_uris, ModelMetrics |
@@ -123,7 +123,7 @@ def __init__( |
123 | 123 | pytorch_version: Optional[str] = None, |
124 | 124 | py_version: Optional[str] = None, |
125 | 125 | image_uri: Optional[Union[str, PipelineVariable]] = None, |
126 | | - predictor_cls: callable = HuggingFacePredictor, |
| 126 | + predictor_cls: Optional[Callable] = HuggingFacePredictor, |
127 | 127 | model_server_workers: Optional[Union[int, PipelineVariable]] = None, |
128 | 128 | **kwargs, |
129 | 129 | ): |
@@ -158,7 +158,7 @@ def __init__( |
158 | 158 | If not specified, a default image for PyTorch will be used. If ``framework_version`` |
159 | 159 | or ``py_version`` are ``None``, then ``image_uri`` is required. If |
160 | 160 | also ``None``, then a ``ValueError`` will be raised. |
161 | | - predictor_cls (callable[str, sagemaker.session.Session]): A function |
| 161 | + predictor_cls (Callable[[string, sagemaker.session.Session], Any]): A function |
162 | 162 | to call to create a predictor with an endpoint name and |
163 | 163 | SageMaker ``Session``. If specified, ``deploy()`` returns the |
164 | 164 | result of invoking this function on the created endpoint name. |
@@ -304,7 +304,7 @@ def deploy( |
304 | 304 | - If a wrong type of object is provided as serverless inference config or async |
305 | 305 | inference config |
306 | 306 | Returns: |
307 | | - callable[string, sagemaker.session.Session] or None: Invocation of |
| 307 | + Optional[Callable[[string, sagemaker.session.Session], Any]]: Invocation of |
308 | 308 | ``self.predictor_cls`` on the created endpoint name, if ``self.predictor_cls`` |
309 | 309 | is not None. Otherwise, return None. |
310 | 310 | """ |
|
0 commit comments