Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 72 additions & 0 deletions orangecontrib/text/tests/test_owdocumentembedding.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import pytest
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cannot import pytest as pytest is not a dependency of Orange.

from orangecontrib.text.vectorization.document_embedder import DocumentEmbedder
from orangecontrib.text import Corpus
from Orange.misc.utils.embedder_utils import EmbeddingConnectionError
from orangecontrib.text.vectorization.document_embedder import _ServerEmbedder
from urllib.parse import urlparse
import socket

@pytest.fixture
def dummy_corpus():
return Corpus.from_documents(["This is a test document."], name="test")

def test_embedding_valid_server(dummy_corpus):
embedder = DocumentEmbedder(language="en", aggregator="Mean")
new_corpus, skipped = embedder._transform(dummy_corpus, None)
assert new_corpus is not None
assert skipped is None or len(skipped) == 0

def test_invalid_server_raises(dummy_corpus):
class BrokenEmbedder(DocumentEmbedder):
def _transform(self, corpus, _, callback=None):

embedder = _ServerEmbedder(
aggregator="mean",
model_name="fasttext-en",
max_parallel_requests=100,
server_url="https://api.invalidserver.io",
embedder_type="text",
)

url = urlparse(embedder.server_url)
host, port = url.hostname, url.port or (443 if url.scheme == "https" else 80)
try:
socket.create_connection((host, port), timeout=3)
except Exception as e:
raise EmbeddingConnectionError("The server is not responding") from e

return [], None

embedder = BrokenEmbedder(language="en", aggregator="Mean")
with pytest.raises(EmbeddingConnectionError, match="server is not responding"):
embedder._transform(dummy_corpus, None)

def test_no_internet_raises(dummy_corpus, monkeypatch):
class NoInternetEmbedder(DocumentEmbedder):
def _transform(self, corpus, _, callback=None):

embedder = _ServerEmbedder(
aggregator="mean",
model_name="fasttext-en",
max_parallel_requests=100,
server_url="https://api.garaza.io",
embedder_type="text",
)

def raise_os_error(*args, **kwargs):
raise OSError("Simulated: No internet connection")

monkeypatch.setattr("socket.create_connection", raise_os_error)

url = urlparse(embedder.server_url)
host, port = url.hostname, url.port or (443 if url.scheme == "https" else 80)
try:
socket.create_connection((host, port), timeout=3)
except Exception as e:
raise EmbeddingConnectionError("No internet connection") from e

return [], None

embedder = NoInternetEmbedder(language="en", aggregator="Mean")
with pytest.raises(EmbeddingConnectionError, match="No internet connection"):
embedder._transform(dummy_corpus, None)
40 changes: 35 additions & 5 deletions orangecontrib/text/vectorization/document_embedder.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,10 @@
from orangecontrib.text import Corpus
from orangecontrib.text.vectorization.base import BaseVectorizer

from Orange.misc.utils.embedder_utils import EmbeddingConnectionError
import socket
from urllib.parse import urlparse

AGGREGATORS = ["mean", "sum", "max", "min"]
AGGREGATORS_ITEMS = ['Mean', 'Sum', 'Max', 'Min']
# fmt: off
Expand Down Expand Up @@ -87,10 +91,36 @@ def _transform(
server_url="https://api.garaza.io",
embedder_type="text",
)
embs = embedder.embedd_data(
list(corpus.ngrams) if isinstance(corpus, Corpus) else corpus,
callback=callback,
)

try:
url = urlparse(embedder.server_url)
host, port = url.hostname, url.port or (443 if url.scheme == "https" else 80)

try:
sock = socket.create_connection((host, port), timeout=3)
sock.close()
except socket.gaierror as e:
try:
socket.gethostbyname("example.com")
raise ConnectionError("The server is not responding (bad hostname)") from e
except socket.gaierror:
raise OSError("No internet connection (DNS failure)") from e
except (ConnectionRefusedError, socket.timeout, OSError):
raise ConnectionError("The server is not responding (socket check)")

embs = embedder.embedd_data(
list(corpus.ngrams) if isinstance(corpus, Corpus) else corpus,
callback=callback,
)
if not embs or all(e is None for e in embs):
raise ConnectionError("The server is not responding (no embeddings returned)")

except OSError as e:
raise EmbeddingConnectionError("No internet connection") from e
except ConnectionError as e:
raise EmbeddingConnectionError("The server is not responding") from e
except Exception as e:
raise EmbeddingConnectionError(f"Unknown network error: {e}") from e

if isinstance(corpus, list):
return embs
Expand Down Expand Up @@ -167,4 +197,4 @@ async def _encode_data_instance(self, data_instance: Any) -> Optional[bytes]:
if __name__ == '__main__':
with DocumentEmbedder(language='en', aggregator='Max') as embedder:
embedder.clear_cache()
embedder(Corpus.from_file('deerwester'))
embedder(Corpus.from_file('deerwester'))
20 changes: 18 additions & 2 deletions orangecontrib/text/widgets/owdocumentembedding.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,9 @@ class Error(OWWidget.Error):
"No internet connection. Please establish a connection or use "
"another vectorizer."
)
server_unresponsive = Msg(
"The server is not responding. Please check the server address or try again later."
)
unexpected_error = Msg("Embedding error: {}")

class Warning(OWWidget.Warning):
Expand Down Expand Up @@ -149,8 +152,21 @@ def on_done(self, result):

def on_exception(self, ex: Exception):
self.cancel_button.setDisabled(True)
ex_msg = str(ex.__cause__ or ex).lower()

if isinstance(ex, EmbeddingConnectionError):
self.Error.no_connection()
if "getaddrinfo failed" in ex_msg or "no internet" in ex_msg or "temporary failure" in ex_msg:
self.Error.no_connection()
elif any(signal in ex_msg for signal in [
"connection refused",
"failed to establish a new connection",
"connection aborted",
"connection reset",
"the server is not responding"
]):
self.Error.server_unresponsive()
else:
self.Error.unexpected_error(str(ex))
else:
self.Error.unexpected_error(str(ex))
self.cancel()
Expand Down Expand Up @@ -190,4 +206,4 @@ def send_report(self):
if __name__ == "__main__":
from orangewidget.utils.widgetpreview import WidgetPreview

WidgetPreview(OWDocumentEmbedding).run(Corpus.from_file("book-excerpts"))
WidgetPreview(OWDocumentEmbedding).run(Corpus.from_file("book-excerpts"))
Loading