Skip to content

Commit 74eeebd

Browse files
authored
Merge pull request #12 from EetuaLaine/main
Restructure setup.py to ensure executability from any directory
2 parents 63abd3f + 1e94300 commit 74eeebd

File tree

11 files changed

+160
-152
lines changed

11 files changed

+160
-152
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,4 +8,6 @@ src/wraval.egg-info/
88
prompts/*
99
.idea
1010
src/wraval/custom_prompts/*
11-
src/wraval/testing.py
11+
src/wraval/testing.py
12+
src/wraval/model_artifacts/*
13+
!src/wraval/model_artifacts/code/

README.md

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,15 @@ WRAVAL helps in evaluating LLMs for writing assistant tasks like summarization,
44

55
## Quick start
66

7+
Disclaimer: this project requires a machine that supports bitsandbytes and CUDA.
8+
9+
Before installing, execute the following to ensure correct dependencies:
10+
11+
```bash
12+
pip install uv
13+
uv pip compile pyproject.toml -o requirements.txt
14+
```
15+
716
```bash
817
pip install -e .
918
wraval generate

config/settings.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,7 @@
11
[default]
22
region = 'us-east-1'
33
data_dir = 's3://llm-finetune-us-east-1-{aws_account}/eval/tones/'
4-
# "./data"
5-
deploy_bucket_name = 's3://llm-finetune-us-east-1-{aws_account}'
4+
deploy_bucket_name = 'llm-finetune-us-east-1-{aws_account}'
65
deploy_bucket_prefix = 'models'
76
sagemaker_execution_role_arn = 'arn:aws:iam::{aws_account}:role/sagemaker-execution-role-us-east-1'
87

pyproject.toml

Lines changed: 15 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,23 @@ authors = [{ name = "Gabriel Benedict", email = "[email protected]" }]
1313

1414
# Dependencies – see note below for loading from requirements.txt
1515
dependencies = [
16-
"tqdm==4.66.4",
17-
"pandas==2.2.3",
18-
"beautifulsoup4==4.12.3",
19-
"boto3==1.34.143",
20-
"plotly==5.24.1",
21-
"transformers==4.48.1",
22-
"datasets==3.2.0",
23-
"evaluate==0.4.3",
24-
"dynaconf==3.2.7",
25-
"torch~=2.6.0",
26-
"botocore~=1.34.162",
27-
"sagemaker",
16+
"tqdm~=4.66.4",
17+
"pandas~=2.2.3",
18+
"beautifulsoup4~=4.12.3",
19+
"boto3",
20+
"plotly~=5.24.1",
21+
"transformers==4.51.0",
22+
"datasets~=3.2.0",
23+
"evaluate~=0.4.3",
24+
"dynaconf~=3.2.7",
25+
"torch",
26+
"botocore",
27+
"sagemaker==2.244.2",
2828
"numpy",
2929
"requests",
30-
"bitsandbytes",
31-
"accelerate"
30+
"bitsandbytes==0.45.5",
31+
"accelerate",
32+
"torchvision"
3233
]
3334

3435
[project.scripts]

requirements.txt

Lines changed: 68 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
# This file was autogenerated by uv via the following command:
22
# uv pip compile pyproject.toml -o requirements.txt
3+
accelerate==1.1.1
4+
# via wraval (pyproject.toml)
35
aiohappyeyeballs==2.6.1
46
# via aiohttp
57
aiohttp==3.11.18
@@ -14,6 +16,8 @@ antlr4-python3-runtime==4.9.3
1416
# via omegaconf
1517
anyio==4.9.0
1618
# via starlette
19+
async-timeout==5.0.1
20+
# via aiohttp
1721
attrs==23.2.0
1822
# via
1923
# aiohttp
@@ -22,6 +26,8 @@ attrs==23.2.0
2226
# sagemaker
2327
beautifulsoup4==4.12.3
2428
# via wraval (pyproject.toml)
29+
bitsandbytes==0.44.1
30+
# via wraval (pyproject.toml)
2531
boto3==1.34.143
2632
# via
2733
# wraval (pyproject.toml)
@@ -56,6 +62,8 @@ dynaconf==3.2.7
5662
# via wraval (pyproject.toml)
5763
evaluate==0.4.3
5864
# via wraval (pyproject.toml)
65+
exceptiongroup==1.3.0
66+
# via anyio
5967
fastapi==0.115.12
6068
# via sagemaker
6169
filelock==3.18.0
@@ -82,6 +90,7 @@ hf-xet==1.1.0
8290
# via huggingface-hub
8391
huggingface-hub==0.31.1
8492
# via
93+
# accelerate
8594
# datasets
8695
# evaluate
8796
# tokenizers
@@ -131,15 +140,53 @@ networkx==3.4.2
131140
numpy==1.26.4
132141
# via
133142
# wraval (pyproject.toml)
143+
# accelerate
144+
# bitsandbytes
134145
# datasets
135146
# evaluate
136147
# pandas
137148
# sagemaker
149+
# torchvision
138150
# transformers
151+
nvidia-cublas-cu12==12.4.5.8
152+
# via
153+
# nvidia-cudnn-cu12
154+
# nvidia-cusolver-cu12
155+
# torch
156+
nvidia-cuda-cupti-cu12==12.4.127
157+
# via torch
158+
nvidia-cuda-nvrtc-cu12==12.4.127
159+
# via torch
160+
nvidia-cuda-runtime-cu12==12.4.127
161+
# via torch
162+
nvidia-cudnn-cu12==9.1.0.70
163+
# via torch
164+
nvidia-cufft-cu12==11.2.1.3
165+
# via torch
166+
nvidia-curand-cu12==10.3.5.147
167+
# via torch
168+
nvidia-cusolver-cu12==11.6.1.9
169+
# via torch
170+
nvidia-cusparse-cu12==12.3.1.170
171+
# via
172+
# nvidia-cusolver-cu12
173+
# torch
174+
nvidia-cusparselt-cu12==0.6.2
175+
# via torch
176+
nvidia-nccl-cu12==2.21.5
177+
# via torch
178+
nvidia-nvjitlink-cu12==12.4.127
179+
# via
180+
# nvidia-cusolver-cu12
181+
# nvidia-cusparse-cu12
182+
# torch
183+
nvidia-nvtx-cu12==12.4.127
184+
# via torch
139185
omegaconf==2.2.3
140186
# via sagemaker
141187
packaging==25.0
142188
# via
189+
# accelerate
143190
# datasets
144191
# evaluate
145192
# huggingface-hub
@@ -154,6 +201,8 @@ pandas==2.2.3
154201
# sagemaker
155202
pathos==0.3.2
156203
# via sagemaker
204+
pillow==11.2.1
205+
# via torchvision
157206
platformdirs==4.3.8
158207
# via
159208
# sagemaker
@@ -171,7 +220,9 @@ propcache==0.3.1
171220
protobuf==4.25.7
172221
# via sagemaker
173222
psutil==7.0.0
174-
# via sagemaker
223+
# via
224+
# accelerate
225+
# sagemaker
175226
pyarrow==20.0.0
176227
# via datasets
177228
pydantic==2.11.4
@@ -190,6 +241,7 @@ pytz==2025.2
190241
# via pandas
191242
pyyaml==6.0.2
192243
# via
244+
# accelerate
193245
# datasets
194246
# huggingface-hub
195247
# omegaconf
@@ -220,15 +272,15 @@ rpds-py==0.24.0
220272
s3transfer==0.10.4
221273
# via boto3
222274
safetensors==0.5.3
223-
# via transformers
275+
# via
276+
# accelerate
277+
# transformers
224278
sagemaker==2.236.0
225279
# via wraval (pyproject.toml)
226280
sagemaker-core==1.0.16
227281
# via sagemaker
228282
schema==0.7.7
229283
# via sagemaker
230-
setuptools==80.3.1
231-
# via torch
232284
six==1.17.0
233285
# via
234286
# google-pasta
@@ -250,6 +302,12 @@ tenacity==9.1.2
250302
tokenizers==0.21.1
251303
# via transformers
252304
torch==2.6.0
305+
# via
306+
# wraval (pyproject.toml)
307+
# accelerate
308+
# bitsandbytes
309+
# torchvision
310+
torchvision==0.21.0
253311
# via wraval (pyproject.toml)
254312
tqdm==4.66.4
255313
# via
@@ -261,16 +319,22 @@ tqdm==4.66.4
261319
# transformers
262320
transformers==4.48.1
263321
# via wraval (pyproject.toml)
322+
triton==3.2.0
323+
# via torch
264324
typing-extensions==4.13.2
265325
# via
266326
# anyio
327+
# exceptiongroup
267328
# fastapi
268329
# huggingface-hub
330+
# multidict
269331
# pydantic
270332
# pydantic-core
271333
# referencing
334+
# rich
272335
# torch
273336
# typing-inspection
337+
# uvicorn
274338
typing-inspection==0.4.0
275339
# via pydantic
276340
tzdata==2025.2

setup.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,11 @@
1010
package_dir={"": "src"},
1111
install_requires=required,
1212
data_files=[
13-
('config', ['config/settings.toml'])
13+
('config', ['config/settings.toml']),
14+
('model_artifacts/code', [
15+
'src/wraval/model_artifacts/code/inference.py',
16+
'src/wraval/model_artifacts/code/requirements.txt'
17+
])
1418
],
1519
include_package_data=True,
1620
entry_points={

src/wraval/actions/action_deploy.py

Lines changed: 56 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,22 +1,39 @@
1-
from argparse import ArgumentParser
2-
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3-
from sagemaker.huggingface import HuggingFaceModel
4-
import torch
1+
import json
2+
import os
53
import tarfile
4+
from argparse import ArgumentParser
5+
66
import boto3
7-
import json
7+
import torch
8+
from sagemaker.huggingface import HuggingFaceModel
9+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
10+
11+
PACKAGE_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
12+
MODEL_DIRECTORY = os.path.join(PACKAGE_DIR, "model_artifacts")
13+
CODE_PATH = "code"
814

9-
MODEL_DIRECTORY = '../../../model_artifacts'
15+
def parse_args():
16+
arg_parser = ArgumentParser()
17+
arg_parser.add_argument("--model_name", type=str, required=True, choices=(
18+
"Qwen/Qwen2.5-1.5B-Instruct",
19+
"microsoft/Phi-3.5-mini-instruct",
20+
"microsoft/Phi-4-mini-instruct"
21+
)
22+
)
23+
arg_parser.add_argument("--bucket_name", type=str, required=True)
24+
arg_parser.add_argument("--bucket_prefix", type=str, required=True)
25+
arg_parser.add_argument("--sagemaker_execution_role_arn", type=str, required=True)
26+
return arg_parser.parse_args()
1027

1128
def cleanup_endpoints(endpoint_name):
1229

13-
sagemaker_client = boto3.client("sagemaker", region_name='us-east-1')
30+
sagemaker_client = boto3.client("sagemaker", region_name="us-east-1")
1431

15-
endpoints = sagemaker_client.list_endpoints()['Endpoints']
16-
endpoints_configs = sagemaker_client.list_endpoint_configs()['EndpointConfigs']
32+
endpoints = sagemaker_client.list_endpoints()["Endpoints"]
33+
endpoints_configs = sagemaker_client.list_endpoint_configs()["EndpointConfigs"]
1734

18-
endpoints_names = [e['EndpointName'] for e in endpoints]
19-
endpoints_configs_names = [e['EndpointConfigName'] for e in endpoints_configs]
35+
endpoints_names = [e["EndpointName"] for e in endpoints]
36+
endpoints_configs_names = [e["EndpointConfigName"] for e in endpoints_configs]
2037

2138
if endpoint_name in endpoints_names:
2239
sagemaker_client.delete_endpoint(EndpointConfigName=endpoint_name)
@@ -44,14 +61,14 @@ def load_artifacts(settings):
4461

4562
def write_model_to_s3(settings, model_name):
4663
tar_file_name = f"{model_name}.tar.gz"
47-
64+
4865
with tarfile.open(tar_file_name, "w:gz") as tar:
4966
tar.add(MODEL_DIRECTORY, arcname=".")
50-
67+
5168
s3_model_path = f"{settings.deploy_bucket_prefix}/{tar_file_name}"
5269
s3_client = boto3.client("s3")
5370
s3_client.upload_file(tar_file_name, settings.deploy_bucket_name, s3_model_path)
54-
71+
5572
s3_uri = f"s3://{settings.deploy_bucket_name}/{s3_model_path}"
5673
print(f"Model uploaded to: {s3_uri}")
5774
return s3_uri
@@ -92,15 +109,35 @@ def validate_deployment(predictor):
92109
print(f"Validation failed: {e}")
93110
raise e
94111

95-
def deploy(settings, cleanup_endpoints=False):
96-
sanitized_model_name = settings.hf_name.split('/')[1].replace('.', '-')
97-
if cleanup_endpoints:
98-
cleanup_endpoints(sanitized_model_name)
112+
def validate_model_directory():
113+
endpoint_code_path = os.path.join(MODEL_DIRECTORY, CODE_PATH)
114+
inference_script_name = "inference.py"
115+
requirements_name = "requirements.txt"
116+
if not os.path.isdir(endpoint_code_path):
117+
raise ValueError(f"{endpoint_code_path} is missing.")
118+
if not os.path.isfile(os.path.join(endpoint_code_path, inference_script_name)):
119+
raise ValueError(f"{inference_script_name} is missing from the code directory.")
120+
if not os.path.isfile(os.path.join(endpoint_code_path, requirements_name)):
121+
raise ValueError(f"{requirements_name} is missing from the code directory.")
122+
123+
124+
def cleanup_model_directory():
125+
for item in os.listdir(MODEL_DIRECTORY):
126+
item_path = os.path.join(MODEL_DIRECTORY, item)
127+
if item == CODE_PATH:
128+
continue
129+
if os.path.isfile(item_path):
130+
os.remove(item_path)
131+
132+
def deploy(settings):
133+
validate_model_directory()
134+
cleanup_model_directory()
135+
sanitized_model_name = settings.hf_name.split("/")[1].replace(".", "-")
99136
load_artifacts(settings)
100137
s3_uri = write_model_to_s3(settings, sanitized_model_name)
101138
predictor = deploy_endpoint(
102139
s3_uri,
103-
settings.sagemaker_execution_role_arn,
140+
settings.sagemaker_execution_role_arn,
104141
sanitized_model_name
105142
)
106143
validate_deployment(predictor)

0 commit comments

Comments
 (0)