1313"""This module contains functions for obtaining JumpStart resoure requirements."""
1414from __future__ import absolute_import
1515
16- from typing import Optional
16+ from typing import Dict , Optional , Tuple
1717
1818from sagemaker .jumpstart .constants import (
1919 DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
2828from sagemaker .session import Session
2929from sagemaker .compute_resource_requirements .resource_requirements import ResourceRequirements
3030
31+ REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP : Dict [
32+ str , Dict [str , Tuple [str , str ]]
33+ ] = {
34+ "requests" : {
35+ "num_accelerators" : ("num_accelerators" , "num_accelerators" ),
36+ "num_cpus" : ("num_cpus" , "num_cpus" ),
37+ "copies" : ("copies" , "copy_count" ),
38+ "min_memory_mb" : ("memory" , "min_memory" ),
39+ },
40+ "limits" : {
41+ "max_memory_mb" : ("memory" , "max_memory" ),
42+ },
43+ }
44+
3145
3246def _retrieve_default_resources (
3347 model_id : str ,
@@ -37,6 +51,7 @@ def _retrieve_default_resources(
3751 tolerate_vulnerable_model : bool = False ,
3852 tolerate_deprecated_model : bool = False ,
3953 sagemaker_session : Session = DEFAULT_JUMPSTART_SAGEMAKER_SESSION ,
54+ instance_type : Optional [str ] = None ,
4055) -> ResourceRequirements :
4156 """Retrieves the default resource requirements for the model.
4257
@@ -60,6 +75,8 @@ def _retrieve_default_resources(
6075 object, used for SageMaker interactions. If not
6176 specified, one is created using the default AWS configuration
6277 chain. (Default: sagemaker.jumpstart.constants.DEFAULT_JUMPSTART_SAGEMAKER_SESSION).
78+ instance_type (str): An instance type to optionally supply in order to get
79+ host requirements specific for the instance type.
6380 Returns:
6481 str: The default resource requirements to use for the model or None.
6582
@@ -87,23 +104,44 @@ def _retrieve_default_resources(
87104 is_dynamic_container_deployment_supported = (
88105 model_specs .dynamic_container_deployment_supported
89106 )
90- default_resource_requirements = model_specs .hosting_resource_requirements
107+ default_resource_requirements : Dict [str , int ] = (
108+ model_specs .hosting_resource_requirements or {}
109+ )
91110 else :
92111 raise NotImplementedError (
93112 f"Unsupported script scope for retrieving default resource requirements: '{ scope } '"
94113 )
95114
115+ instance_specific_resource_requirements : Dict [str , int ] = (
116+ model_specs .hosting_instance_type_variants .get_instance_specific_resource_requirements (
117+ instance_type
118+ )
119+ if instance_type
120+ and getattr (model_specs , "hosting_instance_type_variants" , None ) is not None
121+ else {}
122+ )
123+
124+ default_resource_requirements = {
125+ ** default_resource_requirements ,
126+ ** instance_specific_resource_requirements ,
127+ }
128+
96129 if is_dynamic_container_deployment_supported :
97- requests = {}
98- if "num_accelerators" in default_resource_requirements :
99- requests ["num_accelerators" ] = default_resource_requirements ["num_accelerators" ]
100- if "min_memory_mb" in default_resource_requirements :
101- requests ["memory" ] = default_resource_requirements ["min_memory_mb" ]
102- if "num_cpus" in default_resource_requirements :
103- requests ["num_cpus" ] = default_resource_requirements ["num_cpus" ]
104-
105- limits = {}
106- if "max_memory_mb" in default_resource_requirements :
107- limits ["memory" ] = default_resource_requirements ["max_memory_mb" ]
108- return ResourceRequirements (requests = requests , limits = limits )
130+
131+ all_resource_requirement_kwargs = {}
132+
133+ for (
134+ requirement_type ,
135+ spec_field_to_resource_requirement_map ,
136+ ) in REQUIREMENT_TYPE_TO_SPEC_FIELD_NAME_TO_RESOURCE_REQUIREMENT_NAME_MAP .items ():
137+ requirement_kwargs = {}
138+ for spec_field , resource_requirement in spec_field_to_resource_requirement_map .items ():
139+ if spec_field in default_resource_requirements :
140+ requirement_kwargs [resource_requirement [0 ]] = default_resource_requirements [
141+ spec_field
142+ ]
143+
144+ all_resource_requirement_kwargs [requirement_type ] = requirement_kwargs
145+
146+ return ResourceRequirements (** all_resource_requirement_kwargs )
109147 return None
0 commit comments