@@ -270,20 +270,6 @@ def retrieve(
270270 return ECR_URI_TEMPLATE .format (registry = registry , hostname = hostname , repository = repo )
271271
272272
273- def _get_instance_type_family (instance_type ):
274- """Return the family of the instance type.
275-
276- Regex matches either "ml.<family>.<size>" or "ml_<family>. If input is None
277- or there is no match, return an empty string.
278- """
279- instance_type_family = ""
280- if isinstance (instance_type , str ):
281- match = re .match (r"^ml[\._]([a-z\d]+)\.?\w*$" , instance_type )
282- if match is not None :
283- instance_type_family = match [1 ]
284- return instance_type_family
285-
286-
287273def _get_image_tag (
288274 container_version ,
289275 distribution ,
@@ -297,7 +283,7 @@ def _get_image_tag(
297283 version ,
298284):
299285 """Return image tag based on framework, container, and compute configuration(s)."""
300- instance_type_family = _get_instance_type_family (instance_type )
286+ instance_type_family = utils . get_instance_type_family (instance_type )
301287 if framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
302288 if instance_type_family and final_image_scope == INFERENCE_GRAVITON :
303289 _validate_arg (
@@ -385,7 +371,7 @@ def _config_for_framework_and_scope(framework, image_scope, accelerator_type=Non
385371
386372def _validate_instance_deprecation (framework , instance_type , version ):
387373 """Check if instance type is deprecated for a certain framework with a certain version"""
388- if _get_instance_type_family (instance_type ) == "p2" :
374+ if utils . get_instance_type_family (instance_type ) == "p2" :
389375 if (framework == "pytorch" and Version (version ) >= Version ("1.13" )) or (
390376 framework == "tensorflow" and Version (version ) >= Version ("2.12" )
391377 ):
@@ -409,7 +395,7 @@ def _validate_for_suppported_frameworks_and_instance_type(framework, instance_ty
409395 # Validate for Graviton allowed frameowrks
410396 if (
411397 instance_type is not None
412- and _get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
398+ and utils . get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
413399 and framework not in GRAVITON_ALLOWED_FRAMEWORKS
414400 ):
415401 _validate_framework (framework , GRAVITON_ALLOWED_FRAMEWORKS , "framework" , "Graviton" )
@@ -426,7 +412,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
426412 """Return final image scope based on provided framework and instance type."""
427413 if (
428414 framework in GRAVITON_ALLOWED_FRAMEWORKS
429- and _get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
415+ and utils . get_instance_type_family (instance_type ) in GRAVITON_ALLOWED_TARGET_INSTANCE_FAMILY
430416 ):
431417 return INFERENCE_GRAVITON
432418 if image_scope is None and framework in (XGBOOST_FRAMEWORK , SKLEARN_FRAMEWORK ):
@@ -441,7 +427,7 @@ def _get_final_image_scope(framework, instance_type, image_scope):
441427def _get_inference_tool (inference_tool , instance_type ):
442428 """Extract the inference tool name from instance type."""
443429 if not inference_tool :
444- instance_type_family = _get_instance_type_family (instance_type )
430+ instance_type_family = utils . get_instance_type_family (instance_type )
445431 if instance_type_family .startswith ("inf" ) or instance_type_family .startswith ("trn" ):
446432 return "neuron"
447433 return inference_tool
@@ -529,7 +515,7 @@ def _processor(instance_type, available_processors, serverless_inference_config=
529515 processor = "neuron"
530516 else :
531517 # looks for either "ml.<family>.<size>" or "ml_<family>"
532- family = _get_instance_type_family (instance_type )
518+ family = utils . get_instance_type_family (instance_type )
533519 if family :
534520 # For some frameworks, we have optimized images for specific families, e.g c5 or p3.
535521 # In those cases, we use the family name in the image tag. In other cases, we use
@@ -559,7 +545,7 @@ def _should_auto_select_container_version(instance_type, distribution):
559545 p4d = False
560546 if instance_type :
561547 # looks for either "ml.<family>.<size>" or "ml_<family>"
562- family = _get_instance_type_family (instance_type )
548+ family = utils . get_instance_type_family (instance_type )
563549 if family :
564550 p4d = family == "p4d"
565551
0 commit comments