Skip to content
Merged
Changes from 15 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
123 changes: 93 additions & 30 deletions keybert/llm/_langchain.py
Original file line number Diff line number Diff line change
@@ -1,29 +1,37 @@
from tqdm import tqdm
from typing import List
from langchain.docstore.document import Document

from langchain.prompts import ChatPromptTemplate, PromptTemplate
from langchain_core.language_models.chat_models import BaseChatModel as LangChainBaseChatModel
from langchain_core.language_models.llms import BaseLLM as LangChainBaseLLM
from langchain_core.output_parsers import StrOutputParser
from tqdm import tqdm

from keybert.llm._base import BaseLLM
from keybert.llm._utils import process_candidate_keywords


DEFAULT_PROMPT = "What is this document about? Please provide keywords separated by commas."
"""NOTE
langchain >= 0.1 is required. Which supports:
- chain.invoke()
- LangChain Expression Language (LCEL) is used and it is not compatible with langchain < 0.1.
"""
Comment on lines 12 to 16
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Might be best to also update this in the documentation here: https://github.com/MaartenGr/KeyBERT/blob/master/docs/guides/llms.md

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sure. Updated



class LangChain(BaseLLM):
"""Using chains in langchain to generate keywords.

Currently, only chains from question answering is implemented. See:
https://langchain.readthedocs.io/en/latest/modules/chains/combine_docs_examples/question_answering.html

NOTE: The resulting keywords are expected to be separated by commas so
any changes to the prompt will have to make sure that the resulting
keywords are comma-separated.

Arguments:
chain: A langchain chain that has two input parameters, `input_documents` and `query`.
llm: A langchain LLM class. e.g ChatOpenAI, OpenAI, etc.
prompt: The prompt to be used in the model. If no prompt is given,
`self.default_prompt_` is used instead.
`self.DEFAULT_PROMPT_TEMPLATE` is used instead.
NOTE: The prompt should contain:
1. Placeholders
- `[DOCUMENT]`: Required. The document to extract keywords from.
- `[CANDIDATES]`: Optional. The candidate keywords to fine-tune the extraction.
2. Output format instructions
- Include this or something similar in your prompt:
"Extracted keywords must be separated by comma."
verbose: Set this to True if you want to see a progress bar for the
keyword extraction.
keyword extraction.

Usage:

Expand All @@ -32,14 +40,18 @@ class LangChain(BaseLLM):
like openai:

`pip install langchain`
`pip install openai`
`pip install langchain-openai`

Then, you can create your chain as follows:

```python
from langchain.chains.question_answering import load_qa_chain
from langchain.llms import OpenAI
chain = load_qa_chain(OpenAI(temperature=0, openai_api_key=my_openai_api_key), chain_type="stuff")
from langchain_openai import ChatOpenAI

_llm = ChatOpenAI(
model="gpt-4o",
api_key="my-openai-api-key",
temperature=0,
)
```

Finally, you can pass the chain to KeyBERT as follows:
Expand All @@ -49,14 +61,39 @@ class LangChain(BaseLLM):
from keybert import KeyLLM

# Create your LLM
llm = LangChain(chain)
llm = LangChain(_llm)

# Load it in KeyLLM
kw_model = KeyLLM(llm)

# Extract keywords
document = "The website mentions that it only takes a couple of days to deliver but I still have not received mine."
keywords = kw_model.extract_keywords(document)
docs = [
"KeyBERT: A minimal method for keyword extraction with BERT. The keyword extraction is done by finding the sub-phrases in a document that are the most similar to the document itself. First, document embeddings are extracted with BERT to get a document-level representation. Then, word embeddings are extracted for N-gram words/phrases. Finally, we use cosine similarity to find the words/phrases that are the most similar to the document. The most similar words could then be identified as the words that best describe the entire document.",
"KeyLLM: A minimal method for keyword extraction with Large Language Models (LLM). The keyword extraction is done by simply asking the LLM to extract a number of keywords from a single piece of text.",
]
keywords = kw_model.extract_keywords(docs=docs)
print(keywords)

# Output:
# [
# ['KeyBERT', 'keyword extraction', 'BERT', 'document embeddings', 'word embeddings', 'N-gram phrases', 'cosine similarity', 'document representation'],
# ['KeyLLM', 'keyword extraction', 'Large Language Models', 'LLM', 'minimal method']
# ]


# fine tune with candidate keywords
candidates = [
["keyword extraction", "Large Language Models", "LLM", "BERT", "transformer", "embeddings"],
["keyword extraction", "Large Language Models", "LLM", "BERT", "transformer", "embeddings"],
]
keywords = kw_model.extract_keywords(docs=docs, candidate_keywords=candidates)
print(keywords)

# Output:
# [
# ['keyword extraction', 'BERT', 'document embeddings', 'word embeddings', 'cosine similarity', 'N-gram phrases'],
# ['KeyLLM', 'keyword extraction', 'Large Language Models', 'LLM']
# ]
```

You can also use a custom prompt:
Expand All @@ -67,16 +104,35 @@ class LangChain(BaseLLM):
```
"""

DEFAULT_PROMPT_TEMPLATE = """
# Task
You are provided with a document and possiblily a list of candidate keywords.

If no candidate keywords are provided, your task to is extract keywords from the document.
If candidate keywords are provided, your task is to improve the candidate keywords to best describe the topic of the document.

# Document
[DOCUMENT]

# Candidate Keywords
[CANDIDATES]


Now extract the keywords from the document.
The keywords must be comma separated.
For example: "keyword1, keyword2, keyword3"
"""

def __init__(
self,
chain,
llm: LangChainBaseChatModel | LangChainBaseLLM,
prompt: str = None,
verbose: bool = False,
):
self.chain = chain
self.prompt = prompt if prompt is not None else DEFAULT_PROMPT
self.default_prompt_ = DEFAULT_PROMPT
self.llm = llm
self.prompt = prompt if prompt is not None else self.DEFAULT_PROMPT_TEMPLATE
self.verbose = verbose
self.chain = self._get_chain()

def extract_keywords(self, documents: List[str], candidate_keywords: List[List[str]] = None):
"""Extract topics.
Expand All @@ -95,12 +151,19 @@ def extract_keywords(self, documents: List[str], candidate_keywords: List[List[s
candidate_keywords = process_candidate_keywords(documents, candidate_keywords)

for document, candidates in tqdm(zip(documents, candidate_keywords), disable=not self.verbose):
prompt = self.prompt.replace("[DOCUMENT]", document)
if candidates is not None:
prompt = prompt.replace("[CANDIDATES]", ", ".join(candidates))
input_document = Document(page_content=document)
keywords = self.chain.run(input_documents=[input_document], question=self.prompt).strip()
keywords = self.chain.invoke({"DOCUMENT": document, "CANDIDATES": candidates})
keywords = [keyword.strip() for keyword in keywords.split(",")]
all_keywords.append(keywords)

return all_keywords

def _get_chain(self):
"""Get the chain using LLM and prompt."""
# format prompt for langchain template placeholders
prompt = self.prompt.replace("[DOCUMENT]", "{DOCUMENT}").replace("[CANDIDATES]", "{CANDIDATES}")
# check if the model is a chat model
is_chat_model = isinstance(self.llm, LangChainBaseChatModel)
# langchain prompt template
prompt_template = ChatPromptTemplate([("human", prompt)]) if is_chat_model else PromptTemplate(template=prompt)
# chain
return prompt_template | self.llm | StrOutputParser()
Loading