7373 _generate_model_source ,
7474 _extract_optimization_config_and_env ,
7575 _is_s3_uri ,
76- _normalize_local_model_path ,
7776 _custom_speculative_decoding ,
7877 _extract_speculative_draft_model_provider ,
7978)
@@ -833,6 +832,8 @@ def build( # pylint: disable=R0911
833832 # until we deprecate HUGGING_FACE_HUB_TOKEN.
834833 if self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" ) and not self .env_vars .get ("HF_TOKEN" ):
835834 self .env_vars ["HF_TOKEN" ] = self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" )
835+ elif self .env_vars .get ("HF_TOKEN" ) and not self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" ):
836+ self .env_vars ["HUGGING_FACE_HUB_TOKEN" ] = self .env_vars .get ("HF_TOKEN" )
836837
837838 self .sagemaker_session .settings ._local_download_dir = self .model_path
838839
@@ -851,7 +852,9 @@ def build( # pylint: disable=R0911
851852
852853 self ._build_validations ()
853854
854- if not self ._is_jumpstart_model_id () and self .model_server :
855+ if (
856+ not (isinstance (self .model , str ) and self ._is_jumpstart_model_id ())
857+ ) and self .model_server :
855858 return self ._build_for_model_server ()
856859
857860 if isinstance (self .model , str ):
@@ -1216,18 +1219,15 @@ def _model_builder_optimize_wrapper(
12161219 raise ValueError ("Quantization config and compilation config are mutually exclusive." )
12171220
12181221 self .sagemaker_session = sagemaker_session or self .sagemaker_session or Session ()
1219-
12201222 self .instance_type = instance_type or self .instance_type
12211223 self .role_arn = role_arn or self .role_arn
12221224
1223- self .build (mode = self .mode , sagemaker_session = self .sagemaker_session )
12241225 job_name = job_name or f"modelbuilderjob-{ uuid .uuid4 ().hex } "
1225-
12261226 if self ._is_jumpstart_model_id ():
1227+ self .build (mode = self .mode , sagemaker_session = self .sagemaker_session )
12271228 input_args = self ._optimize_for_jumpstart (
12281229 output_path = output_path ,
12291230 instance_type = instance_type ,
1230- role_arn = self .role_arn ,
12311231 tags = tags ,
12321232 job_name = job_name ,
12331233 accept_eula = accept_eula ,
@@ -1240,10 +1240,13 @@ def _model_builder_optimize_wrapper(
12401240 max_runtime_in_sec = max_runtime_in_sec ,
12411241 )
12421242 else :
1243+ if self .model_server != ModelServer .DJL_SERVING :
1244+ logger .info ("Overriding model server to DJL_SERVING." )
1245+ self .model_server = ModelServer .DJL_SERVING
1246+
1247+ self .build (mode = self .mode , sagemaker_session = self .sagemaker_session )
12431248 input_args = self ._optimize_for_hf (
12441249 output_path = output_path ,
1245- instance_type = instance_type ,
1246- role_arn = self .role_arn ,
12471250 tags = tags ,
12481251 job_name = job_name ,
12491252 quantization_config = quantization_config ,
@@ -1269,8 +1272,6 @@ def _model_builder_optimize_wrapper(
12691272 def _optimize_for_hf (
12701273 self ,
12711274 output_path : str ,
1272- instance_type : Optional [str ] = None ,
1273- role_arn : Optional [str ] = None ,
12741275 tags : Optional [Tags ] = None ,
12751276 job_name : Optional [str ] = None ,
12761277 quantization_config : Optional [Dict ] = None ,
@@ -1285,9 +1286,6 @@ def _optimize_for_hf(
12851286
12861287 Args:
12871288 output_path (str): Specifies where to store the compiled/quantized model.
1288- instance_type (Optional[str]): Target deployment instance type that
1289- the model is optimized for.
1290- role_arn (Optional[str]): Execution role. Defaults to ``None``.
12911289 tags (Optional[Tags]): Tags for labeling a model optimization job. Defaults to ``None``.
12921290 job_name (Optional[str]): The name of the model optimization job. Defaults to ``None``.
12931291 quantization_config (Optional[Dict]): Quantization configuration. Defaults to ``None``.
@@ -1305,13 +1303,6 @@ def _optimize_for_hf(
13051303 Returns:
13061304 Optional[Dict[str, Any]]: Model optimization job input arguments.
13071305 """
1308- if self .model_server != ModelServer .DJL_SERVING :
1309- logger .info ("Overwriting model server to DJL." )
1310- self .model_server = ModelServer .DJL_SERVING
1311-
1312- self .role_arn = role_arn or self .role_arn
1313- self .instance_type = instance_type or self .instance_type
1314-
13151306 self .pysdk_model = _custom_speculative_decoding (
13161307 self .pysdk_model , speculative_decoding_config , False
13171308 )
@@ -1371,13 +1362,12 @@ def _optimize_prepare_for_hf(self):
13711362 )
13721363 else :
13731364 if not custom_model_path :
1374- custom_model_path = f"/tmp/sagemaker/model-builder/{ self .model } /code "
1365+ custom_model_path = f"/tmp/sagemaker/model-builder/{ self .model } "
13751366 download_huggingface_model_metadata (
13761367 self .model ,
1377- custom_model_path ,
1368+ os . path . join ( custom_model_path , "code" ) ,
13781369 self .env_vars .get ("HUGGING_FACE_HUB_TOKEN" ),
13791370 )
1380- custom_model_path = _normalize_local_model_path (custom_model_path )
13811371
13821372 self .pysdk_model .model_data , env = self ._prepare_for_mode (
13831373 model_path = custom_model_path ,
0 commit comments