Skip to content

Commit 38453cb

Browse files
committed
action deploy
1 parent 59832ea commit 38453cb

File tree

5 files changed

+52
-34
lines changed

5 files changed

+52
-34
lines changed

.gitignore

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,4 +7,5 @@ src/wraval.egg-info/
77
**__pycache__/
88
prompts/*
99
.idea
10-
src/wraval/custom_prompts/*
10+
src/wraval/custom_prompts/*
11+
src/wraval/testing.py

config/settings.toml

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
[default]
22
region = 'us-east-1'
3-
data_dir = "./data"
4-
# 's3://llm-finetune-us-east-1-{aws_account}/eval/tones/'
3+
data_dir = 's3://llm-finetune-us-east-1-{aws_account}/eval/tones/'
4+
# "./data"
5+
deploy_bucket_name = 's3://llm-finetune-us-east-1-{aws_account}'
6+
deploy_bucket_prefix = 'models'
7+
sagemaker_execution_role_arn = 'arn:aws:iam::{aws_account}:role/sagemaker-execution-role-us-east-1'
58

69
[haiku-3]
710
model = 'anthropic.claude-3-haiku-20240307-v1:0'

pyproject.toml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,9 @@ dependencies = [
2626
"botocore~=1.34.162",
2727
"sagemaker",
2828
"numpy",
29-
"requests"
29+
"requests",
30+
"bitsandbytes",
31+
"accelerate"
3032
]
3133

3234
[project.scripts]
Lines changed: 29 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -6,21 +6,22 @@
66
import boto3
77
import json
88

9-
MODEL_DIRECTORY = "model_artifacts"
9+
def cleanup_endpoints(endpoint_name):
1010

11-
def parse_args():
12-
arg_parser = ArgumentParser()
13-
arg_parser.add_argument("--model_name", type=str, required=True, choices=(
14-
"Qwen/Qwen2.5-1.5B-Instruct",
15-
"microsoft/Phi-3.5-mini-instruct"
16-
)
17-
)
18-
arg_parser.add_argument("--bucket_name", type=str, required=True)
19-
arg_parser.add_argument("--bucket_prefix", type=str, required=True)
20-
arg_parser.add_argument("--sagemaker_execution_role_arn", type=str, required=True)
21-
return arg_parser.parse_args()
11+
sagemaker_client = boto3.client("sagemaker", region_name='us-east-1')
12+
13+
endpoints = sagemaker_client.list_endpoints()['Endpoints']
14+
endpoints_configs = sagemaker_client.list_endpoint_configs()['EndpointConfigs']
2215

23-
def load_artifacts(args):
16+
endpoints_names = [e['EndpointName'] for e in endpoints]
17+
endpoints_configs_names = [e['EndpointConfigName'] for e in endpoints_configs]
18+
19+
if endpoint_name in endpoints_names:
20+
sagemaker_client.delete_endpoint(EndpointConfigName=endpoint_name)
21+
if endpoint_name in endpoints_configs_names:
22+
sagemaker_client.delete_endpoint_config(EndpointConfigName=endpoint_name)
23+
24+
def load_artifacts(settings):
2425
bnb_config = BitsAndBytesConfig(
2526
load_in_4bit=True,
2627
bnb_4bit_quant_type="nf4",
@@ -29,27 +30,27 @@ def load_artifacts(args):
2930
)
3031

3132
model = AutoModelForCausalLM.from_pretrained(
32-
args.model_name,
33+
settings.hf_name,
3334
device_map="auto",
3435
quantization_config=bnb_config
3536
)
3637

37-
tokenizer = AutoTokenizer.from_pretrained(args.model_name)
38+
tokenizer = AutoTokenizer.from_pretrained(settings.hf_name)
3839

3940
model.save_pretrained(MODEL_DIRECTORY)
4041
tokenizer.save_pretrained(MODEL_DIRECTORY)
4142

42-
def write_model_to_s3(args, model_name):
43+
def write_model_to_s3(settings, model_name):
4344
tar_file_name = f"{model_name}.tar.gz"
4445

4546
with tarfile.open(tar_file_name, "w:gz") as tar:
4647
tar.add(MODEL_DIRECTORY, arcname=".")
4748

48-
s3_model_path = f"{args.bucket_prefix}/{tar_file_name}"
49+
s3_model_path = f"{settings.deploy_bucket_prefix}/{tar_file_name}"
4950
s3_client = boto3.client("s3")
50-
s3_client.upload_file(tar_file_name, args.bucket_name, s3_model_path)
51+
s3_client.upload_file(tar_file_name, settings.deploy_bucket_name, s3_model_path)
5152

52-
s3_uri = f"s3://{args.bucket_name}/{s3_model_path}"
53+
s3_uri = f"s3://{settings.deploy_bucket_name}/{s3_model_path}"
5354
print(f"Model uploaded to: {s3_uri}")
5455
return s3_uri
5556

@@ -89,17 +90,15 @@ def validate_deployment(predictor):
8990
print(f"Validation failed: {e}")
9091
raise e
9192

92-
def deploy():
93-
args = parse_args()
94-
load_artifacts(args)
95-
sanitized_model_name = args.model_name.split('/')[1].replace('.', '-')
96-
s3_uri = write_model_to_s3(args, sanitized_model_name)
93+
def deploy(settings, cleanup_endpoints=False):
94+
sanitized_model_name = settings.hf_name.split('/')[1].replace('.', '-')
95+
if cleanup_endpoints:
96+
cleanup_endpoints(sanitized_model_name)
97+
load_artifacts(settings)
98+
s3_uri = write_model_to_s3(settings, sanitized_model_name)
9799
predictor = deploy_endpoint(
98-
s3_uri,
99-
args.sagemaker_execution_role_arn,
100+
s3_uri,
101+
settings.sagemaker_execution_role_arn,
100102
sanitized_model_name
101103
)
102-
validate_deployment(predictor)
103-
104-
if __name__ == "__main__":
105-
deploy()
104+
validate_deployment(predictor)

src/wraval/main.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
from wraval.actions.action_llm_judge import judge
1212
from wraval.actions.aws_utils import get_current_aws_account_id
1313
from wraval.actions.action_results import show_results
14+
from wraval.actions.action_deploy import deploy
1415
import os
1516

1617

@@ -32,6 +33,8 @@ def get_settings(args):
3233
## add the AWS account you are logged into, if necessary.
3334
settings.model = settings.model.format(aws_account=settings.aws_account)
3435
settings.data_dir = settings.data_dir.format(aws_account=settings.aws_account)
36+
settings.deploy_bucket_name = settings.deploy_bucket_name.format(aws_account=settings.aws_account)
37+
settings.sagemaker_execution_role_arn = settings.sagemaker_execution_role_arn.format(aws_account=settings.aws_account)
3538

3639
if args.custom_prompts:
3740
settings.custom_prompts = True
@@ -54,6 +57,7 @@ def parse_args() -> argparse.Namespace:
5457
"human_judge_upload",
5558
"human_judge_parsing",
5659
"show_results",
60+
"deploy"
5761
],
5862
help="Action to perform (generate data or run inference)",
5963
)
@@ -84,6 +88,10 @@ def parse_args() -> argparse.Namespace:
8488
"--custom-prompts", default=False, help="Load custom prompts from a prompt folder"
8589
)
8690

91+
parser.add_argument(
92+
"--cleanup_endpoints", action='store_true'
93+
)
94+
8795
return parser.parse_args()
8896

8997

@@ -117,6 +125,9 @@ def handle_judge(args, settings):
117125
def handle_show_results(args, settings):
118126
show_results(settings, args.type)
119127

128+
def handle_deploy(args, settings):
129+
deploy(settings, args.cleanup_endpoints)
130+
120131

121132
def main():
122133
args = parse_args()
@@ -131,6 +142,8 @@ def main():
131142
handle_judge(args, settings)
132143
case "show_results":
133144
handle_show_results(args, settings)
145+
case "deploy":
146+
handle_deploy(args, settings)
134147
case _:
135148
raise ValueError(f"Unknown action: {args.action}")
136149

0 commit comments

Comments
 (0)