diff --git a/distilbert/README.md b/distilbert/README.md new file mode 100644 index 000000000..d536ff448 --- /dev/null +++ b/distilbert/README.md @@ -0,0 +1,15 @@ +# DistilBERT +This truss runs the [DistilBERT](https://huggingface.co/docs/transformers/en/model_doc/distilbert) model as an endpoint on Baseten. + +## Deploy +``` +pip install --upgrade truss +truss push --publish # grab an api key from https://app.baseten.co/settings/api_keys +``` + +The deployment will take a few minutes the first. Once it's ready in the you UI you can proceed to calling the API. + +## Test +``` +truss predict --published -d '{"text": "some text to embed"}' +``` \ No newline at end of file diff --git a/distilbert/config.yaml b/distilbert/config.yaml new file mode 100644 index 000000000..3f15fbd8d --- /dev/null +++ b/distilbert/config.yaml @@ -0,0 +1,6 @@ + +model_name: DistilBert +python_version: py310 +requirements_file: ./requirements.txt +resources: + accelerator: T4 diff --git a/distilbert/model/__init__.py b/distilbert/model/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/distilbert/model/model.py b/distilbert/model/model.py new file mode 100644 index 000000000..ede40ef2f --- /dev/null +++ b/distilbert/model/model.py @@ -0,0 +1,30 @@ +import torch +from transformers import AutoTokenizer, AutoModel + + +class Model: + def __init__(self, **kwargs): + self._model = None + + def load(self): + # Load model here and assign to self._model. + self.device = ( + "cuda" if torch.cuda.is_available() else "cpu" + ) # the device to load the model onto + + self._tokenizer = AutoTokenizer.from_pretrained( + "distilbert/distilbert-base-uncased", device=self.device + ) + self._model = AutoModel.from_pretrained( + "distilbert/distilbert-base-uncased", + torch_dtype=torch.float16, + ).to(self.device) + + def predict(self, model_input): + # Run model inference here + + text = model_input.get("text") + + encoded_input = self._tokenizer(text, return_tensors='pt').to(self.device) + + return self._model(**encoded_input).last_hidden_state.tolist() diff --git a/distilbert/requirements.txt b/distilbert/requirements.txt new file mode 100644 index 000000000..c42de737b --- /dev/null +++ b/distilbert/requirements.txt @@ -0,0 +1,3 @@ +hf-transfer==0.1.6 +torch==2.2.2 +transformers==4.40.0