Skip to content

Commit 0707798

Browse files
committed
fix rebase issues
1 parent f5b568a commit 0707798

File tree

5 files changed

+39
-10
lines changed

5 files changed

+39
-10
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1248,6 +1248,7 @@ def _model_builder_optimize_wrapper(
12481248
# TODO: ideally these dictionaries need to be sagemaker_core shapes
12491249
# TODO: for organization, abstract all validation behind this fn
12501250
_validate_optimization_configuration(
1251+
is_jumpstart=self._is_jumpstart_model_id(),
12511252
instance_type=instance_type,
12521253
quantization_config=quantization_config,
12531254
compilation_config=compilation_config,
@@ -1264,13 +1265,6 @@ def _model_builder_optimize_wrapper(
12641265
if self.mode != Mode.SAGEMAKER_ENDPOINT:
12651266
raise ValueError("Model optimization is only supported in Sagemaker Endpoint Mode.")
12661267

1267-
if sharding_config and (
1268-
quantization_config or compilation_config or speculative_decoding_config
1269-
):
1270-
raise ValueError(
1271-
"Sharding config is mutually exclusive and cannot be combined with any other optimization."
1272-
)
1273-
12741268
if sharding_config and (
12751269
quantization_config or compilation_config or speculative_decoding_config
12761270
):
@@ -1456,7 +1450,9 @@ def _optimize_for_hf(
14561450
quantization_override_env,
14571451
compilation_override_env,
14581452
sharding_override_env,
1459-
) = _extract_optimization_config_and_env(quantization_config, compilation_config)
1453+
) = _extract_optimization_config_and_env(
1454+
quantization_config, compilation_config, sharding_config
1455+
)
14601456
create_optimization_job_args["OptimizationConfigs"] = [
14611457
{k: v} for k, v in optimization_config.items()
14621458
]

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -405,6 +405,7 @@ def _extract_optimization_config_and_env(
405405

406406
return None, None, None, None
407407

408+
408409
def _custom_speculative_decoding(
409410
model: Model,
410411
speculative_decoding_config: Optional[Dict],

src/sagemaker/serve/validations/optimization.py

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,6 +104,7 @@ def validate_against(self, optimization_combination, rule_set: _OptimizationCont
104104

105105

106106
def _validate_optimization_configuration(
107+
is_jumpstart: bool,
107108
instance_type: str,
108109
quantization_config: Dict[str, Any],
109110
compilation_config: Dict[str, Any],
@@ -153,6 +154,9 @@ def _validate_optimization_configuration(
153154
and optimization_combination.speculative_decoding == {None}
154155
and optimization_combination.sharding == {None}
155156
):
157+
# JumpStart has defaults for Inf/Trn instances
158+
if is_jumpstart and instance_family in NEURON_CONFIGURATION["supported_instance_families"]:
159+
return
156160
raise ValueError(
157161
(
158162
"Optimizations that provide no optimization configs "

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2927,6 +2927,7 @@ def test_optimize_with_gpu_instance_and_llama_3_1_and_compilation(
29272927
"Compilation is not supported for Llama-3.1 with a GPU instance.",
29282928
lambda: model_builder.optimize(
29292929
job_name="job_name-123",
2930+
instance_type="ml.g5.24xlarge",
29302931
compilation_config={"OverrideEnvironment": {"OPTION_TENSOR_PARALLEL_DEGREE": "2"}},
29312932
output_path="s3://bucket/code/",
29322933
),
@@ -2975,9 +2976,10 @@ def test_optimize_with_gpu_instance_and_compilation_with_speculative_decoding(
29752976

29762977
self.assertRaisesRegex(
29772978
ValueError,
2978-
"Compilation is not supported with speculative decoding with a GPU instance.",
2979+
"Optimizations that use Compilation and Speculative Decoding are not supported for GPU instances.",
29792980
lambda: model_builder.optimize(
29802981
job_name="job_name-123",
2982+
instance_type="ml.g5.24xlarge",
29812983
speculative_decoding_config={
29822984
"ModelProvider": "custom",
29832985
"ModelSource": "s3://data-source",
@@ -3481,6 +3483,7 @@ def test_corner_cases_throw_errors(self):
34813483
ValueError,
34823484
"Optimizations that uses None instance type are not currently supported",
34833485
lambda: _validate_optimization_configuration(
3486+
is_jumpstart=False,
34843487
sharding_config={"key": "value"},
34853488
instance_type=None,
34863489
quantization_config=None,
@@ -3496,6 +3499,7 @@ def test_corner_cases_throw_errors(self):
34963499
"are currently not support on both GPU and Neuron instances."
34973500
),
34983501
lambda: _validate_optimization_configuration(
3502+
is_jumpstart=False,
34993503
instance_type="ml.g5.24xlarge",
35003504
quantization_config=None,
35013505
speculative_decoding_config=None,
@@ -3504,12 +3508,22 @@ def test_corner_cases_throw_errors(self):
35043508
),
35053509
)
35063510

3511+
_validate_optimization_configuration(
3512+
is_jumpstart=True,
3513+
instance_type="ml.inf2.xlarge",
3514+
quantization_config=None,
3515+
speculative_decoding_config=None,
3516+
compilation_config=None,
3517+
sharding_config=None,
3518+
)
3519+
35073520
def test_trt_and_vllm_configurations_throw_errors_for_rule_set(self):
35083521
# Quantization:smoothquant without compilation
35093522
self.assertRaisesRegex(
35103523
ValueError,
35113524
"Optimizations that use Quantization:smoothquant must be provided with Compilation for GPU instances.",
35123525
lambda: _validate_optimization_configuration(
3526+
is_jumpstart=False,
35133527
instance_type="ml.g5.24xlarge",
35143528
quantization_config={
35153529
"OverrideEnvironment": {"OPTION_QUANTIZE": "smoothquant"},
@@ -3525,6 +3539,7 @@ def test_trt_and_vllm_configurations_throw_errors_for_rule_set(self):
35253539
ValueError,
35263540
"Optimizations that use Quantization:test are not supported for GPU instances.",
35273541
lambda: _validate_optimization_configuration(
3542+
is_jumpstart=False,
35283543
instance_type="ml.g5.24xlarge",
35293544
quantization_config={
35303545
"OverrideEnvironment": {"OPTION_QUANTIZE": "test"},
@@ -3540,6 +3555,7 @@ def test_neuron_configurations_throw_errors_for_rule_set(self):
35403555
ValueError,
35413556
"Optimizations that use Speculative Decoding are not supported on Neuron instances.",
35423557
lambda: _validate_optimization_configuration(
3558+
is_jumpstart=False,
35433559
instance_type="ml.inf2.xlarge",
35443560
quantization_config=None,
35453561
speculative_decoding_config={"key": "value"},
@@ -3552,6 +3568,7 @@ def test_neuron_configurations_throw_errors_for_rule_set(self):
35523568
ValueError,
35533569
"Optimizations that use Sharding are not supported on Neuron instances.",
35543570
lambda: _validate_optimization_configuration(
3571+
is_jumpstart=False,
35553572
instance_type="ml.inf2.xlarge",
35563573
quantization_config=None,
35573574
speculative_decoding_config=None,
@@ -3563,6 +3580,7 @@ def test_neuron_configurations_throw_errors_for_rule_set(self):
35633580
def test_trt_configurations_rule_set(self):
35643581
# Can be compiled with quantization
35653582
_validate_optimization_configuration(
3583+
is_jumpstart=False,
35663584
instance_type="ml.g5.24xlarge",
35673585
quantization_config={
35683586
"OverrideEnvironment": {"OPTION_QUANTIZE": "smoothquant"},
@@ -3574,6 +3592,7 @@ def test_trt_configurations_rule_set(self):
35743592

35753593
# Can be just compiled
35763594
_validate_optimization_configuration(
3595+
is_jumpstart=False,
35773596
instance_type="ml.g5.24xlarge",
35783597
quantization_config=None,
35793598
sharding_config=None,
@@ -3583,6 +3602,7 @@ def test_trt_configurations_rule_set(self):
35833602

35843603
# Can be just compiled with empty dict
35853604
_validate_optimization_configuration(
3605+
is_jumpstart=False,
35863606
instance_type="ml.g5.24xlarge",
35873607
quantization_config=None,
35883608
sharding_config=None,
@@ -3593,6 +3613,7 @@ def test_trt_configurations_rule_set(self):
35933613
def test_vllm_configurations_rule_set(self):
35943614
# Can use speculative decoding
35953615
_validate_optimization_configuration(
3616+
is_jumpstart=False,
35963617
instance_type="ml.g5.24xlarge",
35973618
quantization_config=None,
35983619
sharding_config=None,
@@ -3602,6 +3623,7 @@ def test_vllm_configurations_rule_set(self):
36023623

36033624
# Can be quantized
36043625
_validate_optimization_configuration(
3626+
is_jumpstart=False,
36053627
instance_type="ml.g5.24xlarge",
36063628
quantization_config={
36073629
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
@@ -3613,6 +3635,7 @@ def test_vllm_configurations_rule_set(self):
36133635

36143636
# Can be sharded
36153637
_validate_optimization_configuration(
3638+
is_jumpstart=False,
36163639
instance_type="ml.g5.24xlarge",
36173640
quantization_config=None,
36183641
sharding_config={"key": "value"},
@@ -3623,6 +3646,7 @@ def test_vllm_configurations_rule_set(self):
36233646
def test_neuron_configurations_rule_set(self):
36243647
# Can be compiled
36253648
_validate_optimization_configuration(
3649+
is_jumpstart=False,
36263650
instance_type="ml.inf2.xlarge",
36273651
quantization_config=None,
36283652
sharding_config=None,
@@ -3632,6 +3656,7 @@ def test_neuron_configurations_rule_set(self):
36323656

36333657
# Can be compiled with empty dict
36343658
_validate_optimization_configuration(
3659+
is_jumpstart=False,
36353660
instance_type="ml.inf2.xlarge",
36363661
quantization_config=None,
36373662
sharding_config=None,

tests/unit/sagemaker/serve/utils/test_optimize_utils.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -284,7 +284,10 @@ def test_is_draft_model_gated(draft_model_config, expected):
284284

285285

286286
@pytest.mark.parametrize(
287-
"quantization_config, compilation_config, sharding_config, expected_config, expected_quant_env, expected_compilation_env, expected_sharding_env",
287+
(
288+
"quantization_config, compilation_config, sharding_config, expected_config, "
289+
"expected_quant_env, expected_compilation_env, expected_sharding_env"
290+
),
288291
[
289292
(
290293
None,

0 commit comments

Comments
 (0)