Skip to content

Commit bc131da

Browse files
committed
action deploy - add model directory
1 parent 38453cb commit bc131da

File tree

3 files changed

+37
-0
lines changed

3 files changed

+37
-0
lines changed

model_artifacts/inference.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
3+
4+
def model_fn(model_dir, *args):
5+
# Load model from HuggingFace Hub
6+
bnb_config = BitsAndBytesConfig(
7+
load_in_4bit=True,
8+
bnb_4bit_quant_type="nf4",
9+
bnb_4bit_use_double_quant=True,
10+
bnb_4bit_compute_dtype=torch.bfloat16
11+
)
12+
model = AutoModelForCausalLM.from_pretrained(
13+
model_dir,
14+
device_map="auto",
15+
quantization_config=bnb_config
16+
)
17+
tokenizer = AutoTokenizer.from_pretrained(model_dir)
18+
return model, tokenizer
19+
20+
def predict_fn(data, model_and_tokenizer, *args):
21+
# destruct model and tokenizer
22+
model, tokenizer = model_and_tokenizer
23+
# Tokenize sentences
24+
sentences = data.pop("inputs", data)
25+
tokenizer.padding_side = "left"
26+
tokenizer.pad_token = tokenizer.eos_token
27+
model.config.pad_token_id = model.config.eos_token_id
28+
inputs = tokenizer(sentences, return_tensors="pt", padding=True).to(model.device)
29+
output_sequences = model.generate(**inputs, max_new_tokens=20)
30+
return tokenizer.batch_decode(output_sequences, skip_special_tokens=True)

model_artifacts/requirements.txt

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
bitsandbytes==0.44.1
2+
accelerate==1.6.0
3+
transformers==4.51.1
4+
torch==2.5.0
5+
torchvision==0.20

src/wraval/actions/action_deploy.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@
66
import boto3
77
import json
88

9+
MODEL_DIRECTORY = '../../model_artifacts'
10+
911
def cleanup_endpoints(endpoint_name):
1012

1113
sagemaker_client = boto3.client("sagemaker", region_name='us-east-1')

0 commit comments

Comments
 (0)