73
73
_generate_model_source ,
74
74
_extract_optimization_config_and_env ,
75
75
_is_s3_uri ,
76
- _normalize_local_model_path ,
77
76
_custom_speculative_decoding ,
78
77
_extract_speculative_draft_model_provider ,
79
78
)
@@ -833,6 +832,8 @@ def build( # pylint: disable=R0911
833
832
# until we deprecate HUGGING_FACE_HUB_TOKEN.
834
833
if self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" ) and not self .env_vars .get ("HF_TOKEN" ):
835
834
self .env_vars ["HF_TOKEN" ] = self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" )
835
+ elif self .env_vars .get ("HF_TOKEN" ) and not self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" ):
836
+ self .env_vars ["HUGGING_FACE_HUB_TOKEN" ] = self .env_vars .get ("HF_TOKEN" )
836
837
837
838
self .sagemaker_session .settings ._local_download_dir = self .model_path
838
839
@@ -851,7 +852,9 @@ def build( # pylint: disable=R0911
851
852
852
853
self ._build_validations ()
853
854
854
- if not self ._is_jumpstart_model_id () and self .model_server :
855
+ if (
856
+ not (isinstance (self .model , str ) and self ._is_jumpstart_model_id ())
857
+ ) and self .model_server :
855
858
return self ._build_for_model_server ()
856
859
857
860
if isinstance (self .model , str ):
@@ -1216,18 +1219,14 @@ def _model_builder_optimize_wrapper(
1216
1219
raise ValueError ("Quantization config and compilation config are mutually exclusive." )
1217
1220
1218
1221
self .sagemaker_session = sagemaker_session or self .sagemaker_session or Session ()
1219
-
1220
1222
self .instance_type = instance_type or self .instance_type
1221
1223
self .role_arn = role_arn or self .role_arn
1222
1224
1223
- self .build (mode = self .mode , sagemaker_session = self .sagemaker_session )
1224
1225
job_name = job_name or f"modelbuilderjob-{ uuid .uuid4 ().hex } "
1225
-
1226
1226
if self ._is_jumpstart_model_id ():
1227
+ self .build (mode = self .mode , sagemaker_session = self .sagemaker_session )
1227
1228
input_args = self ._optimize_for_jumpstart (
1228
1229
output_path = output_path ,
1229
- instance_type = instance_type ,
1230
- role_arn = self .role_arn ,
1231
1230
tags = tags ,
1232
1231
job_name = job_name ,
1233
1232
accept_eula = accept_eula ,
@@ -1240,10 +1239,13 @@ def _model_builder_optimize_wrapper(
1240
1239
max_runtime_in_sec = max_runtime_in_sec ,
1241
1240
)
1242
1241
else :
1242
+ if self .model_server != ModelServer .DJL_SERVING :
1243
+ logger .info ("Overriding model server to DJL_SERVING." )
1244
+ self .model_server = ModelServer .DJL_SERVING
1245
+
1246
+ self .build (mode = self .mode , sagemaker_session = self .sagemaker_session )
1243
1247
input_args = self ._optimize_for_hf (
1244
1248
output_path = output_path ,
1245
- instance_type = instance_type ,
1246
- role_arn = self .role_arn ,
1247
1249
tags = tags ,
1248
1250
job_name = job_name ,
1249
1251
quantization_config = quantization_config ,
@@ -1256,8 +1258,10 @@ def _model_builder_optimize_wrapper(
1256
1258
)
1257
1259
1258
1260
if input_args :
1261
+ print (input_args )
1259
1262
self .sagemaker_session .sagemaker_client .create_optimization_job (** input_args )
1260
1263
job_status = self .sagemaker_session .wait_for_optimization_job (job_name )
1264
+ print (job_status )
1261
1265
return _generate_optimized_model (self .pysdk_model , job_status )
1262
1266
1263
1267
self .pysdk_model .remove_tag_with_key (Tag .OPTIMIZATION_JOB_NAME )
@@ -1269,8 +1273,6 @@ def _model_builder_optimize_wrapper(
1269
1273
def _optimize_for_hf (
1270
1274
self ,
1271
1275
output_path : str ,
1272
- instance_type : Optional [str ] = None ,
1273
- role_arn : Optional [str ] = None ,
1274
1276
tags : Optional [Tags ] = None ,
1275
1277
job_name : Optional [str ] = None ,
1276
1278
quantization_config : Optional [Dict ] = None ,
@@ -1285,9 +1287,6 @@ def _optimize_for_hf(
1285
1287
1286
1288
Args:
1287
1289
output_path (str): Specifies where to store the compiled/quantized model.
1288
- instance_type (Optional[str]): Target deployment instance type that
1289
- the model is optimized for.
1290
- role_arn (Optional[str]): Execution role. Defaults to ``None``.
1291
1290
tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
1292
1291
job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
1293
1292
quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``.
@@ -1305,13 +1304,6 @@ def _optimize_for_hf(
1305
1304
Returns:
1306
1305
Optional[Dict[str, Any]]: Model optimization job input arguments.
1307
1306
"""
1308
- if self .model_server != ModelServer .DJL_SERVING :
1309
- logger .info ("Overwriting model server to DJL." )
1310
- self .model_server = ModelServer .DJL_SERVING
1311
-
1312
- self .role_arn = role_arn or self .role_arn
1313
- self .instance_type = instance_type or self .instance_type
1314
-
1315
1307
self .pysdk_model = _custom_speculative_decoding (
1316
1308
self .pysdk_model , speculative_decoding_config , False
1317
1309
)
@@ -1371,13 +1363,12 @@ def _optimize_prepare_for_hf(self):
1371
1363
)
1372
1364
else :
1373
1365
if not custom_model_path :
1374
- custom_model_path = f"/tmp/sagemaker/model-builder/{ self .model } /code "
1366
+ custom_model_path = f"/tmp/sagemaker/model-builder/{ self .model } "
1375
1367
download_huggingface_model_metadata (
1376
1368
self .model ,
1377
- custom_model_path ,
1369
+ os . path . join ( custom_model_path , "code" ) ,
1378
1370
self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" ),
1379
1371
)
1380
- custom_model_path = _normalize_local_model_path (custom_model_path )
1381
1372
1382
1373
self .pysdk_model .model_data , env = self ._prepare_for_mode (
1383
1374
model_path = custom_model_path ,
0 commit comments