1313"""Placeholder docstring"""
1414from __future__ import absolute_import
1515
16+ import copy
1617from abc import ABC , abstractmethod
18+ from datetime import datetime , timedelta
1719from typing import Type
1820import logging
1921
2022from sagemaker .model import Model
2123from sagemaker import model_uris
2224from sagemaker .serve .model_server .djl_serving .prepare import prepare_djl_js_resources
25+ from sagemaker .serve .model_server .djl_serving .utils import _get_admissible_tensor_parallel_degrees
2326from sagemaker .serve .model_server .tgi .prepare import prepare_tgi_js_resources , _create_dir_structure
2427from sagemaker .serve .mode .function_pointers import Mode
28+ from sagemaker .serve .utils .exceptions import (
29+ LocalDeepPingException ,
30+ LocalModelOutOfMemoryException ,
31+ LocalModelInvocationException ,
32+ LocalModelLoadException ,
33+ SkipTuningComboException ,
34+ )
2535from sagemaker .serve .utils .predictors import (
2636 DjlLocalModePredictor ,
2737 TgiLocalModePredictor ,
2838)
29- from sagemaker .serve .utils .local_hardware import _get_nb_instance , _get_ram_usage_mb
39+ from sagemaker .serve .utils .local_hardware import (
40+ _get_nb_instance ,
41+ _get_ram_usage_mb ,
42+ )
3043from sagemaker .serve .utils .telemetry_logger import _capture_telemetry
44+ from sagemaker .serve .utils .tuning import (
45+ _pretty_print_results_jumpstart ,
46+ _serial_benchmark ,
47+ _concurrent_benchmark ,
48+ _more_performant ,
49+ _sharded_supported ,
50+ )
3151from sagemaker .serve .utils .types import ModelServer
3252from sagemaker .base_predictor import PredictorBase
3353from sagemaker .jumpstart .model import JumpStartModel
@@ -134,7 +154,7 @@ def _js_builder_deploy_wrapper(self, *args, **kwargs) -> Type[PredictorBase]:
134154 model_data = self .pysdk_model .model_data ,
135155 )
136156 elif not hasattr (self , "prepared_for_tgi" ):
137- self .prepared_for_tgi = prepare_tgi_js_resources (
157+ self .js_model_config , self . prepared_for_tgi = prepare_tgi_js_resources (
138158 model_path = self .model_path ,
139159 js_id = self .model ,
140160 dependencies = self .dependencies ,
@@ -222,7 +242,7 @@ def _build_for_tgi_jumpstart(self):
222242 env = {}
223243 if self .mode == Mode .LOCAL_CONTAINER :
224244 if not hasattr (self , "prepared_for_tgi" ):
225- self .prepared_for_tgi = prepare_tgi_js_resources (
245+ self .js_model_config , self . prepared_for_tgi = prepare_tgi_js_resources (
226246 model_path = self .model_path ,
227247 js_id = self .model ,
228248 dependencies = self .dependencies ,
@@ -234,6 +254,183 @@ def _build_for_tgi_jumpstart(self):
234254
235255 self .pysdk_model .env .update (env )
236256
257+ def _tune_for_js (self , sharded_supported : bool , max_tuning_duration : int = 1800 ):
258+ """Tune for Jumpstart Models in Local Mode.
259+
260+ Args:
261+ sharded_supported (bool): Indicates whether sharding is supported by this ``Model``
262+ max_tuning_duration (int): The maximum timeout to deploy this ``Model`` locally.
263+ Default: ``1800``
264+ returns:
265+ Tuned Model.
266+ """
267+ if self .mode != Mode .LOCAL_CONTAINER :
268+ logger .warning (
269+ "Tuning is only a %s capability. Returning original model." , Mode .LOCAL_CONTAINER
270+ )
271+ return self .pysdk_model
272+
273+ num_shard_env_var_name = "SM_NUM_GPUS"
274+ if "OPTION_TENSOR_PARALLEL_DEGREE" in self .pysdk_model .env .keys ():
275+ num_shard_env_var_name = "OPTION_TENSOR_PARALLEL_DEGREE"
276+
277+ initial_env_vars = copy .deepcopy (self .pysdk_model .env )
278+ admissible_tensor_parallel_degrees = _get_admissible_tensor_parallel_degrees (
279+ self .js_model_config
280+ )
281+
282+ if len (admissible_tensor_parallel_degrees ) > 1 and not sharded_supported :
283+ admissible_tensor_parallel_degrees = [1 ]
284+ logger .warning (
285+ "Sharding across multiple GPUs is not supported for this model. "
286+ "Model can only be sharded across [1] GPU"
287+ )
288+
289+ benchmark_results = {}
290+ best_tuned_combination = None
291+ timeout = datetime .now () + timedelta (seconds = max_tuning_duration )
292+ for tensor_parallel_degree in admissible_tensor_parallel_degrees :
293+ if datetime .now () > timeout :
294+ logger .info ("Max tuning duration reached. Tuning stopped." )
295+ break
296+
297+ self .pysdk_model .env .update ({num_shard_env_var_name : str (tensor_parallel_degree )})
298+ try :
299+ logger .info ("Trying tensor parallel degree: %s" , tensor_parallel_degree )
300+
301+ predictor = self .pysdk_model .deploy (model_data_download_timeout = max_tuning_duration )
302+
303+ avg_latency , p90 , avg_tokens_per_second = _serial_benchmark (
304+ predictor , self .schema_builder .sample_input
305+ )
306+ throughput_per_second , standard_deviation = _concurrent_benchmark (
307+ predictor , self .schema_builder .sample_input
308+ )
309+
310+ tested_env = copy .deepcopy (self .pysdk_model .env )
311+ logger .info (
312+ "Average latency: %s, throughput/s: %s for configuration: %s" ,
313+ avg_latency ,
314+ throughput_per_second ,
315+ tested_env ,
316+ )
317+ benchmark_results [avg_latency ] = [
318+ tested_env ,
319+ p90 ,
320+ avg_tokens_per_second ,
321+ throughput_per_second ,
322+ standard_deviation ,
323+ ]
324+
325+ if not best_tuned_combination :
326+ best_tuned_combination = [
327+ avg_latency ,
328+ tensor_parallel_degree ,
329+ None ,
330+ p90 ,
331+ avg_tokens_per_second ,
332+ throughput_per_second ,
333+ standard_deviation ,
334+ ]
335+ else :
336+ tuned_configuration = [
337+ avg_latency ,
338+ tensor_parallel_degree ,
339+ None ,
340+ p90 ,
341+ avg_tokens_per_second ,
342+ throughput_per_second ,
343+ standard_deviation ,
344+ ]
345+ if _more_performant (best_tuned_combination , tuned_configuration ):
346+ best_tuned_combination = tuned_configuration
347+ except LocalDeepPingException as e :
348+ logger .warning (
349+ "Deployment unsuccessful with %s: %s. " "Failed to invoke the model server: %s" ,
350+ num_shard_env_var_name ,
351+ tensor_parallel_degree ,
352+ str (e ),
353+ )
354+ except LocalModelOutOfMemoryException as e :
355+ logger .warning (
356+ "Deployment unsuccessful with %s: %s. "
357+ "Out of memory when loading the model: %s" ,
358+ num_shard_env_var_name ,
359+ tensor_parallel_degree ,
360+ str (e ),
361+ )
362+ except LocalModelInvocationException as e :
363+ logger .warning (
364+ "Deployment unsuccessful with %s: %s. "
365+ "Failed to invoke the model server: %s"
366+ "Please check that model server configurations are as expected "
367+ "(Ex. serialization, deserialization, content_type, accept)." ,
368+ num_shard_env_var_name ,
369+ tensor_parallel_degree ,
370+ str (e ),
371+ )
372+ except LocalModelLoadException as e :
373+ logger .warning (
374+ "Deployment unsuccessful with %s: %s. " "Failed to load the model: %s." ,
375+ num_shard_env_var_name ,
376+ tensor_parallel_degree ,
377+ str (e ),
378+ )
379+ except SkipTuningComboException as e :
380+ logger .warning (
381+ "Deployment with %s: %s"
382+ "was expected to be successful. However failed with: %s. "
383+ "Trying next combination." ,
384+ num_shard_env_var_name ,
385+ tensor_parallel_degree ,
386+ str (e ),
387+ )
388+ except Exception : # pylint: disable=W0703
389+ logger .exception (
390+ "Deployment unsuccessful with %s: %s. " "with uncovered exception" ,
391+ num_shard_env_var_name ,
392+ tensor_parallel_degree ,
393+ )
394+
395+ if best_tuned_combination :
396+ self .pysdk_model .env .update ({num_shard_env_var_name : str (best_tuned_combination [1 ])})
397+
398+ _pretty_print_results_jumpstart (benchmark_results , [num_shard_env_var_name ])
399+ logger .info (
400+ "Model Configuration: %s was most performant with avg latency: %s, "
401+ "p90 latency: %s, average tokens per second: %s, throughput/s: %s, "
402+ "standard deviation of request %s" ,
403+ self .pysdk_model .env ,
404+ best_tuned_combination [0 ],
405+ best_tuned_combination [3 ],
406+ best_tuned_combination [4 ],
407+ best_tuned_combination [5 ],
408+ best_tuned_combination [6 ],
409+ )
410+ else :
411+ self .pysdk_model .env .update (initial_env_vars )
412+ logger .debug (
413+ "Failed to gather any tuning results. "
414+ "Please inspect the stack trace emitted from live logging for more details. "
415+ "Falling back to default model configurations: %s" ,
416+ self .pysdk_model .env ,
417+ )
418+
419+ return self .pysdk_model
420+
421+ @_capture_telemetry ("djl_jumpstart.tune" )
422+ def tune_for_djl_jumpstart (self , max_tuning_duration : int = 1800 ):
423+ """Tune for Jumpstart Models with DJL DLC"""
424+ return self ._tune_for_js (sharded_supported = True , max_tuning_duration = max_tuning_duration )
425+
426+ @_capture_telemetry ("tgi_jumpstart.tune" )
427+ def tune_for_tgi_jumpstart (self , max_tuning_duration : int = 1800 ):
428+ """Tune for Jumpstart Models with TGI DLC"""
429+ sharded_supported = _sharded_supported (self .model , self .js_model_config )
430+ return self ._tune_for_js (
431+ sharded_supported = sharded_supported , max_tuning_duration = max_tuning_duration
432+ )
433+
237434 def _build_for_jumpstart (self ):
238435 """Placeholder docstring"""
239436 # we do not pickle for jumpstart. set to none
@@ -254,6 +451,8 @@ def _build_for_jumpstart(self):
254451 self .image_uri = self .pysdk_model .image_uri
255452
256453 self ._build_for_djl_jumpstart ()
454+
455+ self .pysdk_model .tune = self .tune_for_djl_jumpstart
257456 elif "tgi-inference" in image_uri :
258457 logger .info ("Building for TGI JumpStart Model ID..." )
259458 self .model_server = ModelServer .TGI
@@ -262,6 +461,8 @@ def _build_for_jumpstart(self):
262461 self .image_uri = self .pysdk_model .image_uri
263462
264463 self ._build_for_tgi_jumpstart ()
464+
465+ self .pysdk_model .tune = self .tune_for_tgi_jumpstart
265466 else :
266467 raise ValueError (
267468 "JumpStart Model ID was not packaged with djl-inference or tgi-inference container."
0 commit comments