1- from tqdm import tqdm
21from typing import List
3- from langchain .docstore .document import Document
2+
3+ from langchain .prompts import ChatPromptTemplate , PromptTemplate
4+ from langchain_core .language_models .chat_models import BaseChatModel as LangChainBaseChatModel
5+ from langchain_core .language_models .llms import BaseLLM as LangChainBaseLLM
6+ from langchain_core .output_parsers import StrOutputParser
7+ from tqdm import tqdm
8+
49from keybert .llm ._base import BaseLLM
510from keybert .llm ._utils import process_candidate_keywords
611
7-
8- DEFAULT_PROMPT = "What is this document about? Please provide keywords separated by commas."
12+ """NOTE
13+ KeyBERT only supports `langchain >= 0.1` which features:
14+ - [Runnable Interface](https://python.langchain.com/docs/concepts/runnables/)
15+ - [LangChain Expression Language (LCEL)](https://python.langchain.com/docs/concepts/lcel/)
16+ """
917
1018
1119class LangChain (BaseLLM ):
1220 """Using chains in langchain to generate keywords.
1321
14- Currently, only chains from question answering is implemented. See:
15- https://langchain.readthedocs.io/en/latest/modules/chains/combine_docs_examples/question_answering.html
16-
17- NOTE: The resulting keywords are expected to be separated by commas so
18- any changes to the prompt will have to make sure that the resulting
19- keywords are comma-separated.
20-
2122 Arguments:
22- chain : A langchain chain that has two input parameters, `input_documents` and `query` .
23+ llm : A langchain LLM class. e.g ChatOpenAI, OpenAI, etc .
2324 prompt: The prompt to be used in the model. If no prompt is given,
24- `self.default_prompt_` is used instead.
25+ `self.DEFAULT_PROMPT_TEMPLATE` is used instead.
26+ NOTE: The prompt should contain:
27+ 1. Placeholders
28+ - `[DOCUMENT]`: Required. The document to extract keywords from.
29+ - `[CANDIDATES]`: Optional. The candidate keywords to fine-tune the extraction.
30+ 2. Output format instructions
31+ - Include this or something similar in your prompt:
32+ "Extracted keywords must be separated by comma."
2533 verbose: Set this to True if you want to see a progress bar for the
26- keyword extraction.
34+ keyword extraction.
2735
2836 Usage:
2937
@@ -32,14 +40,18 @@ class LangChain(BaseLLM):
3240 like openai:
3341
3442 `pip install langchain`
35- `pip install openai`
43+ `pip install langchain- openai`
3644
3745 Then, you can create your chain as follows:
3846
3947 ```python
40- from langchain.chains.question_answering import load_qa_chain
41- from langchain.llms import OpenAI
42- chain = load_qa_chain(OpenAI(temperature=0, openai_api_key=my_openai_api_key), chain_type="stuff")
48+ from langchain_openai import ChatOpenAI
49+
50+ _llm = ChatOpenAI(
51+ model="gpt-4o",
52+ api_key="my-openai-api-key",
53+ temperature=0,
54+ )
4355 ```
4456
4557 Finally, you can pass the chain to KeyBERT as follows:
@@ -49,14 +61,39 @@ class LangChain(BaseLLM):
4961 from keybert import KeyLLM
5062
5163 # Create your LLM
52- llm = LangChain(chain )
64+ llm = LangChain(_llm )
5365
5466 # Load it in KeyLLM
5567 kw_model = KeyLLM(llm)
5668
5769 # Extract keywords
58- document = "The website mentions that it only takes a couple of days to deliver but I still have not received mine."
59- keywords = kw_model.extract_keywords(document)
70+ docs = [
71+ "KeyBERT: A minimal method for keyword extraction with BERT. The keyword extraction is done by finding the sub-phrases in a document that are the most similar to the document itself. First, document embeddings are extracted with BERT to get a document-level representation. Then, word embeddings are extracted for N-gram words/phrases. Finally, we use cosine similarity to find the words/phrases that are the most similar to the document. The most similar words could then be identified as the words that best describe the entire document.",
72+ "KeyLLM: A minimal method for keyword extraction with Large Language Models (LLM). The keyword extraction is done by simply asking the LLM to extract a number of keywords from a single piece of text.",
73+ ]
74+ keywords = kw_model.extract_keywords(docs=docs)
75+ print(keywords)
76+
77+ # Output:
78+ # [
79+ # ['KeyBERT', 'keyword extraction', 'BERT', 'document embeddings', 'word embeddings', 'N-gram phrases', 'cosine similarity', 'document representation'],
80+ # ['KeyLLM', 'keyword extraction', 'Large Language Models', 'LLM', 'minimal method']
81+ # ]
82+
83+
84+ # fine tune with candidate keywords
85+ candidates = [
86+ ["keyword extraction", "Large Language Models", "LLM", "BERT", "transformer", "embeddings"],
87+ ["keyword extraction", "Large Language Models", "LLM", "BERT", "transformer", "embeddings"],
88+ ]
89+ keywords = kw_model.extract_keywords(docs=docs, candidate_keywords=candidates)
90+ print(keywords)
91+
92+ # Output:
93+ # [
94+ # ['keyword extraction', 'BERT', 'document embeddings', 'word embeddings', 'cosine similarity', 'N-gram phrases'],
95+ # ['KeyLLM', 'keyword extraction', 'Large Language Models', 'LLM']
96+ # ]
6097 ```
6198
6299 You can also use a custom prompt:
@@ -67,16 +104,35 @@ class LangChain(BaseLLM):
67104 ```
68105 """
69106
107+ DEFAULT_PROMPT_TEMPLATE = """
108+ # Task
109+ You are provided with a document and possiblily a list of candidate keywords.
110+
111+ If no candidate keywords are provided, your task to is extract keywords from the document.
112+ If candidate keywords are provided, your task is to improve the candidate keywords to best describe the topic of the document.
113+
114+ # Document
115+ [DOCUMENT]
116+
117+ # Candidate Keywords
118+ [CANDIDATES]
119+
120+
121+ Now extract the keywords from the document.
122+ The keywords must be comma separated.
123+ For example: "keyword1, keyword2, keyword3"
124+ """
125+
70126 def __init__ (
71127 self ,
72- chain ,
128+ llm : LangChainBaseChatModel | LangChainBaseLLM ,
73129 prompt : str = None ,
74130 verbose : bool = False ,
75131 ):
76- self .chain = chain
77- self .prompt = prompt if prompt is not None else DEFAULT_PROMPT
78- self .default_prompt_ = DEFAULT_PROMPT
132+ self .llm = llm
133+ self .prompt = prompt if prompt is not None else self .DEFAULT_PROMPT_TEMPLATE
79134 self .verbose = verbose
135+ self .chain = self ._get_chain ()
80136
81137 def extract_keywords (self , documents : List [str ], candidate_keywords : List [List [str ]] = None ):
82138 """Extract topics.
@@ -95,12 +151,19 @@ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[s
95151 candidate_keywords = process_candidate_keywords (documents , candidate_keywords )
96152
97153 for document , candidates in tqdm (zip (documents , candidate_keywords ), disable = not self .verbose ):
98- prompt = self .prompt .replace ("[DOCUMENT]" , document )
99- if candidates is not None :
100- prompt = prompt .replace ("[CANDIDATES]" , ", " .join (candidates ))
101- input_document = Document (page_content = document )
102- keywords = self .chain .run (input_documents = [input_document ], question = self .prompt ).strip ()
154+ keywords = self .chain .invoke ({"DOCUMENT" : document , "CANDIDATES" : candidates })
103155 keywords = [keyword .strip () for keyword in keywords .split ("," )]
104156 all_keywords .append (keywords )
105157
106158 return all_keywords
159+
160+ def _get_chain (self ):
161+ """Get the chain using LLM and prompt."""
162+ # format prompt for langchain template placeholders
163+ prompt = self .prompt .replace ("[DOCUMENT]" , "{DOCUMENT}" ).replace ("[CANDIDATES]" , "{CANDIDATES}" )
164+ # check if the model is a chat model
165+ is_chat_model = isinstance (self .llm , LangChainBaseChatModel )
166+ # langchain prompt template
167+ prompt_template = ChatPromptTemplate ([("human" , prompt )]) if is_chat_model else PromptTemplate (template = prompt )
168+ # chain
169+ return prompt_template | self .llm | StrOutputParser ()
0 commit comments