6363 "tensorflow-scriptmode" : "tensorflow-training" ,
6464 "mxnet" : "mxnet-training" ,
6565 "tensorflow-serving" : "tensorflow-inference" ,
66- "mxnet-serving" : "mxnet-inference" ,
66+ "tensorflow-serving-eia" : "tensorflow-inference-eia" ,
67+ "mxnet-serving-eia" : "mxnet-inference-eia" ,
6768}
6869
6970MERGED_FRAMEWORKS_LOWEST_VERSIONS = {
7071 "tensorflow-scriptmode" : [1 , 13 , 1 ],
7172 "mxnet" : [1 , 4 , 1 ],
7273 "tensorflow-serving" : [1 , 13 , 0 ],
73- "mxnet-serving" : [1 , 4 , 1 ],
74+ "tensorflow-serving-eia" : [1 , 14 , 0 ],
75+ "mxnet-serving-eia" : [1 , 4 , 1 ],
7476}
7577
7678
@@ -101,7 +103,7 @@ def _is_merged_versions(framework, framework_version):
101103 return False
102104
103105
104- def _using_merged_images (region , framework , py_version , accelerator_type , framework_version ):
106+ def _using_merged_images (region , framework , py_version , framework_version ):
105107 """
106108 Args:
107109 region:
@@ -116,8 +118,11 @@ def _using_merged_images(region, framework, py_version, accelerator_type, framew
116118 return (
117119 (not is_gov_region )
118120 and is_merged_versions
119- and (is_py3 or _is_tf_14_or_later (framework , framework_version ))
120- and accelerator_type is None
121+ and (
122+ is_py3
123+ or _is_tf_14_or_later (framework , framework_version )
124+ or _is_mxnet_serving_141_or_later (framework , framework_version )
125+ )
121126 )
122127
123128
@@ -135,7 +140,25 @@ def _is_tf_14_or_later(framework, framework_version):
135140 )
136141
137142
138- def _registry_id (region , framework , py_version , account , accelerator_type , framework_version ):
143+ def _is_mxnet_serving_141_or_later (framework , framework_version ):
144+ """
145+ Args:
146+ framework:
147+ framework_version:
148+ """
149+ asimov_lowest_mxnet = [1 , 4 , 1 ]
150+
151+ version = [int (s ) for s in framework_version .split ("." )]
152+
153+ if len (version ) == 2 :
154+ version .append (0 )
155+
156+ return (
157+ framework .startswith ("mxnet-serving" ) and version >= asimov_lowest_mxnet [0 : len (version )]
158+ )
159+
160+
161+ def _registry_id (region , framework , py_version , account , framework_version ):
139162 """
140163 Args:
141164 region:
@@ -145,7 +168,7 @@ def _registry_id(region, framework, py_version, account, accelerator_type, frame
145168 accelerator_type:
146169 framework_version:
147170 """
148- if _using_merged_images (region , framework , py_version , accelerator_type , framework_version ):
171+ if _using_merged_images (region , framework , py_version , framework_version ):
149172 if region in ASIMOV_OPT_IN_ACCOUNTS_BY_REGION :
150173 return ASIMOV_OPT_IN_ACCOUNTS_BY_REGION .get (region )
151174 return "763104351884"
@@ -187,13 +210,19 @@ def create_image_uri(
187210 if py_version and py_version not in VALID_PY_VERSIONS :
188211 raise ValueError ("invalid py_version argument: {}" .format (py_version ))
189212
213+ if _accelerator_type_valid_for_framework (
214+ framework = framework ,
215+ accelerator_type = accelerator_type ,
216+ optimized_families = optimized_families ,
217+ ):
218+ framework += "-eia"
219+
190220 # Handle Account Number for Gov Cloud and frameworks with DLC merged images
191221 account = _registry_id (
192222 region = region ,
193223 framework = framework ,
194224 py_version = py_version ,
195225 account = account ,
196- accelerator_type = accelerator_type ,
197226 framework_version = framework_version ,
198227 )
199228
@@ -218,19 +247,14 @@ def create_image_uri(
218247 else :
219248 device_type = "cpu"
220249
221- if py_version :
222- tag = "{}-{}-{}" .format (framework_version , device_type , py_version )
223- else :
224- tag = "{}-{}" .format (framework_version , device_type )
250+ using_merged_images = _using_merged_images (region , framework , py_version , framework_version )
225251
226- if _accelerator_type_valid_for_framework (
227- framework = framework ,
228- accelerator_type = accelerator_type ,
229- optimized_families = optimized_families ,
230- ):
231- framework += "-eia"
252+ if not py_version or (using_merged_images and framework == "tensorflow-serving-eia" ):
253+ tag = "{}-{}" .format (framework_version , device_type )
254+ else :
255+ tag = "{}-{}-{}" .format (framework_version , device_type , py_version )
232256
233- if _using_merged_images ( region , framework , py_version , accelerator_type , framework_version ) :
257+ if using_merged_images :
234258 return "{}/{}:{}" .format (
235259 get_ecr_image_uri_prefix (account , region ), MERGED_FRAMEWORKS_REPO_MAP [framework ], tag
236260 )
0 commit comments