|
17 | 17 | import pkg_resources |
18 | 18 |
|
19 | 19 | import sagemaker |
20 | | -from sagemaker.fw_utils import create_image_uri, model_code_key_prefix, python_deprecation_warning |
| 20 | +from sagemaker.fw_utils import ( |
| 21 | + create_image_uri, |
| 22 | + model_code_key_prefix, |
| 23 | + python_deprecation_warning, |
| 24 | + empty_framework_version_warning, |
| 25 | +) |
21 | 26 | from sagemaker.model import FrameworkModel, MODEL_SERVER_WORKERS_PARAM_NAME |
22 | | -from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION |
| 27 | +from sagemaker.pytorch.defaults import PYTORCH_VERSION, PYTHON_VERSION, LATEST_VERSION |
23 | 28 | from sagemaker.predictor import RealTimePredictor, npy_serializer, numpy_deserializer |
24 | 29 |
|
25 | 30 | logger = logging.getLogger("sagemaker") |
@@ -63,7 +68,7 @@ def __init__( |
63 | 68 | entry_point, |
64 | 69 | image=None, |
65 | 70 | py_version=PYTHON_VERSION, |
66 | | - framework_version=PYTORCH_VERSION, |
| 71 | + framework_version=None, |
67 | 72 | predictor_cls=PyTorchPredictor, |
68 | 73 | model_server_workers=None, |
69 | 74 | **kwargs |
@@ -110,9 +115,11 @@ def __init__( |
110 | 115 |
|
111 | 116 | if py_version == "py2": |
112 | 117 | logger.warning(python_deprecation_warning(self.__framework_name__)) |
| 118 | + if framework_version is None: |
| 119 | + logger.warning(empty_framework_version_warning(PYTORCH_VERSION, LATEST_VERSION)) |
113 | 120 |
|
114 | 121 | self.py_version = py_version |
115 | | - self.framework_version = framework_version |
| 122 | + self.framework_version = framework_version or PYTORCH_VERSION |
116 | 123 | self.model_server_workers = model_server_workers |
117 | 124 |
|
118 | 125 | def prepare_container_def(self, instance_type, accelerator_type=None): |
|
0 commit comments