Skip to content

Commit ebea40c

Browse files
Add 'truncate' parameter for CohereEmbeddings (#798)
Currently, the 'truncate' parameter of the cohere API is not supported. This means that by default, if trying to generate and embedding that is too big, the call will just fail with an error (which is frustrating if using this embedding source e.g. with GPT-Index, because it's hard to handle it properly when generating a lot of embeddings). With the parameter, one can decide to either truncate the START or END of the text to fit the max token length and still generate an embedding without throwing the error. In this PR, I added this parameter to the class. _Arguably, there should be a better way to handle this error, e.g. by optionally calling a function or so that gets triggered when the token limit is reached and can split the document or some such. Especially in the use case with GPT-Index, its often hard to estimate the token counts for each document and I'd rather sort out the troublemakers or simply split them than interrupting the whole execution. Thoughts?_ --------- Co-authored-by: Harrison Chase <[email protected]>
1 parent b9045f7 commit ebea40c

File tree

1 file changed

+9
-2
lines changed

1 file changed

+9
-2
lines changed

langchain/embeddings/cohere.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,9 @@ class CohereEmbeddings(BaseModel, Embeddings):
2525
model: str = "large"
2626
"""Model name to use."""
2727

28+
truncate: str = "NONE"
29+
"""Truncate embeddings that are too long from start or end ("NONE"|"START"|"END")"""
30+
2831
cohere_api_key: Optional[str] = None
2932

3033
class Config:
@@ -58,7 +61,9 @@ def embed_documents(self, texts: List[str]) -> List[List[float]]:
5861
Returns:
5962
List of embeddings, one for each text.
6063
"""
61-
embeddings = self.client.embed(model=self.model, texts=texts).embeddings
64+
embeddings = self.client.embed(
65+
model=self.model, texts=texts, truncate=self.truncate
66+
).embeddings
6267
return embeddings
6368

6469
def embed_query(self, text: str) -> List[float]:
@@ -70,5 +75,7 @@ def embed_query(self, text: str) -> List[float]:
7075
Returns:
7176
Embeddings for the text.
7277
"""
73-
embedding = self.client.embed(model=self.model, texts=[text]).embeddings[0]
78+
embedding = self.client.embed(
79+
model=self.model, texts=[text], truncate=self.truncate
80+
).embeddings[0]
7481
return embedding

0 commit comments

Comments
 (0)