Skip to content

Commit e30b3b3

Browse files
author
Jonathan Makunga
committed
Follow-ups fixes
1 parent 0971c55 commit e30b3b3

File tree

4 files changed

+23
-69
lines changed

4 files changed

+23
-69
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 8 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -668,8 +668,6 @@ def _build_for_jumpstart(self):
668668
def _optimize_for_jumpstart(
669669
self,
670670
output_path: Optional[str] = None,
671-
instance_type: Optional[str] = None,
672-
role_arn: Optional[str] = None,
673671
tags: Optional[Tags] = None,
674672
job_name: Optional[str] = None,
675673
accept_eula: Optional[bool] = None,
@@ -685,9 +683,6 @@ def _optimize_for_jumpstart(
685683
686684
Args:
687685
output_path (Optional[str]): Specifies where to store the compiled/quantized model.
688-
instance_type (Optional[str]): Target deployment instance type that
689-
the model is optimized for.
690-
role_arn (Optional[str]): Execution role. Defaults to ``None``.
691686
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
692687
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
693688
accept_eula (bool): For models that require a Model Access Config, specify True or
@@ -715,13 +710,13 @@ def _optimize_for_jumpstart(
715710
f"Model '{self.model}' requires accepting end-user license agreement (EULA)."
716711
)
717712

718-
is_compilation = (quantization_config is None) and (
719-
(compilation_config is not None) or _is_inferentia_or_trainium(instance_type)
713+
is_compilation = (not quantization_config) and (
714+
(compilation_config is not None) or _is_inferentia_or_trainium(self.instance_type)
720715
)
721716

722717
pysdk_model_env_vars = dict()
723718
if is_compilation:
724-
pysdk_model_env_vars = self._get_neuron_model_env_vars(instance_type)
719+
pysdk_model_env_vars = self._get_neuron_model_env_vars(self.instance_type)
725720

726721
optimization_config, override_env = _extract_optimization_config_and_env(
727722
quantization_config, compilation_config
@@ -757,8 +752,9 @@ def _optimize_for_jumpstart(
757752
if self.pysdk_model.deployment_config
758753
else None
759754
)
760-
self.instance_type = instance_type or deployment_config_instance_type or _get_nb_instance()
761-
self.role_arn = role_arn or self.role_arn
755+
self.instance_type = (
756+
self.instance_type or deployment_config_instance_type or _get_nb_instance()
757+
)
762758

763759
create_optimization_job_args = {
764760
"OptimizationJobName": job_name,
@@ -788,9 +784,10 @@ def _optimize_for_jumpstart(
788784
}
789785

790786
if quantization_config or is_compilation:
791-
self.pysdk_model.env = _update_environment_variables(
787+
optimization_env_vars = _update_environment_variables(
792788
optimization_env_vars, override_env
793789
)
790+
self.pysdk_model.env.update(optimization_env_vars)
794791
return create_optimization_job_args
795792
return None
796793

src/sagemaker/serve/builder/model_builder.py

Lines changed: 15 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -73,7 +73,6 @@
7373
_generate_model_source,
7474
_extract_optimization_config_and_env,
7575
_is_s3_uri,
76-
_normalize_local_model_path,
7776
_custom_speculative_decoding,
7877
_extract_speculative_draft_model_provider,
7978
)
@@ -833,6 +832,8 @@ def build( # pylint: disable=R0911
833832
# until we deprecate HUGGING_FACE_HUB_TOKEN.
834833
if self.env_vars.get("HUGGING_FACE_HUB_TOKEN") and not self.env_vars.get("HF_TOKEN"):
835834
self.env_vars["HF_TOKEN"] = self.env_vars.get("HUGGING_FACE_HUB_TOKEN")
835+
elif self.env_vars.get("HF_TOKEN") and not self.env_vars.get("HUGGING_FACE_HUB_TOKEN"):
836+
self.env_vars["HUGGING_FACE_HUB_TOKEN"] = self.env_vars.get("HF_TOKEN")
836837

837838
self.sagemaker_session.settings._local_download_dir = self.model_path
838839

@@ -851,7 +852,9 @@ def build( # pylint: disable=R0911
851852

852853
self._build_validations()
853854

854-
if not self._is_jumpstart_model_id() and self.model_server:
855+
if (
856+
not (isinstance(self.model, str) and self._is_jumpstart_model_id())
857+
) and self.model_server:
855858
return self._build_for_model_server()
856859

857860
if isinstance(self.model, str):
@@ -1216,18 +1219,14 @@ def _model_builder_optimize_wrapper(
12161219
raise ValueError("Quantization config and compilation config are mutually exclusive.")
12171220

12181221
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
1219-
12201222
self.instance_type = instance_type or self.instance_type
12211223
self.role_arn = role_arn or self.role_arn
12221224

1223-
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
12241225
job_name = job_name or f"modelbuilderjob-{uuid.uuid4().hex}"
1225-
12261226
if self._is_jumpstart_model_id():
1227+
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
12271228
input_args = self._optimize_for_jumpstart(
12281229
output_path=output_path,
1229-
instance_type=instance_type,
1230-
role_arn=self.role_arn,
12311230
tags=tags,
12321231
job_name=job_name,
12331232
accept_eula=accept_eula,
@@ -1240,10 +1239,13 @@ def _model_builder_optimize_wrapper(
12401239
max_runtime_in_sec=max_runtime_in_sec,
12411240
)
12421241
else:
1242+
if self.model_server != ModelServer.DJL_SERVING:
1243+
logger.info("Overriding model server to DJL_SERVING.")
1244+
self.model_server = ModelServer.DJL_SERVING
1245+
1246+
self.build(mode=self.mode, sagemaker_session=self.sagemaker_session)
12431247
input_args = self._optimize_for_hf(
12441248
output_path=output_path,
1245-
instance_type=instance_type,
1246-
role_arn=self.role_arn,
12471249
tags=tags,
12481250
job_name=job_name,
12491251
quantization_config=quantization_config,
@@ -1256,8 +1258,10 @@ def _model_builder_optimize_wrapper(
12561258
)
12571259

12581260
if input_args:
1261+
print(input_args)
12591262
self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args)
12601263
job_status = self.sagemaker_session.wait_for_optimization_job(job_name)
1264+
print(job_status)
12611265
return _generate_optimized_model(self.pysdk_model, job_status)
12621266

12631267
self.pysdk_model.remove_tag_with_key(Tag.OPTIMIZATION_JOB_NAME)
@@ -1269,8 +1273,6 @@ def _model_builder_optimize_wrapper(
12691273
def _optimize_for_hf(
12701274
self,
12711275
output_path: str,
1272-
instance_type: Optional[str] = None,
1273-
role_arn: Optional[str] = None,
12741276
tags: Optional[Tags] = None,
12751277
job_name: Optional[str] = None,
12761278
quantization_config: Optional[Dict] = None,
@@ -1285,9 +1287,6 @@ def _optimize_for_hf(
12851287
12861288
Args:
12871289
output_path (str): Specifies where to store the compiled/quantized model.
1288-
instance_type (Optional[str]): Target deployment instance type that
1289-
the model is optimized for.
1290-
role_arn (Optional[str]): Execution role. Defaults to ``None``.
12911290
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
12921291
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
12931292
quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``.
@@ -1305,13 +1304,6 @@ def _optimize_for_hf(
13051304
Returns:
13061305
Optional[Dict[str, Any]]: Model optimization job input arguments.
13071306
"""
1308-
if self.model_server != ModelServer.DJL_SERVING:
1309-
logger.info("Overwriting model server to DJL.")
1310-
self.model_server = ModelServer.DJL_SERVING
1311-
1312-
self.role_arn = role_arn or self.role_arn
1313-
self.instance_type = instance_type or self.instance_type
1314-
13151307
self.pysdk_model = _custom_speculative_decoding(
13161308
self.pysdk_model, speculative_decoding_config, False
13171309
)
@@ -1371,13 +1363,12 @@ def _optimize_prepare_for_hf(self):
13711363
)
13721364
else:
13731365
if not custom_model_path:
1374-
custom_model_path = f"/tmp/sagemaker/model-builder/{self.model}/code"
1366+
custom_model_path = f"/tmp/sagemaker/model-builder/{self.model}"
13751367
download_huggingface_model_metadata(
13761368
self.model,
1377-
custom_model_path,
1369+
os.path.join(custom_model_path, "code"),
13781370
self.env_vars.get("HUGGING_FACE_HUB_TOKEN"),
13791371
)
1380-
custom_model_path = _normalize_local_model_path(custom_model_path)
13811372

13821373
self.pysdk_model.model_data, env = self._prepare_for_mode(
13831374
model_path=custom_model_path,

src/sagemaker/serve/utils/optimize_utils.py

Lines changed: 0 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -282,26 +282,6 @@ def _extract_optimization_config_and_env(
282282
return None, None
283283

284284

285-
def _normalize_local_model_path(local_model_path: Optional[str]) -> Optional[str]:
286-
"""Normalizes the local model path.
287-
288-
Args:
289-
local_model_path (Optional[str]): The local model path.
290-
291-
Returns:
292-
Optional[str]: The normalized model path.
293-
"""
294-
if local_model_path is None:
295-
return local_model_path
296-
297-
# Removes /code or /code/ path at the end of local_model_path,
298-
# as it is appended during artifacts upload.
299-
pattern = r"/code/?$"
300-
if re.search(pattern, local_model_path):
301-
return re.sub(pattern, "", local_model_path)
302-
return local_model_path
303-
304-
305285
def _custom_speculative_decoding(
306286
model: Model,
307287
speculative_decoding_config: Optional[Dict],

tests/unit/sagemaker/serve/utils/test_optimize_utils.py

Lines changed: 0 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -28,7 +28,6 @@
2828
_generate_additional_model_data_sources,
2929
_generate_channel_name,
3030
_extract_optimization_config_and_env,
31-
_normalize_local_model_path,
3231
_is_optimized,
3332
_custom_speculative_decoding,
3433
_is_inferentia_or_trainium,
@@ -312,19 +311,6 @@ def test_extract_optimization_config_and_env(
312311
)
313312

314313

315-
@pytest.mark.parametrize(
316-
"my_path, expected_path",
317-
[
318-
("local/path/llama/code", "local/path/llama"),
319-
("local/path/llama/code/", "local/path/llama"),
320-
("local/path/llama/", "local/path/llama/"),
321-
("local/path/llama", "local/path/llama"),
322-
],
323-
)
324-
def test_normalize_local_model_path(my_path, expected_path):
325-
assert _normalize_local_model_path(my_path) == expected_path
326-
327-
328314
class TestCustomSpeculativeDecodingConfig(unittest.TestCase):
329315

330316
@patch("sagemaker.model.Model")

0 commit comments

Comments
 (0)