diff --git a/databricks-dbrx-instruct/README.md b/databricks-dbrx-instruct/README.md new file mode 100644 index 000000000..f8791b45e --- /dev/null +++ b/databricks-dbrx-instruct/README.md @@ -0,0 +1,23 @@ +# Databricks DBRX Instruct Truss + +This Truss packages the DBRX-Instruct model from Databricks. DBRX-Instruct is an instruction-following language model that can be used for various language tasks. + +## Deploying + +To deploy this model using Truss, follow these steps: + +1. Clone this repo +2. Set up a Baseten account and install the Truss CLI +3. Run `truss deploy` to deploy the model on Baseten + +## Model Overview + +// TODO: Add a brief overview of the DBRX-Instruct model and its key capabilities + +## API Documentation + +// TODO: Document the key API endpoints, request parameters, and response format + +## Example Usage + +// TODO: Provide example code snippets demonstrating how to use the deployed model via its API diff --git a/databricks-dbrx-instruct/config.yaml b/databricks-dbrx-instruct/config.yaml new file mode 100644 index 000000000..932b79829 --- /dev/null +++ b/databricks-dbrx-instruct/config.yaml @@ -0,0 +1 @@ +# Empty file diff --git a/databricks-dbrx-instruct/model/__init__.py b/databricks-dbrx-instruct/model/__init__.py new file mode 100644 index 000000000..932b79829 --- /dev/null +++ b/databricks-dbrx-instruct/model/__init__.py @@ -0,0 +1 @@ +# Empty file diff --git a/databricks-dbrx-instruct/model/model.py b/databricks-dbrx-instruct/model/model.py new file mode 100644 index 000000000..06337caec --- /dev/null +++ b/databricks-dbrx-instruct/model/model.py @@ -0,0 +1,41 @@ +import logging + +from transformers import AutoModelForCausalLM, AutoTokenizer + +logger = logging.getLogger(__name__) + + +class Model: + def __init__(self, model_name="databricks/dbrx-instruct") -> None: + self.model_name = model_name + self.model = None + self.tokenizer = None + + def load(self): + try: + self.tokenizer = AutoTokenizer.from_pretrained(self.model_name) + self.model = AutoModelForCausalLM.from_pretrained(self.model_name) + except Exception as e: + logger.error(f"Failed to load model {self.model_name}: {e}") + raise + + def preprocess(self, request: dict) -> dict: + prompt = request.get("prompt", "") + return {"input_ids": self.tokenizer.encode(prompt, return_tensors="pt")} + + def postprocess(self, output) -> dict: + return { + "generated_text": self.tokenizer.decode(output[0], skip_special_tokens=True) + } + + def predict(self, request: dict) -> dict: + try: + processed_input = self.preprocess(request) + output = self.model.generate(**processed_input) + return self.postprocess(output) + except Exception as e: + logger.error(f"Prediction failed: {e}") + raise + + +# Empty file