diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 24d397a..91aa267 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -85,3 +85,4 @@ jobs: CF_API_TOKEN: ${{ secrets.CF_API_TOKEN }} CF_ACCOUNT_ID: ${{ secrets.CF_ACCOUNT_ID }} CF_GATEWAY_ENDPOINT: ${{ secrets.CF_GATEWAY_ENDPOINT }} + TOGETHER_API_KEY: ${{ secrets.TOGETHER_API_KEY }} diff --git a/README.md b/README.md index 7967071..76736df 100644 --- a/README.md +++ b/README.md @@ -17,6 +17,7 @@ pip install chromadbx - [Mistral AI](https://github.com/amikos-tech/chromadbx/blob/main/docs/embeddings.md#mistral-ai) embeddings - [Cloudflare Workers AI](https://github.com/amikos-tech/chromadbx/blob/main/docs/embeddings.md#cloudflare-workers-ai) embeddings - [SpaCy](https://github.com/amikos-tech/chromadbx/blob/main/docs/embeddings.md#spacy) embeddings + - [Together](https://github.com/amikos-tech/chromadbx/blob/main/docs/embeddings.md#together) embeddings ## Usage diff --git a/chromadbx/embeddings/together.py b/chromadbx/embeddings/together.py new file mode 100644 index 0000000..85f21a1 --- /dev/null +++ b/chromadbx/embeddings/together.py @@ -0,0 +1,56 @@ +from typing import Optional, cast + +from chromadb.api.types import Documents, Embeddings, EmbeddingFunction + + +class TogetherEmbeddingFunction(EmbeddingFunction[Documents]): # type: ignore[misc] + """ + This class is used to get embeddings for a list of texts from together's embedding models. + It requires an API key and a model name. The default model name is "togethercomputer/m2-bert-80M-8k-retrieval". + For more, refer to the official documentation at "https://docs.together.ai/docs/embeddings-python". + """ + + def __init__( + self, + api_key: str, + model_name: Optional[str] = "togethercomputer/m2-bert-80M-8k-retrieval", + ): + """ + Initialize the TogetherEmbeddingFunction. + + Args: + api_key (str): The API key for the Together API. + model_name (Optional[str]): The name of the model to use for embedding. Defaults to "togethercomputer/m2-bert-80M-8k-retrieval". + """ + + try: + import together + except ImportError: + raise ValueError( + "The together python package is not installed. Please install it with `pip install together`" + ) + together.api_key = api_key + self.model_name = model_name + self.client = together.Together() + + def __call__(self, input: Documents) -> Embeddings: + """ + Get the embeddings for a list of texts. + + Args: + input (Documents): A list of texts to get embeddings for. + + Returns: + Embeddings: The embeddings for the texts. + + Example: + ```python + import os + from chromadbx.embeddings.together import TogetherEmbeddingFunction + + ef = TogetherEmbeddingFunction(api_key=os.getenv("TOGETHER_API_KEY")) + embeddings = ef(["hello world", "goodbye world"]) + ``` + """ + outputs = self.client.embeddings.create(input=input, model=self.model_name) + return cast(Embeddings, [outputs.data[i].embedding for i in range(len(input))]) diff --git a/docs/embeddings.md b/docs/embeddings.md index d3fa71b..34351e7 100644 --- a/docs/embeddings.md +++ b/docs/embeddings.md @@ -229,3 +229,27 @@ col = client.get_or_create_collection("test", embedding_function=ef) col.add(ids=["id1", "id2", "id3"], documents=["lorem ipsum...", "doc2", "doc3"]) ``` + +## Together + +A convenient way to generate embeddings using Together models. To use the embedding function, you need to install the `together` package. + +```bash +pip install together +``` + +Additionally, you need to get an [API key from Together](https://api.together.xyz/settings/api-keys). + +```py +import os +import chromadb +from chromadbx.embeddings.together import TogetherEmbeddingFunction + +ef = TogetherEmbeddingFunction(api_key=os.getenv("TOGETHER_API_KEY")) + +client = chromadb.Client() + +col = client.get_or_create_collection("test", embedding_function=ef) + +col.add(ids=["id1", "id2", "id3"], documents=["lorem ipsum...", "doc2", "doc3"]) +``` diff --git a/pyproject.toml b/pyproject.toml index c9e0a1c..ab93bbb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -32,6 +32,7 @@ huggingface_hub = "^0.24.6" llama-embedder = "^0.0.7" mistralai = "^1.1.0" spacy = "^3.8.4" +together = "^1.3.11" [tool.poetry.extras] diff --git a/test/embeddings/test_together.py b/test/embeddings/test_together.py new file mode 100644 index 0000000..23b132b --- /dev/null +++ b/test/embeddings/test_together.py @@ -0,0 +1,18 @@ +import os +import pytest +from chromadbx.embeddings.together import TogetherEmbeddingFunction + +together = pytest.importorskip("together", reason="together not installed") + + +@pytest.mark.skipif( + os.getenv("TOGETHER_API_KEY") is None, + reason="TOGETHER_API_KEY environment variable is not set", +) +def test_together() -> None: + ef = TogetherEmbeddingFunction(api_key=os.getenv("TOGETHER_API_KEY", "")) + texts = ["hello world", "goodbye world"] + embeddings = ef(texts) + assert embeddings is not None + assert len(embeddings) == 2 + assert len(embeddings[0]) == 768