105
105
get_huggingface_model_metadata ,
106
106
download_huggingface_model_metadata ,
107
107
)
108
+ from sagemaker .serve .validations .optimization import _validate_optimization_configuration
108
109
109
110
logger = logging .getLogger (__name__ )
110
111
@@ -1120,6 +1121,7 @@ def optimize(
1120
1121
quantization_config : Optional [Dict ] = None ,
1121
1122
compilation_config : Optional [Dict ] = None ,
1122
1123
speculative_decoding_config : Optional [Dict ] = None ,
1124
+ sharding_config : Optional [Dict ] = None ,
1123
1125
env_vars : Optional [Dict ] = None ,
1124
1126
vpc_config : Optional [Dict ] = None ,
1125
1127
kms_key : Optional [str ] = None ,
@@ -1143,6 +1145,8 @@ def optimize(
1143
1145
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
1144
1146
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
1145
1147
Defaults to ``None``
1148
+ sharding_config (Optional[Dict]): Model sharding configuration.
1149
+ Defaults to ``None``
1146
1150
env_vars (Optional[Dict]): Additional environment variables to run the optimization
1147
1151
container. Defaults to ``None``.
1148
1152
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1171,6 +1175,7 @@ def optimize(
1171
1175
quantization_config = quantization_config ,
1172
1176
compilation_config = compilation_config ,
1173
1177
speculative_decoding_config = speculative_decoding_config ,
1178
+ sharding_config = sharding_config ,
1174
1179
env_vars = env_vars ,
1175
1180
vpc_config = vpc_config ,
1176
1181
kms_key = kms_key ,
@@ -1190,6 +1195,7 @@ def _model_builder_optimize_wrapper(
1190
1195
quantization_config : Optional [Dict ] = None ,
1191
1196
compilation_config : Optional [Dict ] = None ,
1192
1197
speculative_decoding_config : Optional [Dict ] = None ,
1198
+ sharding_config : Optional [Dict ] = None ,
1193
1199
env_vars : Optional [Dict ] = None ,
1194
1200
vpc_config : Optional [Dict ] = None ,
1195
1201
kms_key : Optional [str ] = None ,
@@ -1213,6 +1219,8 @@ def _model_builder_optimize_wrapper(
1213
1219
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
1214
1220
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
1215
1221
Defaults to ``None``
1222
+ sharding_config (Optional[Dict]): Model sharding configuration.
1223
+ Defaults to ``None``
1216
1224
env_vars (Optional[Dict]): Additional environment variables to run the optimization
1217
1225
container. Defaults to ``None``.
1218
1226
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1227,6 +1235,26 @@ def _model_builder_optimize_wrapper(
1227
1235
Returns:
1228
1236
Model: A deployable ``Model`` object.
1229
1237
"""
1238
+ if (
1239
+ hasattr (self , "enable_network_isolation" )
1240
+ and self .enable_network_isolation
1241
+ and sharding_config
1242
+ ):
1243
+ raise ValueError (
1244
+ "EnableNetworkIsolation cannot be set to True since SageMaker Fast Model "
1245
+ "Loading of model requires network access."
1246
+ )
1247
+
1248
+ # TODO: ideally these dictionaries need to be sagemaker_core shapes
1249
+ # TODO: for organization, abstract all validation behind this fn
1250
+ _validate_optimization_configuration (
1251
+ instance_type = instance_type ,
1252
+ quantization_config = quantization_config ,
1253
+ compilation_config = compilation_config ,
1254
+ sharding_config = sharding_config ,
1255
+ speculative_decoding_config = speculative_decoding_config ,
1256
+ )
1257
+
1230
1258
self .is_compiled = compilation_config is not None
1231
1259
self .is_quantized = quantization_config is not None
1232
1260
self .speculative_decoding_draft_model_source = _extract_speculative_draft_model_provider (
@@ -1236,6 +1264,43 @@ def _model_builder_optimize_wrapper(
1236
1264
if self .mode != Mode .SAGEMAKER_ENDPOINT :
1237
1265
raise ValueError ("Model optimization is only supported in Sagemaker Endpoint Mode." )
1238
1266
1267
+ if sharding_config and (
1268
+ quantization_config or compilation_config or speculative_decoding_config
1269
+ ):
1270
+ raise ValueError (
1271
+ "Sharding config is mutually exclusive and cannot be combined with any other optimization."
1272
+ )
1273
+
1274
+ if sharding_config and (
1275
+ quantization_config or compilation_config or speculative_decoding_config
1276
+ ):
1277
+ raise ValueError (
1278
+ (
1279
+ "Sharding config is mutually exclusive "
1280
+ "and cannot be combined with any other optimization."
1281
+ )
1282
+ )
1283
+
1284
+ if sharding_config :
1285
+ has_tensor_parallel_degree_in_env_vars = (
1286
+ env_vars and "OPTION_TENSOR_PARALLEL_DEGREE" in env_vars
1287
+ )
1288
+ has_tensor_parallel_degree_in_overrides = (
1289
+ sharding_config
1290
+ and sharding_config .get ("OverrideEnvironment" )
1291
+ and "OPTION_TENSOR_PARALLEL_DEGREE" in sharding_config .get ("OverrideEnvironment" )
1292
+ )
1293
+ if (
1294
+ not has_tensor_parallel_degree_in_env_vars
1295
+ and not has_tensor_parallel_degree_in_overrides
1296
+ ):
1297
+ raise ValueError (
1298
+ (
1299
+ "OPTION_TENSOR_PARALLEL_DEGREE is a required "
1300
+ "environment variable with sharding config."
1301
+ )
1302
+ )
1303
+
1239
1304
self .sagemaker_session = sagemaker_session or self .sagemaker_session or Session ()
1240
1305
self .instance_type = instance_type or self .instance_type
1241
1306
self .role_arn = role_arn or self .role_arn
@@ -1252,6 +1317,7 @@ def _model_builder_optimize_wrapper(
1252
1317
quantization_config = quantization_config ,
1253
1318
compilation_config = compilation_config ,
1254
1319
speculative_decoding_config = speculative_decoding_config ,
1320
+ sharding_config = sharding_config ,
1255
1321
env_vars = env_vars ,
1256
1322
vpc_config = vpc_config ,
1257
1323
kms_key = kms_key ,
@@ -1270,12 +1336,16 @@ def _model_builder_optimize_wrapper(
1270
1336
quantization_config = quantization_config ,
1271
1337
compilation_config = compilation_config ,
1272
1338
speculative_decoding_config = speculative_decoding_config ,
1339
+ sharding_config = sharding_config ,
1273
1340
env_vars = env_vars ,
1274
1341
vpc_config = vpc_config ,
1275
1342
kms_key = kms_key ,
1276
1343
max_runtime_in_sec = max_runtime_in_sec ,
1277
1344
)
1278
1345
1346
+ if sharding_config :
1347
+ self .pysdk_model ._is_sharded_model = True
1348
+
1279
1349
if input_args :
1280
1350
optimization_instance_type = input_args ["DeploymentInstanceType" ]
1281
1351
@@ -1325,6 +1395,7 @@ def _optimize_for_hf(
1325
1395
quantization_config : Optional [Dict ] = None ,
1326
1396
compilation_config : Optional [Dict ] = None ,
1327
1397
speculative_decoding_config : Optional [Dict ] = None ,
1398
+ sharding_config : Optional [Dict ] = None ,
1328
1399
env_vars : Optional [Dict ] = None ,
1329
1400
vpc_config : Optional [Dict ] = None ,
1330
1401
kms_key : Optional [str ] = None ,
@@ -1340,6 +1411,8 @@ def _optimize_for_hf(
1340
1411
compilation_config (Optional[Dict]): Compilation configuration. Defaults to ``None``.
1341
1412
speculative_decoding_config (Optional[Dict]): Speculative decoding configuration.
1342
1413
Defaults to ``None``
1414
+ sharding_config (Optional[Dict]): Model sharding configuration.
1415
+ Defaults to ``None``
1343
1416
env_vars (Optional[Dict]): Additional environment variables to run the optimization
1344
1417
container. Defaults to ``None``.
1345
1418
vpc_config (Optional[Dict]): The VpcConfig set on the model. Defaults to ``None``.
@@ -1363,7 +1436,7 @@ def _optimize_for_hf(
1363
1436
self .pysdk_model , speculative_decoding_config , False
1364
1437
)
1365
1438
1366
- if quantization_config or compilation_config :
1439
+ if quantization_config or compilation_config or sharding_config :
1367
1440
create_optimization_job_args = {
1368
1441
"OptimizationJobName" : job_name ,
1369
1442
"DeploymentInstanceType" : self .instance_type ,
@@ -1378,16 +1451,20 @@ def _optimize_for_hf(
1378
1451
model_source = _generate_model_source (self .pysdk_model .model_data , False )
1379
1452
create_optimization_job_args ["ModelSource" ] = model_source
1380
1453
1381
- optimization_config , quantization_override_env , compilation_override_env = (
1382
- _extract_optimization_config_and_env (quantization_config , compilation_config )
1383
- )
1454
+ (
1455
+ optimization_config ,
1456
+ quantization_override_env ,
1457
+ compilation_override_env ,
1458
+ sharding_override_env ,
1459
+ ) = _extract_optimization_config_and_env (quantization_config , compilation_config )
1384
1460
create_optimization_job_args ["OptimizationConfigs" ] = [
1385
1461
{k : v } for k , v in optimization_config .items ()
1386
1462
]
1387
1463
self .pysdk_model .env .update (
1388
1464
{
1389
1465
** (quantization_override_env or {}),
1390
1466
** (compilation_override_env or {}),
1467
+ ** (sharding_override_env or {}),
1391
1468
}
1392
1469
)
1393
1470
0 commit comments