Skip to content

Commit 0aba49e

Browse files
author
Jonathan Makunga
committed
Refactoring
1 parent e30b3b3 commit 0aba49e

File tree

3 files changed

+7
-12
lines changed

3 files changed

+7
-12
lines changed

src/sagemaker/serve/builder/jumpstart_builder.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -783,11 +783,10 @@ def _optimize_for_jumpstart(
783783
"AcceptEula": True
784784
}
785785

786-
if quantization_config or is_compilation:
787-
optimization_env_vars = _update_environment_variables(
788-
optimization_env_vars, override_env
789-
)
786+
optimization_env_vars = _update_environment_variables(optimization_env_vars, override_env)
787+
if optimization_env_vars:
790788
self.pysdk_model.env.update(optimization_env_vars)
789+
if quantization_config or is_compilation:
791790
return create_optimization_job_args
792791
return None
793792

src/sagemaker/serve/builder/model_builder.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1258,10 +1258,8 @@ def _model_builder_optimize_wrapper(
12581258
)
12591259

12601260
if input_args:
1261-
print(input_args)
12621261
self.sagemaker_session.sagemaker_client.create_optimization_job(**input_args)
12631262
job_status = self.sagemaker_session.wait_for_optimization_job(job_name)
1264-
print(job_status)
12651263
return _generate_optimized_model(self.pysdk_model, job_status)
12661264

12671265
self.pysdk_model.remove_tag_with_key(Tag.OPTIMIZATION_JOB_NAME)

tests/unit/sagemaker/serve/builder/test_model_builder.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -2653,22 +2653,20 @@ def test_optimize_for_hf_with_custom_s3_path(
26532653
model_metadata={
26542654
"CUSTOM_MODEL_PATH": "s3://bucket/path/",
26552655
},
2656+
role_arn="role-arn",
2657+
instance_type="ml.g5.2xlarge",
26562658
)
26572659

26582660
model_builder.pysdk_model = mock_pysdk_model
26592661

26602662
out_put = model_builder._optimize_for_hf(
26612663
job_name="job_name-123",
2662-
instance_type="ml.g5.2xlarge",
2663-
role_arn="role-arn",
26642664
quantization_config={
26652665
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
26662666
},
26672667
output_path="s3://bucket/code/",
26682668
)
26692669

2670-
print(out_put)
2671-
26722670
self.assertEqual(model_builder.role_arn, "role-arn")
26732671
self.assertEqual(model_builder.instance_type, "ml.g5.2xlarge")
26742672
self.assertEqual(model_builder.pysdk_model.env["OPTION_QUANTIZE"], "awq")
@@ -2715,14 +2713,14 @@ def test_optimize_for_hf_without_custom_s3_path(
27152713
model_builder = ModelBuilder(
27162714
model="meta-llama/Meta-Llama-3-8B-Instruct",
27172715
env_vars={"HUGGING_FACE_HUB_TOKEN": "token"},
2716+
role_arn="role-arn",
2717+
instance_type="ml.g5.2xlarge",
27182718
)
27192719

27202720
model_builder.pysdk_model = mock_pysdk_model
27212721

27222722
out_put = model_builder._optimize_for_hf(
27232723
job_name="job_name-123",
2724-
instance_type="ml.g5.2xlarge",
2725-
role_arn="role-arn",
27262724
quantization_config={
27272725
"OverrideEnvironment": {"OPTION_QUANTIZE": "awq"},
27282726
},

0 commit comments

Comments
 (0)