104104 "pytorch-serving-eia" : [1 , 3 , 1 ],
105105}
106106
107+ INFERENTIA_VERSION_RANGES = {
108+ "neo-mxnet" : [[1 , 5 , 1 ], [1 , 5 , 1 ]],
109+ "neo-tensorflow" : [[1 , 15 , 0 ], [1 , 15 , 0 ]],
110+ }
111+
112+ INFERENTIA_SUPPORTED_REGIONS = ["us-east-1" , "us-west-2" ]
113+
107114DEBUGGER_UNSUPPORTED_REGIONS = ["us-gov-west-1" , "us-iso-east-1" ]
108115
109116
@@ -124,6 +131,23 @@ def is_version_equal_or_higher(lowest_version, framework_version):
124131 return version_list >= lowest_version [0 : len (version_list )]
125132
126133
134+ def is_version_equal_or_lower (highest_version , framework_version ):
135+ """Determine whether the ``framework_version`` is equal to or lower than
136+ ``highest_version``
137+
138+ Args:
139+ highest_version (List[int]): highest version represented in an integer
140+ list
141+ framework_version (str): framework version string
142+
143+ Returns:
144+ bool: Whether or not ``framework_version`` is equal to or lower than
145+ ``highest_version``
146+ """
147+ version_list = [int (s ) for s in framework_version .split ("." )]
148+ return version_list <= highest_version [0 : len (version_list )]
149+
150+
127151def _is_dlc_version (framework , framework_version , py_version ):
128152 """Return if the framework's version uses the corresponding DLC image.
129153
@@ -144,6 +168,23 @@ def _is_dlc_version(framework, framework_version, py_version):
144168 return False
145169
146170
171+ def _is_inferentia_supported (framework , framework_version ):
172+ """Return if Inferentia supports the framework and its version.
173+
174+ Args:
175+ framework (str): The framework name, e.g. "tensorflow"
176+ framework_version (str): The framework version
177+
178+ Returns:
179+ bool: Whether or not Inferentia supports the framework and its version.
180+ """
181+ lowest_version_list = INFERENTIA_VERSION_RANGES .get (framework )[0 ]
182+ highest_version_list = INFERENTIA_VERSION_RANGES .get (framework )[1 ]
183+ return is_version_equal_or_higher (
184+ lowest_version_list , framework_version
185+ ) and is_version_equal_or_lower (highest_version_list , framework_version )
186+
187+
147188def _registry_id (region , framework , py_version , account , framework_version ):
148189 """Return the Amazon ECR registry number (or AWS account ID) for
149190 the given framework, framework version, Python version, and region.
@@ -240,11 +281,34 @@ def create_image_uri(
240281 # 'cpu' or 'gpu'.
241282 if family in optimized_families :
242283 device_type = family
284+ elif family .startswith ("inf" ):
285+ device_type = "inf"
243286 elif family [0 ] in ["g" , "p" ]:
244287 device_type = "gpu"
245288 else :
246289 device_type = "cpu"
247290
291+ if device_type == "inf" :
292+ if region not in INFERENTIA_SUPPORTED_REGIONS :
293+ raise ValueError (
294+ "Inferentia is not supported in region {}. Supported regions are {}" .format (
295+ region , ", " .join (INFERENTIA_SUPPORTED_REGIONS )
296+ )
297+ )
298+ if framework not in INFERENTIA_VERSION_RANGES :
299+ raise ValueError (
300+ "Inferentia does not support {}. Currently it supports "
301+ "MXNet and TensorFlow with more frameworks coming soon." .format (
302+ framework .split ("-" )[- 1 ]
303+ )
304+ )
305+ if not _is_inferentia_supported (framework , framework_version ):
306+ raise ValueError (
307+ "Inferentia is not supported with {} version {}." .format (
308+ framework .split ("-" )[- 1 ], framework_version
309+ )
310+ )
311+
248312 use_dlc_image = _is_dlc_version (framework , framework_version , py_version )
249313
250314 if not py_version or (use_dlc_image and framework == "tensorflow-serving-eia" ):
0 commit comments