From b2e77b2864b00ad265e8b5ebb7a4b16289a13ece Mon Sep 17 00:00:00 2001 From: Droid Date: Wed, 1 May 2024 21:35:46 +0000 Subject: [PATCH 1/3] Implemented DBRX truss with model loading and prediction. Added requirements, config, and README files. --- dbrx_truss/README.md | 43 ++++++++++++++++++++++++++ dbrx_truss/__init__.py | 1 + dbrx_truss/config.yaml | 13 ++++++++ dbrx_truss/model/__init__.py | 1 + dbrx_truss/model/model.py | 60 ++++++++++++++++++++++++++++++++++++ dbrx_truss/requirements.txt | 4 +++ 6 files changed, 122 insertions(+) create mode 100644 dbrx_truss/README.md create mode 100644 dbrx_truss/__init__.py create mode 100644 dbrx_truss/config.yaml create mode 100644 dbrx_truss/model/__init__.py create mode 100644 dbrx_truss/model/model.py create mode 100644 dbrx_truss/requirements.txt diff --git a/dbrx_truss/README.md b/dbrx_truss/README.md new file mode 100644 index 000000000..6c35ec706 --- /dev/null +++ b/dbrx_truss/README.md @@ -0,0 +1,43 @@ +# DBRX Truss + +This truss makes the [DBRX](https://huggingface.co/databricks/dbrx-instruct) model available on the Baseten platform for efficient inference. DBRX is an open-source large language model trained by Databricks. It is a 132B parameter model capable of instruction following and general language tasks. + +## Setup + +This truss requires Python 3.11 and the dependencies listed in `requirements.txt`. It is configured to run on A10G GPUs for optimal performance. + +## Usage + +Once deployed on Baseten, the truss exposes an endpoint for making prediction requests to the model. + +### Request Format + +Requests should be made with a JSON payload in the following format: + +```json +{ + "prompt": "What is machine learning?" +} +``` + +### Parameters + +The following inference parameters can be configured in `config.yaml`: + +- `max_new_tokens`: Max number of tokens to generate in the response (default: 100) +- `temperature`: Controls randomness of output (default: 0.7) +- `top_p`: Nucleus sampling probability threshold (default: 0.95) +- `top_k`: Number of highest probability vocabulary tokens to keep (default: 50) +- `repetition_penalty`: Penalty for repeated tokens (default: 1.01) + +## Original Model + +DBRX was developed and open-sourced by Databricks. For more information, see: + +- [DBRX Model Card](https://github.com/databricks/dbrx/blob/master/MODEL_CARD_dbrx_instruct.md) +- [Databricks Blog Post](https://www.databricks.com/blog/introducing-dbrx-new-state-art-open-llm) +- [HuggingFace Model Page](https://huggingface.co/databricks/dbrx-instruct) + +## About Baseten + +This truss was created by [Baseten](https://www.baseten.co/) to enable easy deployment and serving of the open-source DBRX model at scale. Baseten is a platform for building powerful AI apps. diff --git a/dbrx_truss/__init__.py b/dbrx_truss/__init__.py new file mode 100644 index 000000000..932b79829 --- /dev/null +++ b/dbrx_truss/__init__.py @@ -0,0 +1 @@ +# Empty file diff --git a/dbrx_truss/config.yaml b/dbrx_truss/config.yaml new file mode 100644 index 000000000..b05df900a --- /dev/null +++ b/dbrx_truss/config.yaml @@ -0,0 +1,13 @@ +python_version: py311 +requirements_file: requirements.txt + +resources: + accelerator: A10G + use_gpu: true + +model_metadata: + example_model_input: | + { + "prompt": "What is machine learning?" + } + repo_id: databricks/dbrx-instruct diff --git a/dbrx_truss/model/__init__.py b/dbrx_truss/model/__init__.py new file mode 100644 index 000000000..932b79829 --- /dev/null +++ b/dbrx_truss/model/__init__.py @@ -0,0 +1 @@ +# Empty file diff --git a/dbrx_truss/model/model.py b/dbrx_truss/model/model.py new file mode 100644 index 000000000..f24cb4fe3 --- /dev/null +++ b/dbrx_truss/model/model.py @@ -0,0 +1,60 @@ +from typing import Dict + +import torch +from transformers import AutoModelForCausalLM, AutoTokenizer + + +class Model: + def __init__(self, data_dir: str, config: Dict, **kwargs): + self.data_dir = data_dir + self.config = config + self.cuda_available = torch.cuda.is_available() + + def load(self): + self.tokenizer = AutoTokenizer.from_pretrained( + "databricks/dbrx-instruct", trust_remote_code=True, token=True + ) + + if self.cuda_available: + self.model = AutoModelForCausalLM.from_pretrained( + "databricks/dbrx-instruct", + trust_remote_code=True, + token=True, + torch_dtype=( + torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16 + ), + device_map="auto", + attn_implementation=( + "flash_attention_2" if "flash_attn" in locals() else "eager" + ), + ) + else: + self.model = AutoModelForCausalLM.from_pretrained( + "databricks/dbrx-instruct", trust_remote_code=True, token=True + ) + + def predict(self, request: Dict) -> Dict: + self.load() # Reload model for each request + + prompt = request["prompt"] + messages = [{"role": "user", "content": prompt}] + + tokenized_input = self.tokenizer.apply_chat_template( + messages, tokenize=True, add_generation_prompt=True, return_tensors="pt" + ) + tokenized_input = tokenized_input.to(self.model.device) + + generated = self.model.generate( + input_ids=tokenized_input, + max_new_tokens=self.config.get("max_new_tokens", 100), + temperature=self.config.get("temperature", 0.7), + top_p=self.config.get("top_p", 0.95), + top_k=self.config.get("top_k", 50), + repetition_penalty=self.config.get("repetition_penalty", 1.01), + pad_token_id=self.tokenizer.pad_token_id, + ) + + decoded_output = self.tokenizer.batch_decode(generated)[0] + response_text = decoded_output.split("<|im_start|> assistant\n")[-1] + + return {"result": response_text} diff --git a/dbrx_truss/requirements.txt b/dbrx_truss/requirements.txt new file mode 100644 index 000000000..b52dd9ca0 --- /dev/null +++ b/dbrx_truss/requirements.txt @@ -0,0 +1,4 @@ +torch>=2.1.0 +transformers>=4.39.0 +accelerate==0.28.0 +tiktoken==0.4.0 From 42c30826c687048fd53f8f9bac6b729af7cd3158 Mon Sep 17 00:00:00 2001 From: Droid Date: Wed, 1 May 2024 21:40:35 +0000 Subject: [PATCH 2/3] Address review by @factory-droid[bot] on pull request #283 --- .github/workflows/truss_deploy.yml | 1 + 1 file changed, 1 insertion(+) diff --git a/.github/workflows/truss_deploy.yml b/.github/workflows/truss_deploy.yml index 188871205..ec3bfe3da 100644 --- a/.github/workflows/truss_deploy.yml +++ b/.github/workflows/truss_deploy.yml @@ -27,6 +27,7 @@ jobs: run: | python -m pip install --upgrade pip pip install git+https://github.com/basetenlabs/truss.git requests tenacity --upgrade + BASETEN_API_KEY: ${{ secrets.BASETEN_API_KEY }} - name: Run tests env: From 17f9d35a024c5d1648119c45e43785ac7a275c4f Mon Sep 17 00:00:00 2001 From: Droid Date: Wed, 1 May 2024 21:44:24 +0000 Subject: [PATCH 3/3] Address review by @factory-droid[bot] on pull request #283 --- .github/workflows/truss_deploy.yml | 2 -- 1 file changed, 2 deletions(-) diff --git a/.github/workflows/truss_deploy.yml b/.github/workflows/truss_deploy.yml index ec3bfe3da..047955dac 100644 --- a/.github/workflows/truss_deploy.yml +++ b/.github/workflows/truss_deploy.yml @@ -27,8 +27,6 @@ jobs: run: | python -m pip install --upgrade pip pip install git+https://github.com/basetenlabs/truss.git requests tenacity --upgrade - BASETEN_API_KEY: ${{ secrets.BASETEN_API_KEY }} - - name: Run tests env: BASETEN_API_KEY: ${{ secrets.BASETEN_API_KEY }}