@@ -1234,7 +1234,11 @@ def _model_builder_optimize_wrapper(
12341234 Returns:
12351235 Model: A deployable ``Model`` object.
12361236 """
1237- if hasattr (self , "enable_network_isolation" ) and sharding_config :
1237+ if (
1238+ hasattr (self , "enable_network_isolation" )
1239+ and self .enable_network_isolation
1240+ and sharding_config
1241+ ):
12381242 raise ValueError (
12391243 "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
12401244 "Loading of model requires network access."
@@ -1272,19 +1276,25 @@ def _model_builder_optimize_wrapper(
12721276 )
12731277 )
12741278
1275- if sharding_config and (
1276- (env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars )
1277- or (
1278- sharding_config .get ("OverrideEnvironment" )
1279- and "OPTION_TENSOR_PARALLEL_DEGREE" not in sharding_config ["OverrideEnvironment" ]
1279+ if sharding_config :
1280+ has_tensor_parallel_degree_in_env_vars = (
1281+ env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" in env_vars
12801282 )
1281- ):
1282- raise ValueError (
1283- (
1284- "OPTION_TENSOR_PARALLEL_DEGREE is a required "
1285- "environment variable with sharding config."
1286- )
1283+ has_tensor_parallel_degree_in_overrides = (
1284+ sharding_config
1285+ and sharding_config .get ("OverrideEnvironment" )
1286+ and "OPTION_TENSOR_PARALLEL_DEGREE" in sharding_config .get ("OverrideEnvironment" )
12871287 )
1288+ if (
1289+ not has_tensor_parallel_degree_in_env_vars
1290+ and not has_tensor_parallel_degree_in_overrides
1291+ ):
1292+ raise ValueError (
1293+ (
1294+ "OPTION_TENSOR_PARALLEL_DEGREE is a required "
1295+ "environment variable with sharding config."
1296+ )
1297+ )
12881298
12891299 self .sagemaker_session = sagemaker_session or self .sagemaker_session or Session ()
12901300 self .instance_type = instance_type or self .instance_type
@@ -1399,10 +1409,11 @@ def _optimize_for_hf(
13991409 create_optimization_job_args ["ModelSource" ] = model_source
14001410
14011411 optimization_config , override_env = _extract_optimization_config_and_env (
1402- quantization_config , compilation_config
1412+ quantization_config , compilation_config , sharding_config
14031413 )
14041414 create_optimization_job_args ["OptimizationConfigs" ] = [optimization_config ]
1405- self .pysdk_model .env .update (override_env )
1415+ if override_env :
1416+ self .pysdk_model .env .update (override_env )
14061417
14071418 output_config = {"S3OutputLocation" : output_path }
14081419 if kms_key :
0 commit comments