Skip to content

Commit 4fee761

Browse files
committed
DistilBERT
1 parent 248cf40 commit 4fee761

File tree

5 files changed

+54
-0
lines changed

5 files changed

+54
-0
lines changed

distilbert/README.md

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
# DistilBERT
2+
This truss runs the [DistilBERT](https://huggingface.co/docs/transformers/en/model_doc/distilbert) model as an endpoint on Baseten.
3+
4+
## Deploy
5+
```
6+
pip install --upgrade truss
7+
truss push --publish # grab an api key from https://app.baseten.co/settings/api_keys
8+
```
9+
10+
The deployment will take a few minutes the first. Once it's ready in the you UI you can proceed to calling the API.
11+
12+
## Test
13+
```
14+
truss predict --published -d '{"text": "some text to embed"}'
15+
```

distilbert/config.yaml

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,6 @@
1+
2+
model_name: DistilBert
3+
python_version: py310
4+
requirements_file: ./requirements.txt
5+
resources:
6+
accelerator: T4

distilbert/model/__init__.py

Whitespace-only changes.

distilbert/model/model.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
import torch
2+
from transformers import AutoTokenizer, AutoModel
3+
4+
5+
class Model:
6+
def __init__(self, **kwargs):
7+
self._model = None
8+
9+
def load(self):
10+
# Load model here and assign to self._model.
11+
self.device = (
12+
"cuda" if torch.cuda.is_available() else "mps"
13+
) # the device to load the model onto
14+
15+
self._tokenizer = AutoTokenizer.from_pretrained(
16+
"distilbert/distilbert-base-uncased", device=self.device
17+
)
18+
self._model = AutoModel.from_pretrained(
19+
"distilbert/distilbert-base-uncased",
20+
torch_dtype=torch.float16,
21+
).to(self.device)
22+
23+
def predict(self, model_input):
24+
# Run model inference here
25+
26+
text = model_input.get("text")
27+
28+
encoded_input = self._tokenizer(text, return_tensors='pt').to(self.device)
29+
30+
return self._model(**encoded_input).last_hidden_state.tolist()

distilbert/requirements.txt

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,3 @@
1+
hf-transfer==0.1.6
2+
torch==2.2.2
3+
transformers==4.40.0

0 commit comments

Comments
 (0)