Skip to content

Commit 3d04384

Browse files
committed
fix formatting and msging
1 parent bb4a718 commit 3d04384

File tree

6 files changed

+129
-84
lines changed

6 files changed

+129
-84
lines changed

src/sagemaker/model.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1601,8 +1601,10 @@ def deploy(
16011601
self._base_name = "-".join((self._base_name, compiled_model_suffix))
16021602

16031603
if self._is_sharded_model and endpoint_type != EndpointType.INFERENCE_COMPONENT_BASED:
1604-
logging.warning("Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
1605-
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints.")
1604+
logging.warning(
1605+
"Forcing INFERENCE_COMPONENT_BASED endpoint for sharded model. ADVISORY - "
1606+
"Use INFERENCE_COMPONENT_BASED endpoints over MODEL_BASED endpoints."
1607+
)
16061608
endpoint_type = EndpointType.INFERENCE_COMPONENT_BASED
16071609

16081610
# Support multiple models on same endpoint

src/sagemaker/serve/builder/model_builder.py

Lines changed: 34 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -104,7 +104,7 @@
104104
get_huggingface_model_metadata,
105105
download_huggingface_model_metadata,
106106
)
107-
from sagemaker.serve.validations.optimization import validate_optimization_configuration
107+
from sagemaker.serve.validations.optimization import _validate_optimization_configuration
108108

109109
logger = logging.getLogger(__name__)
110110

@@ -1161,15 +1161,6 @@ def optimize(
11611161
Model: A deployable ``Model`` object.
11621162
"""
11631163

1164-
# TODO: ideally these dictionaries need to be sagemaker_core shapes
1165-
validate_optimization_configuration(
1166-
instance_type=instance_type,
1167-
quantization_config=quantization_config,
1168-
compilation_config=compilation_config,
1169-
sharding_config=sharding_config,
1170-
speculative_decoding_config=speculative_decoding_config,
1171-
)
1172-
11731164
# need to get telemetry_opt_out info before telemetry decorator is called
11741165
self.serve_settings = self._get_serve_setting()
11751166

@@ -1243,6 +1234,17 @@ def _model_builder_optimize_wrapper(
12431234
Returns:
12441235
Model: A deployable ``Model`` object.
12451236
"""
1237+
1238+
# TODO: ideally these dictionaries need to be sagemaker_core shapes
1239+
# TODO: for organization, abstract all validation behind this fn
1240+
_validate_optimization_configuration(
1241+
instance_type=instance_type,
1242+
quantization_config=quantization_config,
1243+
compilation_config=compilation_config,
1244+
sharding_config=sharding_config,
1245+
speculative_decoding_config=speculative_decoding_config,
1246+
)
1247+
12461248
self.is_compiled = compilation_config is not None
12471249
self.is_quantized = quantization_config is not None
12481250
self.speculative_decoding_draft_model_source = _extract_speculative_draft_model_provider(
@@ -1255,11 +1257,29 @@ def _model_builder_optimize_wrapper(
12551257
if quantization_config and compilation_config:
12561258
raise ValueError("Quantization config and compilation config are mutually exclusive.")
12571259

1258-
if sharding_config and (quantization_config or compilation_config or speculative_decoding_config):
1259-
raise ValueError("Sharding config is mutually exclusive and cannot be combined with any other optimization.")
1260+
if sharding_config and (
1261+
quantization_config or compilation_config or speculative_decoding_config
1262+
):
1263+
raise ValueError(
1264+
(
1265+
"Sharding config is mutually exclusive "
1266+
"and cannot be combined with any other optimization."
1267+
)
1268+
)
12601269

1261-
if sharding_config and ((env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars) or (sharding_config.get("OverrideEnvironment") and "OPTION_TENSOR_PARALLEL_DEGREE" not in sharding_config["OverrideEnvironment"])):
1262-
raise ValueError("OPTION_TENSOR_PARALLEL_DEGREE is required environment variable with Sharding config.")
1270+
if sharding_config and (
1271+
(env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars)
1272+
or (
1273+
sharding_config.get("OverrideEnvironment")
1274+
and "OPTION_TENSOR_PARALLEL_DEGREE" not in sharding_config["OverrideEnvironment"]
1275+
)
1276+
):
1277+
raise ValueError(
1278+
(
1279+
"OPTION_TENSOR_PARALLEL_DEGREE is required "
1280+
"environment variable with Sharding config."
1281+
)
1282+
)
12631283

12641284
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
12651285
self.instance_type = instance_type or self.instance_type

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 4 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -259,8 +259,9 @@ def _is_s3_uri(s3_uri: Optional[str]) -> bool:
259259

260260

261261
def _extract_optimization_config_and_env(
262-
quantization_config: Optional[Dict] = None, compilation_config: Optional[Dict] = None,
263-
sharding_config: Optional[Dict] = None
262+
quantization_config: Optional[Dict] = None,
263+
compilation_config: Optional[Dict] = None,
264+
sharding_config: Optional[Dict] = None,
264265
) -> Optional[Tuple[Optional[Dict], Optional[Dict]]]:
265266
"""Extracts optimization config and environment variables.
266267
@@ -282,9 +283,7 @@ def _extract_optimization_config_and_env(
282283
"OverrideEnvironment"
283284
)
284285
if sharding_config:
285-
return {"ModelShardingConfig": sharding_config}, sharding_config.get(
286-
"OverrideEnvironment"
287-
)
286+
return {"ModelShardingConfig": sharding_config}, sharding_config.get("OverrideEnvironment")
288287
return None, None
289288

290289

src/sagemaker/serve/validations/optimization.py

Lines changed: 81 additions & 61 deletions
Original file line numberDiff line numberDiff line change
@@ -10,155 +10,175 @@
1010
# distributed on an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF
1111
# ANY KIND, either express or implied. See the License for the specific
1212
# language governing permissions and limitations under the License.
13-
"""Holds the validation logic used for the .optimize() function"""
13+
"""Holds the validation logic used for the .optimize() function. INTERNAL only"""
14+
from __future__ import absolute_import
15+
16+
import textwrap
17+
import logging
1418
from typing import Any, Dict, Set
1519
from enum import Enum
1620
from pydantic import BaseModel
17-
import textwrap
18-
import logging
1921

2022
logger = logging.getLogger(__name__)
2123

2224

23-
class OptimizationContainer(Enum):
25+
class _OptimizationContainer(Enum):
26+
"""Optimization containers"""
27+
2428
TRT = "trt"
2529
VLLM = "vllm"
2630
NEURON = "neuron"
2731

2832

29-
class OptimizationCombination(BaseModel):
30-
optimization_container: OptimizationContainer = None
33+
class _OptimizationCombination(BaseModel):
34+
"""Optimization ruleset data structure for comparing input to ruleset"""
35+
36+
optimization_container: _OptimizationContainer = None
3137
compilation: bool
3238
speculative_decoding: bool
3339
sharding: bool
3440
quantization_technique: Set[str | None]
3541

36-
def validate_against(self, optimization_combination, rule_set: OptimizationContainer):
42+
def validate_against(self, optimization_combination, rule_set: _OptimizationContainer):
43+
"""Validator for optimization containers"""
44+
3745
if not optimization_combination.compilation == self.compilation:
38-
raise ValueError("model compilation is not supported")
39-
if not optimization_combination.quantization_technique.issubset(self.quantization_technique):
40-
raise ValueError("model quantization is not supported")
46+
raise ValueError("Compilation")
47+
if not optimization_combination.quantization_technique.issubset(
48+
self.quantization_technique
49+
):
50+
raise ValueError(
51+
f"Quantization:{optimization_combination.quantization_technique.pop()}"
52+
)
4153
if not optimization_combination.speculative_decoding == self.speculative_decoding:
42-
raise ValueError("speculative decoding is not supported")
54+
raise ValueError("Speculative Decoding")
4355
if not optimization_combination.sharding == self.sharding:
44-
raise ValueError("model sharding is not supported")
56+
raise ValueError("Sharding")
4557

46-
if rule_set == OptimizationContainer == OptimizationContainer.TRT:
47-
if optimization_combination.compilation and optimization_combination.speculative_decoding:
48-
raise ValueError("model compilation and speculative decoding provided together ")
58+
if rule_set == _OptimizationContainer == _OptimizationContainer.TRT:
59+
if (
60+
optimization_combination.compilation
61+
and optimization_combination.speculative_decoding
62+
):
63+
raise ValueError("Compilation and Speculative Decoding")
4964
else:
50-
if optimization_combination.compilation and optimization_combination.quantization_technique:
51-
raise ValueError("model compilation and model quantization provided together is not supported")
65+
if (
66+
optimization_combination.compilation
67+
and optimization_combination.quantization_technique
68+
):
69+
raise ValueError(
70+
f"Compilation and Quantization:{optimization_combination.quantization_technique.pop()}"
71+
)
5272

5373

5474
TRT_CONFIGURATION = {
5575
"supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"},
56-
"optimization_combination": OptimizationCombination(
57-
optimization_container=OptimizationContainer.TRT,
76+
"optimization_combination": _OptimizationCombination(
77+
optimization_container=_OptimizationContainer.TRT,
5878
compilation=True,
5979
quantization_technique={"awq", "fp8", "smooth_quant"},
6080
speculative_decoding=False,
6181
sharding=False,
62-
)
82+
),
6383
}
6484
VLLM_CONFIGURATION = {
6585
"supported_instance_families": {"p4d", "p4de", "p5", "g5", "g6"},
66-
"optimization_combination": OptimizationCombination(
67-
optimization_container=OptimizationContainer.VLLM,
86+
"optimization_combination": _OptimizationCombination(
87+
optimization_container=_OptimizationContainer.VLLM,
6888
compilation=False,
6989
quantization_technique={"awq", "fp8"},
7090
speculative_decoding=True,
71-
sharding=True
72-
)
91+
sharding=True,
92+
),
7393
}
7494
NEURON_CONFIGURATION = {
7595
"supported_instance_families": {"inf2", "trn1", "trn1n"},
76-
"optimization_combination": OptimizationCombination(
77-
optimization_container=OptimizationContainer.NEURON,
96+
"optimization_combination": _OptimizationCombination(
97+
optimization_container=_OptimizationContainer.NEURON,
7898
compilation=True,
7999
quantization_technique=set(),
80100
speculative_decoding=False,
81-
sharding=False
82-
)
101+
sharding=False,
102+
),
83103
}
84104

85105
VALIDATION_ERROR_MSG = (
86-
"The model cannot be optimized with the provided configurations on "
87-
"{optimization_container} supported {instance_type} because {validation_error}."
106+
"Optimizations that use {optimization_technique} "
107+
"are not currently supported on {instance_type} instances"
88108
)
89109

90110

91-
def validate_optimization_configuration(
111+
def _validate_optimization_configuration(
92112
instance_type: str,
93113
quantization_config: Dict[str, Any],
94114
compilation_config: Dict[str, Any],
95115
sharding_config: Dict[str, Any],
96-
speculative_decoding_config: Dict[str, Any]
116+
speculative_decoding_config: Dict[str, Any],
97117
):
118+
"""Validate .optimize() input off of standard ruleset"""
119+
98120
split_instance_type = instance_type.split(".")
99121
instance_family = None
100122
if len(split_instance_type) == 3: # invalid instance type will be caught below
101123
instance_family = split_instance_type[1]
102124

103125
if (
104-
not instance_family in TRT_CONFIGURATION["supported_instance_families"] and
105-
not instance_family in VLLM_CONFIGURATION["supported_instance_families"] and
106-
not instance_family in NEURON_CONFIGURATION["supported_instance_families"]
126+
instance_family not in TRT_CONFIGURATION["supported_instance_families"]
127+
and instance_family not in VLLM_CONFIGURATION["supported_instance_families"]
128+
and instance_family not in NEURON_CONFIGURATION["supported_instance_families"]
107129
):
108-
invalid_instance_type_msg = f"""
109-
The model cannot be optimized on {instance_type}. Please optimize on the following instance type families:
110-
- For {OptimizationContainer.TRT} optimized container: {TRT_CONFIGURATION["supported_instance_families"]}
111-
- For {OptimizationContainer.VLLM} optimized container: {VLLM_CONFIGURATION["supported_instance_families"]}
112-
- For {OptimizationContainer.NEURON} optimized container: {NEURON_CONFIGURATION["supported_instance_families"]}
113-
"""
114-
raise ValueError(textwrap.dedent(invalid_instance_type_msg))
115-
116-
optimization_combination = OptimizationCombination(
130+
invalid_instance_type_msg = (
131+
f"Optimizations that use {instance_type} are not currently supported"
132+
)
133+
raise ValueError(invalid_instance_type_msg)
134+
135+
optimization_combination = _OptimizationCombination(
117136
compilation=not compilation_config,
118137
speculative_decoding=not speculative_decoding_config,
119138
sharding=not sharding_config,
120-
quantization_technique={quantization_config.get("OPTION_QUANTIZE") if quantization_config else None}
139+
quantization_technique={
140+
quantization_config.get("OPTION_QUANTIZE") if quantization_config else None
141+
},
121142
)
122143

123144
if instance_type in NEURON_CONFIGURATION["supported_instance_families"]:
124145
try:
125146
(
126-
NEURON_CONFIGURATION["optimization_combination"]
127-
.validate_against(optimization_combination, rule_set=OptimizationContainer.VLLM)
147+
NEURON_CONFIGURATION["optimization_combination"].validate_against(
148+
optimization_combination, rule_set=_OptimizationContainer.VLLM
149+
)
128150
)
129151
except ValueError as neuron_compare_error:
130152
raise ValueError(
131153
VALIDATION_ERROR_MSG.format(
132-
optimization_container=OptimizationContainer.NEURON.value,
133-
instance_type=instance_type,
134-
validation_error=neuron_compare_error
154+
optimization_container=str(neuron_compare_error),
155+
instance_type="Neuron",
135156
)
136157
)
137158
else:
138159
try:
139160
(
140-
TRT_CONFIGURATION["optimization_combination"]
141-
.validate_against(optimization_combination, rule_set=OptimizationContainer.TRT)
161+
TRT_CONFIGURATION["optimization_combination"].validate_against(
162+
optimization_combination, rule_set=_OptimizationContainer.TRT
163+
)
142164
)
143165
except ValueError as trt_compare_error:
144166
try:
145167
(
146-
VLLM_CONFIGURATION["optimization_combination"]
147-
.validate_against(optimization_combination, rule_set=OptimizationContainer.VLLM)
168+
VLLM_CONFIGURATION["optimization_combination"].validate_against(
169+
optimization_combination, rule_set=_OptimizationContainer.VLLM
170+
)
148171
)
149172
except ValueError as vllm_compare_error:
150173
trt_error_msg = VALIDATION_ERROR_MSG.format(
151-
optimization_container=OptimizationContainer.TRT.value,
152-
instance_type=instance_type,
153-
validation_error=trt_compare_error
174+
optimization_container=str(trt_compare_error), instance_type="GPU"
154175
)
155176
vllm_error_msg = VALIDATION_ERROR_MSG.format(
156-
optimization_container=OptimizationContainer.VLLM.value,
157-
instance_type=instance_type,
158-
validation_error=vllm_compare_error
177+
optimization_container=str(vllm_compare_error),
178+
instance_type="GPU",
159179
)
160180
joint_error_msg = f"""
161-
The model cannot be optimized for the following reasons:
181+
Optimization cannot be performed for the following reasons:
162182
- {trt_error_msg}
163183
- {vllm_error_msg}
164184
"""

tests/unit/sagemaker/model/test_model.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -958,6 +958,7 @@ def test_all_framework_models_inference_component_based_endpoint_deploy_path(
958958
sagemaker_session.endpoint_in_service_or_not.reset_mock()
959959
sagemaker_session.create_model.reset_mock()
960960

961+
961962
@patch("sagemaker.utils.repack_model")
962963
@patch("sagemaker.fw_utils.tar_and_upload_dir")
963964
def test_sharded_model_force_inference_component_based_endpoint_deploy_path(
@@ -967,7 +968,7 @@ def test_sharded_model_force_inference_component_based_endpoint_deploy_path(
967968
HuggingFaceModel: {
968969
"pytorch_version": "1.7.1",
969970
"py_version": "py36",
970-
"transformers_version": "4.6.1"
971+
"transformers_version": "4.6.1",
971972
},
972973
}
973974

@@ -1007,6 +1008,7 @@ def test_sharded_model_force_inference_component_based_endpoint_deploy_path(
10071008
sagemaker_session.endpoint_in_service_or_not.reset_mock()
10081009
sagemaker_session.create_model.reset_mock()
10091010

1011+
10101012
@patch("sagemaker.utils.repack_model")
10111013
def test_repack_code_location_with_key_prefix(repack_model, sagemaker_session):
10121014

tests/unit/sagemaker/serve/utils/test_optimize_utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -326,7 +326,9 @@ def test_is_s3_uri(s3_uri, expected):
326326
def test_extract_optimization_config_and_env(
327327
quantization_config, compilation_config, sharding_config, expected_config, expected_env
328328
):
329-
assert _extract_optimization_config_and_env(quantization_config, compilation_config, sharding_config) == (
329+
assert _extract_optimization_config_and_env(
330+
quantization_config, compilation_config, sharding_config
331+
) == (
330332
expected_config,
331333
expected_env,
332334
)

0 commit comments

Comments
 (0)