@@ -1234,7 +1234,11 @@ def _model_builder_optimize_wrapper(
1234
1234
Returns:
1235
1235
Model: A deployable ``Model`` object.
1236
1236
"""
1237
- if hasattr (self , "enable_network_isolation" ) and sharding_config :
1237
+ if (
1238
+ hasattr (self , "enable_network_isolation" )
1239
+ and self .enable_network_isolation
1240
+ and sharding_config
1241
+ ):
1238
1242
raise ValueError (
1239
1243
"EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
1240
1244
"Loading of model requires network access."
@@ -1272,19 +1276,25 @@ def _model_builder_optimize_wrapper(
1272
1276
)
1273
1277
)
1274
1278
1275
- if sharding_config and (
1276
- (env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" not in env_vars )
1277
- or (
1278
- sharding_config .get ("OverrideEnvironment" )
1279
- and "OPTION_TENSOR_PARALLEL_DEGREE" not in sharding_config ["OverrideEnvironment" ]
1279
+ if sharding_config :
1280
+ has_tensor_parallel_degree_in_env_vars = (
1281
+ env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" in env_vars
1280
1282
)
1281
- ):
1282
- raise ValueError (
1283
- (
1284
- "OPTION_TENSOR_PARALLEL_DEGREE is a required "
1285
- "environment variable with sharding config."
1286
- )
1283
+ has_tensor_parallel_degree_in_overrides = (
1284
+ sharding_config
1285
+ and sharding_config .get ("OverrideEnvironment" )
1286
+ and "OPTION_TENSOR_PARALLEL_DEGREE" in sharding_config .get ("OverrideEnvironment" )
1287
1287
)
1288
+ if (
1289
+ not has_tensor_parallel_degree_in_env_vars
1290
+ and not has_tensor_parallel_degree_in_overrides
1291
+ ):
1292
+ raise ValueError (
1293
+ (
1294
+ "OPTION_TENSOR_PARALLEL_DEGREE is a required "
1295
+ "environment variable with sharding config."
1296
+ )
1297
+ )
1288
1298
1289
1299
self .sagemaker_session = sagemaker_session or self .sagemaker_session or Session ()
1290
1300
self .instance_type = instance_type or self .instance_type
@@ -1399,10 +1409,11 @@ def _optimize_for_hf(
1399
1409
create_optimization_job_args ["ModelSource" ] = model_source
1400
1410
1401
1411
optimization_config , override_env = _extract_optimization_config_and_env (
1402
- quantization_config , compilation_config
1412
+ quantization_config , compilation_config , sharding_config
1403
1413
)
1404
1414
create_optimization_job_args ["OptimizationConfigs" ] = [optimization_config ]
1405
- self .pysdk_model .env .update (override_env )
1415
+ if override_env :
1416
+ self .pysdk_model .env .update (override_env )
1406
1417
1407
1418
output_config = {"S3OutputLocation" : output_path }
1408
1419
if kms_key :
0 commit comments