|
53 | 53 | ) |
54 | 54 |
|
55 | 55 | VALID_PY_VERSIONS = ["py2", "py3"] |
56 | | -VALID_EIA_FRAMEWORKS = ["tensorflow", "tensorflow-serving", "mxnet", "mxnet-serving"] |
| 56 | +VALID_EIA_FRAMEWORKS = [ |
| 57 | + "tensorflow", |
| 58 | + "tensorflow-serving", |
| 59 | + "mxnet", |
| 60 | + "mxnet-serving", |
| 61 | + "pytorch-serving", |
| 62 | +] |
| 63 | +PY2_RESTRICTED_EIA_FRAMEWORKS = ["pytorch-serving"] |
57 | 64 | VALID_ACCOUNTS_BY_REGION = {"us-gov-west-1": "246785580436", "us-iso-east-1": "744548109606"} |
58 | 65 | ASIMOV_VALID_ACCOUNTS_BY_REGION = {"us-iso-east-1": "886529160074"} |
59 | 66 | OPT_IN_ACCOUNTS_BY_REGION = {"ap-east-1": "057415533634", "me-south-1": "724002660598"} |
|
71 | 78 | "mxnet-serving-eia": "mxnet-inference-eia", |
72 | 79 | "pytorch": "pytorch-training", |
73 | 80 | "pytorch-serving": "pytorch-inference", |
| 81 | + "pytorch-serving-eia": "pytorch-inference-eia", |
74 | 82 | } |
75 | 83 |
|
76 | 84 | MERGED_FRAMEWORKS_LOWEST_VERSIONS = { |
|
82 | 90 | "mxnet-serving-eia": [1, 4, 1], |
83 | 91 | "pytorch": [1, 2, 0], |
84 | 92 | "pytorch-serving": [1, 2, 0], |
| 93 | + "pytorch-serving-eia": [1, 3, 1], |
85 | 94 | } |
86 | 95 |
|
87 | 96 | DEBUGGER_UNSUPPORTED_REGIONS = ["us-gov-west-1", "us-iso-east-1"] |
@@ -207,6 +216,7 @@ def create_image_uri( |
207 | 216 |
|
208 | 217 | if _accelerator_type_valid_for_framework( |
209 | 218 | framework=framework, |
| 219 | + py_version=py_version, |
210 | 220 | accelerator_type=accelerator_type, |
211 | 221 | optimized_families=optimized_families, |
212 | 222 | ): |
@@ -259,21 +269,27 @@ def create_image_uri( |
259 | 269 |
|
260 | 270 |
|
261 | 271 | def _accelerator_type_valid_for_framework( |
262 | | - framework, accelerator_type=None, optimized_families=None |
| 272 | + framework, py_version, accelerator_type=None, optimized_families=None |
263 | 273 | ): |
264 | 274 | """ |
265 | 275 | Args: |
266 | 276 | framework: |
| 277 | + py_version: |
267 | 278 | accelerator_type: |
268 | 279 | optimized_families: |
269 | 280 | """ |
270 | 281 | if accelerator_type is None: |
271 | 282 | return False |
272 | 283 |
|
| 284 | + if py_version == "py2" and framework in PY2_RESTRICTED_EIA_FRAMEWORKS: |
| 285 | + raise ValueError( |
| 286 | + "{} is not supported with Amazon Elastic Inference in Python 2.".format(framework) |
| 287 | + ) |
| 288 | + |
273 | 289 | if framework not in VALID_EIA_FRAMEWORKS: |
274 | 290 | raise ValueError( |
275 | 291 | "{} is not supported with Amazon Elastic Inference. Currently only " |
276 | | - "Python-based TensorFlow and MXNet are supported.".format(framework) |
| 292 | + "Python-based TensorFlow, MXNet, PyTorch are supported.".format(framework) |
277 | 293 | ) |
278 | 294 |
|
279 | 295 | if optimized_families: |
|
0 commit comments