Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
70 changes: 64 additions & 6 deletions orangecontrib/text/widgets/owcorpusviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
QSortFilterProxyModel,
Qt,
QUrl,
QAbstractTableModel,
)
from AnyQt.QtWidgets import (
QAbstractItemView,
Expand All @@ -37,6 +38,7 @@
from orangewidget.utils.listview import ListViewSearch

from orangecontrib.text.corpus import Corpus
from Orange.data import ContinuousVariable

HTML = """
<!doctype html>
Expand Down Expand Up @@ -140,7 +142,7 @@ def _count_matches(content: List[str], regex: re.Pattern, state: TaskState) -> i
return matches


class DocumentListModel(QAbstractListModel):
class DocumentListModel(QAbstractTableModel):
"""
Custom model for listing documents. Using custom model since Onrage's
pylistmodel is too slow for large number of documents
Expand All @@ -150,6 +152,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.__visible_data = []
self.__filter_content = []
self.__match_counts = []

def data(self, index: QModelIndex, role: int = Qt.DisplayRole) -> Any:
if role == Qt.DisplayRole:
Expand All @@ -160,12 +163,37 @@ def data(self, index: QModelIndex, role: int = Qt.DisplayRole) -> Any:
def rowCount(self, parent: QModelIndex = None, *args, **kwargs) -> int:
return len(self.__visible_data)

def setup_data(self, data: List[str], content: List[str]):
def setup_data(self, data: List[str], content: List[str], match_counts: List[int] = None):
self.beginResetModel()
self.__visible_data = data
self.__filter_content = content
self.__match_counts = match_counts or [0] * len(data)
self.endResetModel()

def set_match_counts(self, match_counts: List[int]):
assert len(match_counts) == len(self.__visible_data)
self.__match_counts = match_counts
self.dataChanged.emit(self.index(0, 0), self.index(self.rowCount() - 1, 1))

def data(self, index: QModelIndex, role: int = Qt.DisplayRole) -> Any:
row = index.row()
col = index.column() if index.isValid() else 0
if role == Qt.DisplayRole:
if col == 0:
return self.__visible_data[row]
elif col == 1:
return self.__match_counts[row]
elif role == Qt.UserRole:
return self.__filter_content[row]

def columnCount(self, parent=None):
return 2

def headerData(self, section, orientation, role):
if orientation == Qt.Horizontal and role == Qt.DisplayRole:
return ["Title", "Match Count"][section]
return super().headerData(section, orientation, role)

def update_filter_content(self, content: List[str]):
assert len(content) == len(self.__visible_data)
self.__filter_content = content
Expand Down Expand Up @@ -383,13 +411,16 @@ def __init__(self):
self.doc_list.setSelectionMode(QTableView.ExtendedSelection)
self.doc_list.setEditTriggers(QAbstractItemView.NoEditTriggers)
self.doc_list.horizontalHeader().setSectionResizeMode(QHeaderView.Stretch)
self.doc_list.horizontalHeader().setVisible(False)
self.doc_list.horizontalHeader().setVisible(True)
self.splitter.addWidget(self.doc_list)
self.doc_list.setSortingEnabled(True)

self.doc_list_model = DocumentListModel()
proxy_model = DocumentsFilterProxyModel()
proxy_model.setSourceModel(self.doc_list_model)
self.doc_list.setModel(proxy_model)
self.doc_list.setSortingEnabled(True)
self.doc_list.sortByColumn(1, Qt.DescendingOrder)
self.doc_list.selectionModel().selectionChanged.connect(self.selection_changed)
# Document contents
self.doc_webview = gui.WebviewWidget(self.splitter, debug=False)
Expand Down Expand Up @@ -467,9 +498,18 @@ def select_variables(self):
self.display_listbox.set_selection(self.display_features)

def list_docs(self):
"""List documents into the left scrolling area"""
docs = self.regenerate_docs()
self.doc_list_model.setup_data(self.corpus.titles.tolist(), docs)
match_counts = []

try:
regex = re.compile(self.regexp_filter.strip("|"), re.IGNORECASE)
except re.error:
regex = re.compile("")

for doc in docs:
match_counts.append(len(regex.findall(doc)) if regex.pattern else 0)

self.doc_list_model.setup_data(self.corpus.titles.tolist(), docs, match_counts)

def get_selected_indexes(self) -> Set[int]:
m = self.doc_list.model().mapToSource
Expand Down Expand Up @@ -597,6 +637,7 @@ def refresh_search(self):
self.Error.invalid_regex.clear()
if self.corpus is not None:
self.doc_list.model().set_filter_string(self.regexp_filter)
self.doc_list.setColumnHidden(1, not bool(self.regexp_filter.strip("|")))
if not self.selected_documents:
# when currently selected items are filtered selection is empty
# select first element in the view in that case
Expand All @@ -621,8 +662,12 @@ def refresh_search(self):
self.commit.deferred()

def on_done(self, res: int):
"""When matches count is done show the result in the label"""
"""When matches count is done show the result in the label and update match counts"""
self.n_matches = f"{int(res):,}" if res is not None else "n/a"
if self.compiled_regex and self.corpus:
docs = self.doc_list_model.get_filter_content()
match_counts = [len(self.compiled_regex.findall(doc)) for doc in docs]
self.doc_list_model.set_match_counts(match_counts)

def on_exception(self, ex):
raise ex
Expand All @@ -649,6 +694,19 @@ def commit(self):
mask[selected_docs] = 0
unmatched = self.corpus[mask] if mask.any() else None
annotated_corpus = create_annotated_table(self.corpus, selected_docs)

if annotated_corpus is not None:
match_counts = self.doc_list_model._DocumentListModel__match_counts
match_var = ContinuousVariable("Match Count")

domain = annotated_corpus.domain
new_domain = Domain(
domain.attributes,
domain.class_vars,
domain.metas + (match_var,)
)
annotated_corpus = Corpus(new_domain, annotated_corpus.X, annotated_corpus.Y, np.column_stack([annotated_corpus.metas, np.array(match_counts, dtype=object).reshape(-1, 1)]))

self.Outputs.matching_docs.send(matched)
self.Outputs.other_docs.send(unmatched)
self.Outputs.corpus.send(annotated_corpus)
Expand Down
50 changes: 32 additions & 18 deletions orangecontrib/text/widgets/tests/test_owcorpusviewer.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,25 +32,25 @@ def test_data(self):
self.assertListEqual(model.get_filter_content(), contents)
self.assertEqual(model.rowCount(), 3)

self.assertEqual(model.data(model.index(0)), documents[0])
self.assertEqual(model.data(model.index(1)), documents[1])
self.assertEqual(model.data(model.index(2)), documents[2])
self.assertEqual(model.data(model.index(0, 0)), documents[0])
self.assertEqual(model.data(model.index(1, 0)), documents[1])
self.assertEqual(model.data(model.index(2, 0)), documents[2])

def test_data_method(self):
model = DocumentListModel()
documents = ["Doc 1", "Doc 2", "Doc 3"]
contents = ["bar", "foo", "bar foo"]
model.setup_data(documents, contents)

self.assertEqual(model.data(model.index(0), Qt.DisplayRole), documents[0])
self.assertEqual(model.data(model.index(1), Qt.DisplayRole), documents[1])
self.assertEqual(model.data(model.index(2), Qt.DisplayRole), documents[2])
self.assertEqual(model.data(model.index(0, 0), Qt.DisplayRole), documents[0])
self.assertEqual(model.data(model.index(1, 0), Qt.DisplayRole), documents[1])
self.assertEqual(model.data(model.index(2, 0), Qt.DisplayRole), documents[2])

self.assertEqual(model.data(model.index(0), Qt.UserRole), contents[0])
self.assertEqual(model.data(model.index(1), Qt.UserRole), contents[1])
self.assertEqual(model.data(model.index(2), Qt.UserRole), contents[2])
self.assertEqual(model.data(model.index(0, 0), Qt.UserRole), contents[0])
self.assertEqual(model.data(model.index(1, 0), Qt.UserRole), contents[1])
self.assertEqual(model.data(model.index(2, 0), Qt.UserRole), contents[2])

self.assertIsNone(model.data(model.index(2), Qt.BackgroundRole))
self.assertIsNone(model.data(model.index(2, 0), Qt.BackgroundRole))

def test_update_filter_content(self):
model = DocumentListModel()
Expand All @@ -59,9 +59,9 @@ def test_update_filter_content(self):
model.setup_data(documents, contents)

model.update_filter_content(["a", "b", "c"])
self.assertEqual(model.data(model.index(0), Qt.UserRole), "a")
self.assertEqual(model.data(model.index(1), Qt.UserRole), "b")
self.assertEqual(model.data(model.index(2), Qt.UserRole), "c")
self.assertEqual(model.data(model.index(0, 0), Qt.UserRole), "a")
self.assertEqual(model.data(model.index(1, 0), Qt.UserRole), "b")
self.assertEqual(model.data(model.index(2, 0), Qt.UserRole), "c")

with self.assertRaises(AssertionError):
model.update_filter_content(
Expand Down Expand Up @@ -119,26 +119,34 @@ def test_search(self):
self.send_signal(self.widget.Inputs.corpus, self.corpus)
self.widget.regexp_filter = "Human"
self.widget.refresh_search()
self.wait_until_finished()

sel_model = self.widget.doc_list.selectionModel()
sel_model.select(sel_model.model().index(0, 0), QItemSelectionModel.Select | QItemSelectionModel.Rows)

self.process_events()
out_corpus = self.get_output(self.widget.Outputs.matching_docs)
self.assertIsNotNone(out_corpus)
self.assertEqual(len(out_corpus), 1)
self.assertEqual(self.widget.n_matches, 7)
self.assertEqual(int(self.widget.n_matches), 7)

# first document is selected, when filter with word that is not in
# selected document, first of shown documents is selected
self.widget.regexp_filter = "graph"
self.widget.refresh_search()
self.wait_until_finished()
self.process_events()
self.assertEqual(1, len(self.get_output(self.widget.Outputs.matching_docs)))
# word count doesn't depend on selection
self.assertEqual(self.widget.n_matches, 7)
self.assertEqual(int(self.widget.n_matches), 7)

# when filter is removed, matched words is 0
self.widget.regexp_filter = ""
self.widget.refresh_search()
self.wait_until_finished()
self.process_events()
self.wait_until_finished()
self.assertEqual(self.widget.n_matches, 0)
self.assertEqual(int(self.widget.n_matches), 0)

def test_invalid_regex(self):
# Error is shown when invalid regex is entered
Expand Down Expand Up @@ -205,7 +213,7 @@ def test_output(self):
)
self.assertEqual(8, len(self.get_output(self.widget.Outputs.other_docs)))
self.assertEqual(
len(self.corpus.domain.metas) + 1,
len(self.corpus.domain.metas) + 2,
len(self.get_output(self.widget.Outputs.corpus).domain.metas),
)

Expand Down Expand Up @@ -370,7 +378,13 @@ def test_migrate_settings(self):
domain = self.corpus.domain
self.assertListEqual(self.widget.display_features, [domain["Category"]])
self.assertListEqual(self.widget.search_features, [domain["Text"]])


def test_match_count_is_in_metas(self):
self.send_signal(self.widget.Inputs.corpus, self.corpus)
self.widget.doc_list.selectAll()
output = self.get_output(self.widget.Outputs.corpus)
meta_names = [var.name for var in output.domain.metas]
self.assertIn("Match Count", meta_names)

if __name__ == "__main__":
unittest.main()