Skip to content

Commit 1a60e21

Browse files
anakin87ZanSara
andauthored
refactor: simplify Summarizer, add Document Merger (#3452)
* remove generate_single_summary * update schemas * remove unused import * fix mypy * fix mypy * test: summarizer doesnt change content * other test correction * move test_summarizer_translation to test_extractor_translation * fix test * first try for doc merger * reintroduce and deprecate generate_single_summary * progress in document merger * document merger! * mypy, pylint fixes * use generator * added test that will fail in 1.12 * adapt to review * extended deprecation docstring * Update test/nodes/test_extractor_translation.py * Update test/nodes/test_summarizer.py * Update test/nodes/test_summarizer.py * black * documents fixture Co-authored-by: Sara Zan <[email protected]>
1 parent 0a04dec commit 1a60e21

File tree

11 files changed

+358
-207
lines changed

11 files changed

+358
-207
lines changed

docs/_src/api/pydoc/other.yml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
loaders:
22
- type: python
33
search_path: [../../../../haystack/nodes/other]
4-
modules: ['docs2answers', 'join_docs', 'join_answers', 'route_documents']
4+
modules: ['docs2answers', 'join_docs', 'join_answers', 'route_documents', 'document_merger']
55
ignore_when_discovered: ['__init__']
66
processors:
77
- type: filter
Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
import logging
2+
from copy import deepcopy
3+
from typing import Optional, List, Dict, Union, Any
4+
5+
from haystack.schema import Document
6+
from haystack.nodes.base import BaseComponent
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class DocumentMerger(BaseComponent):
12+
"""
13+
A node to merge the texts of the documents.
14+
"""
15+
16+
outgoing_edges = 1
17+
18+
def __init__(self, separator: str = " "):
19+
"""
20+
:param separator: The separator that appears between subsequent merged documents.
21+
"""
22+
super().__init__()
23+
self.separator = separator
24+
25+
def merge(self, documents: List[Document], separator: Optional[str] = None) -> List[Document]:
26+
"""
27+
Produce a list made up of a single document, which contains all the texts of the documents provided.
28+
29+
:param separator: The separator that appears between subsequent merged documents.
30+
:return: List of Documents
31+
"""
32+
if len(documents) == 0:
33+
raise ValueError("Document Merger needs at least one document to merge.")
34+
if not all(doc.content_type == "text" for doc in documents):
35+
raise ValueError(
36+
"Some of the documents provided are non-textual. Document Merger only works on textual documents."
37+
)
38+
39+
separator = separator if separator is not None else self.separator
40+
41+
merged_content = separator.join([doc.content for doc in documents])
42+
common_meta = self._keep_common_keys([doc.meta for doc in documents])
43+
44+
merged_document = Document(content=merged_content, meta=common_meta)
45+
return [merged_document]
46+
47+
def run(self, documents: List[Document], separator: Optional[str] = None): # type: ignore
48+
results: Dict = {"documents": []}
49+
if documents:
50+
results["documents"] = self.merge(documents=documents, separator=separator)
51+
return results, "output_1"
52+
53+
def run_batch( # type: ignore
54+
self, documents: Union[List[Document], List[List[Document]]], separator: Optional[str] = None
55+
):
56+
is_doclist_flat = isinstance(documents[0], Document)
57+
if is_doclist_flat:
58+
flat_result: List[Document] = self.merge(
59+
documents=[doc for doc in documents if isinstance(doc, Document)], separator=separator
60+
)
61+
return {"documents": flat_result}, "output_1"
62+
else:
63+
nested_result: List[List[Document]] = [
64+
self.merge(documents=docs_lst, separator=separator)
65+
for docs_lst in documents
66+
if isinstance(docs_lst, list)
67+
]
68+
return {"documents": nested_result}, "output_1"
69+
70+
def _keep_common_keys(self, list_of_dicts: List[Dict[str, Any]]) -> dict:
71+
merge_dictionary = deepcopy(list_of_dicts[0])
72+
for key, value in list_of_dicts[0].items():
73+
74+
# if not all other dicts have this key, delete directly
75+
if not all(key in dict.keys() for dict in list_of_dicts):
76+
del merge_dictionary[key]
77+
78+
# if they all have it and it's a dictionary, merge recursively
79+
elif isinstance(value, dict):
80+
# Get all the subkeys to merge in a new list
81+
list_of_subdicts = [dictionary[key] for dictionary in list_of_dicts]
82+
merge_dictionary[key] = self._keep_common_keys(list_of_subdicts)
83+
84+
# If all dicts have this key and it's not a dictionary, delete only if the values differ
85+
elif not all(value == dict[key] for dict in list_of_dicts):
86+
del merge_dictionary[key]
87+
88+
return merge_dictionary

haystack/nodes/summarizer/base.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -19,12 +19,8 @@ def predict(self, documents: List[Document], generate_single_summary: Optional[b
1919
Abstract method for creating a summary.
2020
2121
:param documents: Related documents (e.g. coming from a retriever) that the answer shall be conditioned on.
22-
:param generate_single_summary: Whether to generate a single summary for all documents or one summary per document.
23-
If set to "True", all docs will be joined to a single string that will then
24-
be summarized.
25-
Important: The summary will depend on the order of the supplied documents!
26-
:return: List of Documents, where Document.content contains the summarization and Document.meta["context"]
27-
the original, not summarized text
22+
:param generate_single_summary: This parameter is deprecated and will be removed in Haystack 1.12
23+
:return: List of Documents, where Document.meta["summary"] contains the summarization
2824
"""
2925
pass
3026

@@ -54,7 +50,7 @@ def run_batch( # type: ignore
5450
):
5551

5652
results = self.predict_batch(
57-
documents=documents, generate_single_summary=generate_single_summary, batch_size=batch_size
53+
documents=documents, batch_size=batch_size, generate_single_summary=generate_single_summary
5854
)
5955

6056
return {"documents": results}, "output_1"

haystack/nodes/summarizer/transformers.py

Lines changed: 68 additions & 92 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,6 @@
66
import torch
77
from tqdm.auto import tqdm
88
from transformers import pipeline
9-
from transformers.models.auto.modeling_auto import AutoModelForSeq2SeqLM
109

1110
from haystack.schema import Document
1211
from haystack.nodes.summarizer.base import BaseSummarizer
@@ -34,19 +33,18 @@ class TransformersSummarizer(BaseSummarizer):
3433
|
3534
| # Summarize
3635
| summary = summarizer.predict(
37-
| documents=docs,
38-
| generate_single_summary=True
39-
| )
36+
| documents=docs)
4037
|
41-
| # Show results (List of Documents, containing summary and original text)
38+
| # Show results (List of Documents, containing summary and original content)
4239
| print(summary)
4340
|
4441
| [
4542
| {
46-
| "text": "California's largest electricity provider has turned off power to hundreds of thousands of customers.",
43+
| "content": "PGE stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. ...",
4744
| ...
4845
| "meta": {
49-
| "context": "PGE stated it scheduled the blackouts in response to forecasts for high winds amid dry conditions. ..."
46+
| "summary": "California's largest electricity provider has turned off power to hundreds of thousands of customers.",
47+
| ...
5048
| },
5149
| ...
5250
| },
@@ -83,12 +81,9 @@ def __init__(
8381
:param min_length: Minimum length of summarized text
8482
:param use_gpu: Whether to use GPU (if available).
8583
:param clean_up_tokenization_spaces: Whether or not to clean up the potential extra spaces in the text output
86-
:param separator_for_single_summary: If `generate_single_summary=True` in `predict()`, we need to join all docs
87-
into a single text. This separator appears between those subsequent docs.
88-
:param generate_single_summary: Whether to generate a single summary for all documents or one summary per document.
89-
If set to "True", all docs will be joined to a single string that will then
90-
be summarized.
91-
Important: The summary will depend on the order of the supplied documents!
84+
:param separator_for_single_summary: This parameter is deprecated and will be removed in Haystack 1.12
85+
:param generate_single_summary: This parameter is deprecated and will be removed in Haystack 1.12.
86+
To obtain single summaries from multiple documents, consider using the [DocumentMerger](https://docs.haystack.deepset.ai/docs/document_merger).
9287
:param batch_size: Number of documents to process at a time.
9388
:param progress_bar: Whether to show a progress bar.
9489
:param use_auth_token: The API token used to download private models from Huggingface.
@@ -103,27 +98,32 @@ def __init__(
10398
"""
10499
super().__init__()
105100

101+
if generate_single_summary is True:
102+
raise ValueError(
103+
"'generate_single_summary' has been removed. Instead, you can use the Document Merger to merge documents before applying the Summarizer."
104+
)
105+
106106
self.devices, _ = initialize_device_settings(devices=devices, use_cuda=use_gpu, multi_gpu=False)
107107
if len(self.devices) > 1:
108108
logger.warning(
109109
f"Multiple devices are not supported in {self.__class__.__name__} inference, "
110110
f"using the first device {self.devices[0]}."
111111
)
112112

113-
# TODO AutoModelForSeq2SeqLM is only necessary with transformers==4.1.1, with newer versions use the pipeline directly
114113
if tokenizer is None:
115114
tokenizer = model_name_or_path
116-
model = AutoModelForSeq2SeqLM.from_pretrained(
117-
pretrained_model_name_or_path=model_name_or_path, revision=model_version, use_auth_token=use_auth_token
118-
)
115+
119116
self.summarizer = pipeline(
120-
"summarization", model=model, tokenizer=tokenizer, device=self.devices[0], use_auth_token=use_auth_token
117+
task="summarization",
118+
model=model_name_or_path,
119+
tokenizer=tokenizer,
120+
revision=model_version,
121+
device=self.devices[0],
122+
use_auth_token=use_auth_token,
121123
)
122124
self.max_length = max_length
123125
self.min_length = min_length
124126
self.clean_up_tokenization_spaces = clean_up_tokenization_spaces
125-
self.separator_for_single_summary = separator_for_single_summary
126-
self.generate_single_summary = generate_single_summary
127127
self.print_log: Set[str] = set()
128128
self.batch_size = batch_size
129129
self.progress_bar = progress_bar
@@ -134,29 +134,23 @@ def predict(self, documents: List[Document], generate_single_summary: Optional[b
134134
These document can for example be retrieved via the Retriever.
135135
136136
:param documents: Related documents (e.g. coming from a retriever) that the answer shall be conditioned on.
137-
:param generate_single_summary: Whether to generate a single summary for all documents or one summary per document.
138-
If set to "True", all docs will be joined to a single string that will then
139-
be summarized.
140-
Important: The summary will depend on the order of the supplied documents!
141-
:return: List of Documents, where Document.text contains the summarization and Document.meta["context"]
142-
the original, not summarized text
137+
:param generate_single_summary: This parameter is deprecated and will be removed in Haystack 1.12.
138+
To obtain single summaries from multiple documents, consider using the [DocumentMerger](https://docs.haystack.deepset.ai/docs/document_merger).
139+
:return: List of Documents, where Document.meta["summary"] contains the summarization
143140
"""
141+
if generate_single_summary is True:
142+
raise ValueError(
143+
"'generate_single_summary' has been removed. Instead, you can use the Document Merger to merge documents before applying the Summarizer."
144+
)
145+
144146
if self.min_length > self.max_length:
145147
raise AttributeError("min_length cannot be greater than max_length")
146148

147149
if len(documents) == 0:
148150
raise AttributeError("Summarizer needs at least one document to produce a summary.")
149151

150-
if generate_single_summary is None:
151-
generate_single_summary = self.generate_single_summary
152-
153152
contexts: List[str] = [doc.content for doc in documents]
154153

155-
if generate_single_summary:
156-
# Documents order is very important to produce summary.
157-
# Different order of same documents produce different summary.
158-
contexts = [self.separator_for_single_summary.join(contexts)]
159-
160154
encoded_input = self.summarizer.tokenizer(contexts, verbose=False)
161155
for input_id in encoded_input["input_ids"]:
162156
tokens_count: int = len(input_id)
@@ -182,15 +176,9 @@ def predict(self, documents: List[Document], generate_single_summary: Optional[b
182176

183177
result: List[Document] = []
184178

185-
if generate_single_summary:
186-
for context, summarized_answer in zip(contexts, summaries):
187-
cur_doc = Document(content=summarized_answer["summary_text"], meta={"context": context})
188-
result.append(cur_doc)
189-
else:
190-
for context, summarized_answer, document in zip(contexts, summaries, documents):
191-
cur_doc = Document(content=summarized_answer["summary_text"], meta=document.meta)
192-
cur_doc.meta.update({"context": context})
193-
result.append(cur_doc)
179+
for summary, document in zip(summaries, documents):
180+
document.meta.update({"summary": summary["summary_text"]})
181+
result.append(document)
194182

195183
return result
196184

@@ -206,13 +194,14 @@ def predict_batch(
206194
207195
:param documents: Single list of related documents or list of lists of related documents
208196
(e.g. coming from a retriever) that the answer shall be conditioned on.
209-
:param generate_single_summary: Whether to generate a single summary for each provided document list or
210-
one summary per document.
211-
If set to "True", all docs of a document list will be joined to a single string
212-
that will then be summarized.
213-
Important: The summary will depend on the order of the supplied documents!
197+
:param generate_single_summary: This parameter is deprecated and will be removed in Haystack 1.12.
198+
To obtain single summaries from multiple documents, consider using the [DocumentMerger](https://docs.haystack.deepset.ai/docs/document_merger).
214199
:param batch_size: Number of Documents to process at a time.
215200
"""
201+
if generate_single_summary is True:
202+
raise ValueError(
203+
"'generate_single_summary' has been removed. Instead, you can use the Document Merger to merge documents before applying the Summarizer."
204+
)
216205

217206
if self.min_length > self.max_length:
218207
raise AttributeError("min_length cannot be greater than max_length")
@@ -225,34 +214,17 @@ def predict_batch(
225214
if batch_size is None:
226215
batch_size = self.batch_size
227216

228-
if generate_single_summary is None:
229-
generate_single_summary = self.generate_single_summary
230-
231-
single_doc_list = False
232-
if isinstance(documents[0], Document):
233-
single_doc_list = True
234-
235-
if single_doc_list:
217+
is_doclist_flat = isinstance(documents[0], Document)
218+
if is_doclist_flat:
236219
contexts = [doc.content for doc in documents if isinstance(doc, Document)]
237220
else:
238221
contexts = [
239222
[doc.content for doc in docs if isinstance(doc, Document)]
240223
for docs in documents
241224
if isinstance(docs, list)
242225
]
243-
244-
if generate_single_summary:
245-
if single_doc_list:
246-
contexts = [self.separator_for_single_summary.join(contexts)]
247-
else:
248-
contexts = [self.separator_for_single_summary.join(context_group) for context_group in contexts]
249-
number_of_docs = [1 for _ in contexts]
250-
else:
251-
if single_doc_list:
252-
number_of_docs = [1 for _ in contexts]
253-
else:
254-
number_of_docs = [len(context_group) for context_group in contexts]
255-
contexts = list(itertools.chain.from_iterable(contexts))
226+
number_of_docs = [len(context_group) for context_group in contexts]
227+
contexts = list(itertools.chain.from_iterable(contexts))
256228

257229
encoded_input = self.summarizer.tokenizer(contexts, verbose=False)
258230
for input_id in encoded_input["input_ids"]:
@@ -286,26 +258,30 @@ def predict_batch(
286258
):
287259
summaries.extend(summary_batch)
288260

289-
# Group summaries together
290-
grouped_summaries = []
291-
grouped_contexts = []
292-
left_idx = 0
293-
right_idx = 0
294-
for number in number_of_docs:
295-
right_idx = left_idx + number
296-
grouped_summaries.append(summaries[left_idx:right_idx])
297-
grouped_contexts.append(contexts[left_idx:right_idx])
298-
left_idx = right_idx
299-
300-
result = []
301-
for summary_group, context_group in zip(grouped_summaries, grouped_contexts):
302-
cur_summaries = [
303-
Document(content=summary["summary_text"], meta={"context": context})
304-
for summary, context in zip(summary_group, context_group)
305-
]
306-
if single_doc_list:
307-
result.append(cur_summaries[0])
308-
else:
309-
result.append(cur_summaries) # type: ignore
310-
311-
return result
261+
if is_doclist_flat:
262+
flat_result: List[Document] = []
263+
flat_doc_list: List[Document] = [doc for doc in documents if isinstance(doc, Document)]
264+
for summary, document in zip(summaries, flat_doc_list):
265+
document.meta.update({"summary": summary["summary_text"]})
266+
flat_result.append(document)
267+
return flat_result
268+
else:
269+
nested_result: List[List[Document]] = []
270+
nested_doc_list: List[List[Document]] = [lst for lst in documents if isinstance(lst, list)]
271+
272+
# Group summaries together
273+
grouped_summaries = []
274+
left_idx = 0
275+
right_idx = 0
276+
for number in number_of_docs:
277+
right_idx = left_idx + number
278+
grouped_summaries.append(summaries[left_idx:right_idx])
279+
left_idx = right_idx
280+
281+
for summary_group, docs_group in zip(grouped_summaries, nested_doc_list):
282+
cur_summaries = []
283+
for summary, document in zip(summary_group, docs_group):
284+
document.meta.update({"summary": summary["summary_text"]})
285+
cur_summaries.append(document)
286+
nested_result.append(cur_summaries)
287+
return nested_result

0 commit comments

Comments
 (0)