File tree Expand file tree Collapse file tree 5 files changed +54
-0
lines changed Expand file tree Collapse file tree 5 files changed +54
-0
lines changed Original file line number Diff line number Diff line change
1
+ # DistilBERT
2
+ This truss runs the [ DistilBERT] ( https://huggingface.co/docs/transformers/en/model_doc/distilbert ) model as an endpoint on Baseten.
3
+
4
+ ## Deploy
5
+ ```
6
+ pip install --upgrade truss
7
+ truss push --publish # grab an api key from https://app.baseten.co/settings/api_keys
8
+ ```
9
+
10
+ The deployment will take a few minutes the first. Once it's ready in the you UI you can proceed to calling the API.
11
+
12
+ ## Test
13
+ ```
14
+ truss predict --published -d '{"text": "some text to embed"}'
15
+ ```
Original file line number Diff line number Diff line change
1
+
2
+ model_name : DistilBert
3
+ python_version : py310
4
+ requirements_file : ./requirements.txt
5
+ resources :
6
+ accelerator : T4
Original file line number Diff line number Diff line change
1
+ import torch
2
+ from transformers import AutoTokenizer , AutoModel
3
+
4
+
5
+ class Model :
6
+ def __init__ (self , ** kwargs ):
7
+ self ._model = None
8
+
9
+ def load (self ):
10
+ # Load model here and assign to self._model.
11
+ self .device = (
12
+ "cuda" if torch .cuda .is_available () else "mps"
13
+ ) # the device to load the model onto
14
+
15
+ self ._tokenizer = AutoTokenizer .from_pretrained (
16
+ "distilbert/distilbert-base-uncased" , device = self .device
17
+ )
18
+ self ._model = AutoModel .from_pretrained (
19
+ "distilbert/distilbert-base-uncased" ,
20
+ torch_dtype = torch .float16 ,
21
+ ).to (self .device )
22
+
23
+ def predict (self , model_input ):
24
+ # Run model inference here
25
+
26
+ text = model_input .get ("text" )
27
+
28
+ encoded_input = self ._tokenizer (text , return_tensors = 'pt' ).to (self .device )
29
+
30
+ return self ._model (** encoded_input ).last_hidden_state .tolist ()
Original file line number Diff line number Diff line change
1
+ hf-transfer == 0.1.6
2
+ torch == 2.2.2
3
+ transformers == 4.40.0
You can’t perform that action at this time.
0 commit comments