Skip to content

Commit 2184f7a

Browse files
committed
add more UTs for sharding
1 parent bb4e15b commit 2184f7a

File tree

2 files changed

+499
-15
lines changed

2 files changed

+499
-15
lines changed

src/sagemaker/serve/builder/model_builder.py

Lines changed: 25 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1234,7 +1234,11 @@ def _model_builder_optimize_wrapper(
12341234
Returns:
12351235
Model: A deployable ``Model`` object.
12361236
"""
1237-
if hasattr(self, "enable_network_isolation") and sharding_config:
1237+
if (
1238+
hasattr(self, "enable_network_isolation")
1239+
and self.enable_network_isolation
1240+
and sharding_config
1241+
):
12381242
raise ValueError(
12391243
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
12401244
"Loading of model requires network access."
@@ -1272,19 +1276,25 @@ def _model_builder_optimize_wrapper(
12721276
)
12731277
)
12741278

1275-
if sharding_config and (
1276-
(env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars)
1277-
or (
1278-
sharding_config.get("OverrideEnvironment")
1279-
and "OPTION_TENSOR_PARALLEL_DEGREE" not in sharding_config["OverrideEnvironment"]
1279+
if sharding_config:
1280+
has_tensor_parallel_degree_in_env_vars = (
1281+
env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" in env_vars
12801282
)
1281-
):
1282-
raise ValueError(
1283-
(
1284-
"OPTION_TENSOR_PARALLEL_DEGREE is a required "
1285-
"environment variable with sharding config."
1286-
)
1283+
has_tensor_parallel_degree_in_overrides = (
1284+
sharding_config
1285+
and sharding_config.get("OverrideEnvironment")
1286+
and "OPTION_TENSOR_PARALLEL_DEGREE" in sharding_config.get("OverrideEnvironment")
12871287
)
1288+
if (
1289+
not has_tensor_parallel_degree_in_env_vars
1290+
and not has_tensor_parallel_degree_in_overrides
1291+
):
1292+
raise ValueError(
1293+
(
1294+
"OPTION_TENSOR_PARALLEL_DEGREE is a required "
1295+
"environment variable with sharding config."
1296+
)
1297+
)
12881298

12891299
self.sagemaker_session = sagemaker_session or self.sagemaker_session or Session()
12901300
self.instance_type = instance_type or self.instance_type
@@ -1399,10 +1409,11 @@ def _optimize_for_hf(
13991409
create_optimization_job_args["ModelSource"] = model_source
14001410

14011411
optimization_config, override_env = _extract_optimization_config_and_env(
1402-
quantization_config, compilation_config
1412+
quantization_config, compilation_config, sharding_config
14031413
)
14041414
create_optimization_job_args["OptimizationConfigs"] = [optimization_config]
1405-
self.pysdk_model.env.update(override_env)
1415+
if override_env:
1416+
self.pysdk_model.env.update(override_env)
14061417

14071418
output_config = {"S3OutputLocation": output_path}
14081419
if kms_key:

0 commit comments

Comments
 (0)