1- from argparse import ArgumentParser
2- from transformers import AutoModelForCausalLM , AutoTokenizer , BitsAndBytesConfig
3- from sagemaker .huggingface import HuggingFaceModel
4- import torch
1+ import json
2+ import os
53import tarfile
4+ from argparse import ArgumentParser
5+
66import boto3
7- import json
7+ import torch
8+ from sagemaker .huggingface import HuggingFaceModel
9+ from transformers import AutoModelForCausalLM , AutoTokenizer , BitsAndBytesConfig
10+
11+ PACKAGE_DIR = os .path .dirname (os .path .dirname (os .path .abspath (__file__ )))
12+ MODEL_DIRECTORY = os .path .join (PACKAGE_DIR , "model_artifacts" )
13+ CODE_PATH = "code"
814
9- MODEL_DIRECTORY = '../../../model_artifacts'
15+ def parse_args ():
16+ arg_parser = ArgumentParser ()
17+ arg_parser .add_argument ("--model_name" , type = str , required = True , choices = (
18+ "Qwen/Qwen2.5-1.5B-Instruct" ,
19+ "microsoft/Phi-3.5-mini-instruct" ,
20+ "microsoft/Phi-4-mini-instruct"
21+ )
22+ )
23+ arg_parser .add_argument ("--bucket_name" , type = str , required = True )
24+ arg_parser .add_argument ("--bucket_prefix" , type = str , required = True )
25+ arg_parser .add_argument ("--sagemaker_execution_role_arn" , type = str , required = True )
26+ return arg_parser .parse_args ()
1027
1128def cleanup_endpoints (endpoint_name ):
1229
13- sagemaker_client = boto3 .client ("sagemaker" , region_name = ' us-east-1' )
30+ sagemaker_client = boto3 .client ("sagemaker" , region_name = " us-east-1" )
1431
15- endpoints = sagemaker_client .list_endpoints ()[' Endpoints' ]
16- endpoints_configs = sagemaker_client .list_endpoint_configs ()[' EndpointConfigs' ]
32+ endpoints = sagemaker_client .list_endpoints ()[" Endpoints" ]
33+ endpoints_configs = sagemaker_client .list_endpoint_configs ()[" EndpointConfigs" ]
1734
18- endpoints_names = [e [' EndpointName' ] for e in endpoints ]
19- endpoints_configs_names = [e [' EndpointConfigName' ] for e in endpoints_configs ]
35+ endpoints_names = [e [" EndpointName" ] for e in endpoints ]
36+ endpoints_configs_names = [e [" EndpointConfigName" ] for e in endpoints_configs ]
2037
2138 if endpoint_name in endpoints_names :
2239 sagemaker_client .delete_endpoint (EndpointConfigName = endpoint_name )
@@ -44,14 +61,14 @@ def load_artifacts(settings):
4461
4562def write_model_to_s3 (settings , model_name ):
4663 tar_file_name = f"{ model_name } .tar.gz"
47-
64+
4865 with tarfile .open (tar_file_name , "w:gz" ) as tar :
4966 tar .add (MODEL_DIRECTORY , arcname = "." )
50-
67+
5168 s3_model_path = f"{ settings .deploy_bucket_prefix } /{ tar_file_name } "
5269 s3_client = boto3 .client ("s3" )
5370 s3_client .upload_file (tar_file_name , settings .deploy_bucket_name , s3_model_path )
54-
71+
5572 s3_uri = f"s3://{ settings .deploy_bucket_name } /{ s3_model_path } "
5673 print (f"Model uploaded to: { s3_uri } " )
5774 return s3_uri
@@ -92,15 +109,35 @@ def validate_deployment(predictor):
92109 print (f"Validation failed: { e } " )
93110 raise e
94111
95- def deploy (settings , cleanup_endpoints = False ):
96- sanitized_model_name = settings .hf_name .split ('/' )[1 ].replace ('.' , '-' )
97- if cleanup_endpoints :
98- cleanup_endpoints (sanitized_model_name )
112+ def validate_model_directory ():
113+ endpoint_code_path = os .path .join (MODEL_DIRECTORY , CODE_PATH )
114+ inference_script_name = "inference.py"
115+ requirements_name = "requirements.txt"
116+ if not os .path .isdir (endpoint_code_path ):
117+ raise ValueError (f"{ endpoint_code_path } is missing." )
118+ if not os .path .isfile (os .path .join (endpoint_code_path , inference_script_name )):
119+ raise ValueError (f"{ inference_script_name } is missing from the code directory." )
120+ if not os .path .isfile (os .path .join (endpoint_code_path , requirements_name )):
121+ raise ValueError (f"{ requirements_name } is missing from the code directory." )
122+
123+
124+ def cleanup_model_directory ():
125+ for item in os .listdir (MODEL_DIRECTORY ):
126+ item_path = os .path .join (MODEL_DIRECTORY , item )
127+ if item == CODE_PATH :
128+ continue
129+ if os .path .isfile (item_path ):
130+ os .remove (item_path )
131+
132+ def deploy (settings ):
133+ validate_model_directory ()
134+ cleanup_model_directory ()
135+ sanitized_model_name = settings .hf_name .split ("/" )[1 ].replace ("." , "-" )
99136 load_artifacts (settings )
100137 s3_uri = write_model_to_s3 (settings , sanitized_model_name )
101138 predictor = deploy_endpoint (
102139 s3_uri ,
103- settings .sagemaker_execution_role_arn ,
140+ settings .sagemaker_execution_role_arn ,
104141 sanitized_model_name
105142 )
106143 validate_deployment (predictor )
0 commit comments