Skip to content

Commit d564308

Browse files
rfc: instruct embeddings (#811)
Co-authored-by: seanaedmiston <[email protected]>
1 parent 576609e commit d564308

File tree

4 files changed

+155
-3
lines changed

4 files changed

+155
-3
lines changed

docs/modules/utils/combine_docs_examples/embeddings.ipynb

Lines changed: 59 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -255,12 +255,70 @@
255255
"query_result = embeddings.embed_query(text)"
256256
]
257257
},
258+
{
259+
"cell_type": "markdown",
260+
"id": "59428e05",
261+
"metadata": {},
262+
"source": [
263+
"## InstructEmbeddings\n",
264+
"Let's load the HuggingFace instruct Embeddings class."
265+
]
266+
},
258267
{
259268
"cell_type": "code",
260-
"execution_count": null,
269+
"execution_count": 8,
270+
"id": "92c5b61e",
271+
"metadata": {},
272+
"outputs": [],
273+
"source": [
274+
"from langchain.embeddings import HuggingFaceInstructEmbeddings"
275+
]
276+
},
277+
{
278+
"cell_type": "code",
279+
"execution_count": 9,
280+
"id": "062547b9",
281+
"metadata": {},
282+
"outputs": [
283+
{
284+
"name": "stdout",
285+
"output_type": "stream",
286+
"text": [
287+
"load INSTRUCTOR_Transformer\n",
288+
"max_seq_length 512\n"
289+
]
290+
}
291+
],
292+
"source": [
293+
"embeddings = HuggingFaceInstructEmbeddings(query_instruction=\"Represent the query for retrieval: \")"
294+
]
295+
},
296+
{
297+
"cell_type": "code",
298+
"execution_count": 10,
299+
"id": "e1dcc4bd",
300+
"metadata": {},
301+
"outputs": [],
302+
"source": [
303+
"text = \"This is a test document.\""
304+
]
305+
},
306+
{
307+
"cell_type": "code",
308+
"execution_count": 11,
261309
"id": "90f0db94",
262310
"metadata": {},
263311
"outputs": [],
312+
"source": [
313+
"query_result = embeddings.embed_query(text)"
314+
]
315+
},
316+
{
317+
"cell_type": "code",
318+
"execution_count": null,
319+
"id": "a961cdb5",
320+
"metadata": {},
321+
"outputs": [],
264322
"source": []
265323
}
266324
],

langchain/embeddings/__init__.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,10 @@
33
from typing import Any
44

55
from langchain.embeddings.cohere import CohereEmbeddings
6-
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
6+
from langchain.embeddings.huggingface import (
7+
HuggingFaceEmbeddings,
8+
HuggingFaceInstructEmbeddings,
9+
)
710
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
811
from langchain.embeddings.openai import OpenAIEmbeddings
912
from langchain.embeddings.tensorflow_hub import TensorflowHubEmbeddings
@@ -16,6 +19,7 @@
1619
"CohereEmbeddings",
1720
"HuggingFaceHubEmbeddings",
1821
"TensorflowHubEmbeddings",
22+
"HuggingFaceInstructEmbeddings",
1923
]
2024

2125

langchain/embeddings/huggingface.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,11 @@
66
from langchain.embeddings.base import Embeddings
77

88
DEFAULT_MODEL_NAME = "sentence-transformers/all-mpnet-base-v2"
9+
DEFAULT_INSTRUCT_MODEL = "hkunlp/instructor-large"
10+
DEFAULT_EMBED_INSTRUCTION = "Represent the document for retrieval: "
11+
DEFAULT_QUERY_INSTRUCTION = (
12+
"Represent the question for retrieving supporting documents: "
13+
)
914

1015

1116
class HuggingFaceEmbeddings(BaseModel, Embeddings):
@@ -68,3 +73,68 @@ def embed_query(self, text: str) -> List[float]:
6873
text = text.replace("\n", " ")
6974
embedding = self.client.encode(text)
7075
return embedding.tolist()
76+
77+
78+
class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
79+
"""Wrapper around sentence_transformers embedding models.
80+
81+
To use, you should have the ``sentence_transformers`` python package installed.
82+
83+
Example:
84+
.. code-block:: python
85+
86+
from langchain.embeddings import HuggingFaceInstructEmbeddings
87+
model_name = "hkunlp/instructor-large"
88+
hf = HuggingFaceInstructEmbeddings(model_name=model_name)
89+
"""
90+
91+
client: Any #: :meta private:
92+
model_name: str = DEFAULT_INSTRUCT_MODEL
93+
"""Model name to use."""
94+
embed_instruction: str = DEFAULT_EMBED_INSTRUCTION
95+
"""Instruction to use for embedding documents."""
96+
query_instruction: str = DEFAULT_QUERY_INSTRUCTION
97+
"""Instruction to use for embedding query."""
98+
99+
def __init__(self, **kwargs: Any):
100+
"""Initialize the sentence_transformer."""
101+
super().__init__(**kwargs)
102+
try:
103+
from InstructorEmbedding import INSTRUCTOR
104+
105+
self.client = INSTRUCTOR(self.model_name)
106+
except ImportError as e:
107+
raise ValueError("Dependencies for InstructorEmbedding not found.") from e
108+
109+
class Config:
110+
"""Configuration for this pydantic object."""
111+
112+
extra = Extra.forbid
113+
114+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
115+
"""Compute doc embeddings using a HuggingFace instruct model.
116+
117+
Args:
118+
texts: The list of texts to embed.
119+
120+
Returns:
121+
List of embeddings, one for each text.
122+
"""
123+
instruction_pairs = []
124+
for text in texts:
125+
instruction_pairs.append([self.embed_instruction, text])
126+
embeddings = self.client.encode(instruction_pairs)
127+
return embeddings.tolist()
128+
129+
def embed_query(self, text: str) -> List[float]:
130+
"""Compute query embeddings using a HuggingFace instruct model.
131+
132+
Args:
133+
text: The text to embed.
134+
135+
Returns:
136+
Embeddings for the text.
137+
"""
138+
instruction_pair = [self.query_instruction, text]
139+
embedding = self.client.encode([instruction_pair])[0]
140+
return embedding.tolist()

tests/integration_tests/embeddings/test_huggingface.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,10 @@
11
"""Test huggingface embeddings."""
22
import unittest
33

4-
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
4+
from langchain.embeddings.huggingface import (
5+
HuggingFaceEmbeddings,
6+
HuggingFaceInstructEmbeddings,
7+
)
58

69

710
@unittest.skip("This test causes a segfault.")
@@ -21,3 +24,20 @@ def test_huggingface_embedding_query() -> None:
2124
embedding = HuggingFaceEmbeddings()
2225
output = embedding.embed_query(document)
2326
assert len(output) == 768
27+
28+
29+
def test_huggingface_instructor_embedding_documents() -> None:
30+
"""Test huggingface embeddings."""
31+
documents = ["foo bar"]
32+
embedding = HuggingFaceInstructEmbeddings()
33+
output = embedding.embed_documents(documents)
34+
assert len(output) == 1
35+
assert len(output[0]) == 768
36+
37+
38+
def test_huggingface_instructor_embedding_query() -> None:
39+
"""Test huggingface embeddings."""
40+
query = "foo bar"
41+
embedding = HuggingFaceInstructEmbeddings()
42+
output = embedding.embed_query(query)
43+
assert len(output) == 768

0 commit comments

Comments
 (0)