66import boto3
77import json
88
9- MODEL_DIRECTORY = "model_artifacts"
9+ def cleanup_endpoints ( endpoint_name ):
1010
11- def parse_args ():
12- arg_parser = ArgumentParser ()
13- arg_parser .add_argument ("--model_name" , type = str , required = True , choices = (
14- "Qwen/Qwen2.5-1.5B-Instruct" ,
15- "microsoft/Phi-3.5-mini-instruct"
16- )
17- )
18- arg_parser .add_argument ("--bucket_name" , type = str , required = True )
19- arg_parser .add_argument ("--bucket_prefix" , type = str , required = True )
20- arg_parser .add_argument ("--sagemaker_execution_role_arn" , type = str , required = True )
21- return arg_parser .parse_args ()
11+ sagemaker_client = boto3 .client ("sagemaker" , region_name = 'us-east-1' )
12+
13+ endpoints = sagemaker_client .list_endpoints ()['Endpoints' ]
14+ endpoints_configs = sagemaker_client .list_endpoint_configs ()['EndpointConfigs' ]
2215
23- def load_artifacts (args ):
16+ endpoints_names = [e ['EndpointName' ] for e in endpoints ]
17+ endpoints_configs_names = [e ['EndpointConfigName' ] for e in endpoints_configs ]
18+
19+ if endpoint_name in endpoints_names :
20+ sagemaker_client .delete_endpoint (EndpointConfigName = endpoint_name )
21+ if endpoint_name in endpoints_configs_names :
22+ sagemaker_client .delete_endpoint_config (EndpointConfigName = endpoint_name )
23+
24+ def load_artifacts (settings ):
2425 bnb_config = BitsAndBytesConfig (
2526 load_in_4bit = True ,
2627 bnb_4bit_quant_type = "nf4" ,
@@ -29,27 +30,27 @@ def load_artifacts(args):
2930 )
3031
3132 model = AutoModelForCausalLM .from_pretrained (
32- args . model_name ,
33+ settings . hf_name ,
3334 device_map = "auto" ,
3435 quantization_config = bnb_config
3536 )
3637
37- tokenizer = AutoTokenizer .from_pretrained (args . model_name )
38+ tokenizer = AutoTokenizer .from_pretrained (settings . hf_name )
3839
3940 model .save_pretrained (MODEL_DIRECTORY )
4041 tokenizer .save_pretrained (MODEL_DIRECTORY )
4142
42- def write_model_to_s3 (args , model_name ):
43+ def write_model_to_s3 (settings , model_name ):
4344 tar_file_name = f"{ model_name } .tar.gz"
4445
4546 with tarfile .open (tar_file_name , "w:gz" ) as tar :
4647 tar .add (MODEL_DIRECTORY , arcname = "." )
4748
48- s3_model_path = f"{ args . bucket_prefix } /{ tar_file_name } "
49+ s3_model_path = f"{ settings . deploy_bucket_prefix } /{ tar_file_name } "
4950 s3_client = boto3 .client ("s3" )
50- s3_client .upload_file (tar_file_name , args . bucket_name , s3_model_path )
51+ s3_client .upload_file (tar_file_name , settings . deploy_bucket_name , s3_model_path )
5152
52- s3_uri = f"s3://{ args . bucket_name } /{ s3_model_path } "
53+ s3_uri = f"s3://{ settings . deploy_bucket_name } /{ s3_model_path } "
5354 print (f"Model uploaded to: { s3_uri } " )
5455 return s3_uri
5556
@@ -89,17 +90,15 @@ def validate_deployment(predictor):
8990 print (f"Validation failed: { e } " )
9091 raise e
9192
92- def deploy ():
93- args = parse_args ()
94- load_artifacts (args )
95- sanitized_model_name = args .model_name .split ('/' )[1 ].replace ('.' , '-' )
96- s3_uri = write_model_to_s3 (args , sanitized_model_name )
93+ def deploy (settings , cleanup_endpoints = False ):
94+ sanitized_model_name = settings .hf_name .split ('/' )[1 ].replace ('.' , '-' )
95+ if cleanup_endpoints :
96+ cleanup_endpoints (sanitized_model_name )
97+ load_artifacts (settings )
98+ s3_uri = write_model_to_s3 (settings , sanitized_model_name )
9799 predictor = deploy_endpoint (
98- s3_uri ,
99- args .sagemaker_execution_role_arn ,
100+ s3_uri ,
101+ settings .sagemaker_execution_role_arn ,
100102 sanitized_model_name
101103 )
102- validate_deployment (predictor )
103-
104- if __name__ == "__main__" :
105- deploy ()
104+ validate_deployment (predictor )
0 commit comments