Skip to content

Commit 7b4882a

Browse files
hwchase17yakigac
andauthored
Harrison/tf embeddings (#817)
Co-authored-by: Ryohei Kuroki <[email protected]>
1 parent 5d4b6e4 commit 7b4882a

File tree

6 files changed

+1020
-301
lines changed

6 files changed

+1020
-301
lines changed

docs/modules/utils/combine_docs_examples/embeddings.ipynb

Lines changed: 70 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,6 @@
7777
]
7878
},
7979
{
80-
"attachments": {},
8180
"cell_type": "markdown",
8281
"id": "42f76e43",
8382
"metadata": {},
@@ -138,7 +137,6 @@
138137
]
139138
},
140139
{
141-
"attachments": {},
142140
"cell_type": "markdown",
143141
"id": "ed47bb62",
144142
"metadata": {},
@@ -196,11 +194,79 @@
196194
"source": [
197195
"doc_result = embeddings.embed_documents([text])"
198196
]
197+
},
198+
{
199+
"cell_type": "markdown",
200+
"id": "fff4734f",
201+
"metadata": {},
202+
"source": [
203+
"## TensorflowHub\n",
204+
"Let's load the TensorflowHub Embedding class."
205+
]
206+
},
207+
{
208+
"cell_type": "code",
209+
"execution_count": 1,
210+
"id": "f822104b",
211+
"metadata": {},
212+
"outputs": [],
213+
"source": [
214+
"from langchain.embeddings import TensorflowHubEmbeddings"
215+
]
216+
},
217+
{
218+
"cell_type": "code",
219+
"execution_count": 5,
220+
"id": "bac84e46",
221+
"metadata": {},
222+
"outputs": [
223+
{
224+
"name": "stderr",
225+
"output_type": "stream",
226+
"text": [
227+
"2023-01-30 23:53:01.652176: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
228+
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
229+
"2023-01-30 23:53:34.362802: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
230+
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n"
231+
]
232+
}
233+
],
234+
"source": [
235+
"embeddings = TensorflowHubEmbeddings()"
236+
]
237+
},
238+
{
239+
"cell_type": "code",
240+
"execution_count": 6,
241+
"id": "4790d770",
242+
"metadata": {},
243+
"outputs": [],
244+
"source": [
245+
"text = \"This is a test document.\""
246+
]
247+
},
248+
{
249+
"cell_type": "code",
250+
"execution_count": 7,
251+
"id": "f556dcdb",
252+
"metadata": {},
253+
"outputs": [],
254+
"source": [
255+
"query_result = embeddings.embed_query(text)"
256+
]
257+
},
258+
{
259+
"cell_type": "code",
260+
"execution_count": null,
261+
"id": "90f0db94",
262+
"metadata": {},
263+
"outputs": [],
264+
"source": []
199265
}
200266
],
201267
"metadata": {
202268
"kernelspec": {
203-
"display_name": "cohere",
269+
"display_name": "Python 3 (ipykernel)",
204270
"language": "python",
205271
"name": "python3"
206272
},
@@ -214,7 +280,7 @@
214280
"name": "python",
215281
"nbconvert_exporter": "python",
216282
"pygments_lexer": "ipython3",
217-
"version": "3.10.8"
283+
"version": "3.10.9"
218284
},
219285
"vscode": {
220286
"interpreter": {

langchain/embeddings/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
77
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
88
from langchain.embeddings.openai import OpenAIEmbeddings
9+
from langchain.embeddings.tensorflow_hub import TensorflowHubEmbeddings
910

1011
logger = logging.getLogger(__name__)
1112

@@ -14,6 +15,7 @@
1415
"HuggingFaceEmbeddings",
1516
"CohereEmbeddings",
1617
"HuggingFaceHubEmbeddings",
18+
"TensorflowHubEmbeddings",
1719
]
1820

1921

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
"""Wrapper around TensorflowHub embedding models."""
2+
from typing import Any, List
3+
4+
from pydantic import BaseModel, Extra
5+
6+
from langchain.embeddings.base import Embeddings
7+
8+
DEFAULT_MODEL_URL = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
9+
10+
11+
class TensorflowHubEmbeddings(BaseModel, Embeddings):
12+
"""Wrapper around tensorflow_hub embedding models.
13+
14+
To use, you should have the ``tensorflow_text`` python package installed.
15+
16+
Example:
17+
.. code-block:: python
18+
19+
from langchain.embeddings import TensorflowHubEmbeddings
20+
url = "https://tfhub.dev/google/universal-sentence-encoder-multilingual/3"
21+
tf = TensorflowHubEmbeddings(model_url=url)
22+
"""
23+
24+
embed: Any #: :meta private:
25+
model_url: str = DEFAULT_MODEL_URL
26+
"""Model name to use."""
27+
28+
def __init__(self, **kwargs: Any):
29+
"""Initialize the tensorflow_hub and tensorflow_text."""
30+
super().__init__(**kwargs)
31+
try:
32+
import tensorflow_hub
33+
import tensorflow_text # noqa
34+
35+
self.embed = tensorflow_hub.load(self.model_url)
36+
except ImportError as e:
37+
raise ValueError(
38+
"Could not import some python packages." "Please install them."
39+
) from e
40+
41+
class Config:
42+
"""Configuration for this pydantic object."""
43+
44+
extra = Extra.forbid
45+
46+
def embed_documents(self, texts: List[str]) -> List[List[float]]:
47+
"""Compute doc embeddings using a TensorflowHub embedding model.
48+
49+
Args:
50+
texts: The list of texts to embed.
51+
52+
Returns:
53+
List of embeddings, one for each text.
54+
"""
55+
texts = list(map(lambda x: x.replace("\n", " "), texts))
56+
embeddings = self.embed(texts).numpy()
57+
return embeddings.tolist()
58+
59+
def embed_query(self, text: str) -> List[float]:
60+
"""Compute query embeddings using a TensorflowHub embedding model.
61+
62+
Args:
63+
text: The text to embed.
64+
65+
Returns:
66+
Embeddings for the text.
67+
"""
68+
text = text.replace("\n", " ")
69+
embedding = self.embed(text).numpy()[0]
70+
return embedding.tolist()

0 commit comments

Comments
 (0)