Skip to content

Commit e3df8ab

Browse files
authored
move hyde into chains (#728)
Co-authored-by: scadEfUr <>
1 parent 0ffeabd commit e3df8ab

File tree

7 files changed

+34
-14
lines changed

7 files changed

+34
-14
lines changed

β€Ždocs/modules/utils/combine_docs_examples/hyde.ipynb

Lines changed: 9 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -21,8 +21,8 @@
2121
"outputs": [],
2222
"source": [
2323
"from langchain.llms import OpenAI\n",
24-
"from langchain.embeddings import OpenAIEmbeddings, HypotheticalDocumentEmbedder\n",
25-
"from langchain.chains import LLMChain\n",
24+
"from langchain.embeddings import OpenAIEmbeddings\n",
25+
"from langchain.chains import LLMChain, HypotheticalDocumentEmbedder\n",
2626
"from langchain.prompts import PromptTemplate"
2727
]
2828
},
@@ -220,7 +220,7 @@
220220
],
221221
"metadata": {
222222
"kernelspec": {
223-
"display_name": "Python 3 (ipykernel)",
223+
"display_name": "llm-env",
224224
"language": "python",
225225
"name": "python3"
226226
},
@@ -234,7 +234,12 @@
234234
"name": "python",
235235
"nbconvert_exporter": "python",
236236
"pygments_lexer": "ipython3",
237-
"version": "3.10.9"
237+
"version": "3.9.0 (default, Nov 15 2020, 06:25:35) \n[Clang 10.0.0 ]"
238+
},
239+
"vscode": {
240+
"interpreter": {
241+
"hash": "9dd01537e9ab68cf47cb0398488d182358f774f73101197b3bd1b5502c6ec7f9"
242+
}
238243
}
239244
},
240245
"nbformat": 4,

β€Žlangchain/chains/__init__.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
"""Chains are easily reusable components which can be linked together."""
22
from langchain.chains.api.base import APIChain
33
from langchain.chains.conversation.base import ConversationChain
4+
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
45
from langchain.chains.llm import LLMChain
56
from langchain.chains.llm_bash.base import LLMBashChain
67
from langchain.chains.llm_checker.base import LLMCheckerChain
@@ -41,4 +42,5 @@
4142
"OpenAIModerationChain",
4243
"SQLDatabaseSequentialChain",
4344
"load_chain",
45+
"HypotheticalDocumentEmbedder",
4446
]

β€Žlangchain/embeddings/hyde/base.py renamed to β€Žlangchain/chains/hyde/base.py

Lines changed: 21 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,18 +4,19 @@
44
"""
55
from __future__ import annotations
66

7-
from typing import List
7+
from typing import Dict, List
88

99
import numpy as np
1010
from pydantic import BaseModel, Extra
1111

12+
from langchain.chains.base import Chain
13+
from langchain.chains.hyde.prompts import PROMPT_MAP
1214
from langchain.chains.llm import LLMChain
1315
from langchain.embeddings.base import Embeddings
14-
from langchain.embeddings.hyde.prompts import PROMPT_MAP
1516
from langchain.llms.base import BaseLLM
1617

1718

18-
class HypotheticalDocumentEmbedder(Embeddings, BaseModel):
19+
class HypotheticalDocumentEmbedder(Chain, Embeddings, BaseModel):
1920
"""Generate hypothetical document for query, and then embed that.
2021
2122
Based on https://arxiv.org/abs/2212.10496
@@ -30,10 +31,24 @@ class Config:
3031
extra = Extra.forbid
3132
arbitrary_types_allowed = True
3233

34+
@property
35+
def input_keys(self) -> List[str]:
36+
"""Input keys for Hyde's LLM chain."""
37+
return self.llm_chain.input_keys
38+
39+
@property
40+
def output_keys(self) -> List[str]:
41+
"""Output keys for Hyde's LLM chain."""
42+
return self.llm_chain.output_keys
43+
3344
def embed_documents(self, texts: List[str]) -> List[List[float]]:
3445
"""Call the base embeddings."""
3546
return self.base_embeddings.embed_documents(texts)
3647

48+
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]:
49+
"""Combine embeddings into final embeddings."""
50+
return list(np.array(embeddings).mean(axis=0))
51+
3752
def embed_query(self, text: str) -> List[float]:
3853
"""Generate a hypothetical document and embedded it."""
3954
var_name = self.llm_chain.input_keys[0]
@@ -42,9 +57,9 @@ def embed_query(self, text: str) -> List[float]:
4257
embeddings = self.embed_documents(documents)
4358
return self.combine_embeddings(embeddings)
4459

45-
def combine_embeddings(self, embeddings: List[List[float]]) -> List[float]:
46-
"""Combine embeddings into final embeddings."""
47-
return list(np.array(embeddings).mean(axis=0))
60+
def _call(self, inputs: Dict[str, str]) -> Dict[str, str]:
61+
"""Call the internal llm chain."""
62+
return self.llm_chain._call(inputs)
4863

4964
@classmethod
5065
def from_llm(

β€Žlangchain/embeddings/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,11 @@
22
from langchain.embeddings.cohere import CohereEmbeddings
33
from langchain.embeddings.huggingface import HuggingFaceEmbeddings
44
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
5-
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
65
from langchain.embeddings.openai import OpenAIEmbeddings
76

87
__all__ = [
98
"OpenAIEmbeddings",
109
"HuggingFaceEmbeddings",
1110
"CohereEmbeddings",
1211
"HuggingFaceHubEmbeddings",
13-
"HypotheticalDocumentEmbedder",
1412
]

β€Žtests/unit_tests/test_hyde.py renamed to β€Žtests/unit_tests/chains/test_hyde.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -4,9 +4,9 @@
44
import numpy as np
55
from pydantic import BaseModel
66

7+
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
8+
from langchain.chains.hyde.prompts import PROMPT_MAP
79
from langchain.embeddings.base import Embeddings
8-
from langchain.embeddings.hyde.base import HypotheticalDocumentEmbedder
9-
from langchain.embeddings.hyde.prompts import PROMPT_MAP
1010
from langchain.llms.base import BaseLLM
1111
from langchain.schema import Generation, LLMResult
1212

0 commit comments

Comments
Β (0)