Skip to content

Commit 706f525

Browse files
authored
Merge pull request #4083 from VesnaT/base_sql
[ENH] OWBaseSql: Base widget for connecting to a database
2 parents 10965d5 + 26ba3d6 commit 706f525

File tree

2 files changed

+283
-0
lines changed

2 files changed

+283
-0
lines changed

Orange/widgets/utils/owbasesql.py

Lines changed: 187 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,187 @@
1+
from typing import Type
2+
from collections import OrderedDict
3+
4+
from AnyQt.QtWidgets import QLineEdit, QSizePolicy
5+
6+
from Orange.data import Table
7+
from Orange.data.sql.backend import Backend
8+
from Orange.data.sql.backend.base import BackendError
9+
from Orange.widgets import gui, report
10+
from Orange.widgets.credentials import CredentialManager
11+
from Orange.widgets.settings import Setting
12+
from Orange.widgets.utils.signals import Output
13+
from Orange.widgets.widget import OWWidget, Msg
14+
15+
16+
class OWBaseSql(OWWidget, openclass=True):
17+
"""Base widget for connecting to a database.
18+
Override `get_backend` when subclassing to get corresponding backend.
19+
"""
20+
class Outputs:
21+
data = Output("Data", Table)
22+
23+
class Error(OWWidget.Error):
24+
connection = Msg("{}")
25+
26+
want_main_area = False
27+
resizing_enabled = False
28+
29+
host = Setting(None) # type: Optional[str]
30+
port = Setting(None) # type: Optional[str]
31+
database = Setting(None) # type: Optional[str]
32+
schema = Setting(None) # type: Optional[str]
33+
username = ""
34+
password = ""
35+
36+
def __init__(self):
37+
super().__init__()
38+
self.backend = None # type: Optional[Backend]
39+
self.data_desc_table = None # type: Optional[Table]
40+
self.database_desc = None # type: Optional[OrderedDict]
41+
self._setup_gui()
42+
self.connect()
43+
44+
def _setup_gui(self):
45+
self.controlArea.setMinimumWidth(360)
46+
47+
vbox = gui.vBox(self.controlArea, "Server", addSpace=True)
48+
self.serverbox = gui.vBox(vbox)
49+
self.servertext = QLineEdit(self.serverbox)
50+
self.servertext.setPlaceholderText("Server")
51+
self.servertext.setToolTip("Server")
52+
self.servertext.editingFinished.connect(self._load_credentials)
53+
if self.host:
54+
self.servertext.setText(self.host if not self.port else
55+
"{}:{}".format(self.host, self.port))
56+
self.serverbox.layout().addWidget(self.servertext)
57+
58+
self.databasetext = QLineEdit(self.serverbox)
59+
self.databasetext.setPlaceholderText("Database[/Schema]")
60+
self.databasetext.setToolTip("Database or optionally Database/Schema")
61+
if self.database:
62+
self.databasetext.setText(
63+
self.database if not self.schema else
64+
"{}/{}".format(self.database, self.schema))
65+
self.serverbox.layout().addWidget(self.databasetext)
66+
self.usernametext = QLineEdit(self.serverbox)
67+
self.usernametext.setPlaceholderText("Username")
68+
self.usernametext.setToolTip("Username")
69+
70+
self.serverbox.layout().addWidget(self.usernametext)
71+
self.passwordtext = QLineEdit(self.serverbox)
72+
self.passwordtext.setPlaceholderText("Password")
73+
self.passwordtext.setToolTip("Password")
74+
self.passwordtext.setEchoMode(QLineEdit.Password)
75+
76+
self.serverbox.layout().addWidget(self.passwordtext)
77+
78+
self._load_credentials()
79+
80+
self.connectbutton = gui.button(self.serverbox, self, "Connect",
81+
callback=self.connect)
82+
self.connectbutton.setSizePolicy(QSizePolicy.Fixed, QSizePolicy.Fixed)
83+
84+
def _load_credentials(self):
85+
self._parse_host_port()
86+
cm = self._credential_manager(self.host, self.port)
87+
self.username = cm.username
88+
self.password = cm.password
89+
90+
if self.username:
91+
self.usernametext.setText(self.username)
92+
if self.password:
93+
self.passwordtext.setText(self.password)
94+
95+
def _save_credentials(self):
96+
cm = self._credential_manager(self.host, self.port)
97+
cm.username = self.username or ""
98+
cm.password = self.password or ""
99+
100+
@staticmethod
101+
def _credential_manager(host, port):
102+
return CredentialManager("SQL Table: {}:{}".format(host, port))
103+
104+
def _parse_host_port(self):
105+
hostport = self.servertext.text().split(":")
106+
self.host = hostport[0]
107+
self.port = hostport[1] if len(hostport) == 2 else None
108+
109+
def _check_db_settings(self):
110+
self._parse_host_port()
111+
self.database, _, self.schema = self.databasetext.text().partition("/")
112+
self.username = self.usernametext.text() or None
113+
self.password = self.passwordtext.text() or None
114+
115+
def connect(self):
116+
self.clear()
117+
self._check_db_settings()
118+
if not self.host or not self.database:
119+
return
120+
try:
121+
backend = self.get_backend()
122+
if backend is None:
123+
return
124+
self.backend = backend(dict(
125+
host=self.host,
126+
port=self.port,
127+
database=self.database,
128+
user=self.username,
129+
password=self.password
130+
))
131+
self.on_connection_success()
132+
except BackendError as err:
133+
self.on_connection_error(err)
134+
135+
def get_backend(self) -> Type[Backend]:
136+
"""
137+
Derived widgets should override this to get corresponding backend.
138+
139+
Returns
140+
-------
141+
backend: Type[Backend]
142+
"""
143+
raise NotImplementedError
144+
145+
def on_connection_success(self):
146+
self._save_credentials()
147+
self.database_desc = OrderedDict((
148+
("Host", self.host), ("Port", self.port),
149+
("Database", self.database), ("User name", self.username)
150+
))
151+
152+
def on_connection_error(self, err):
153+
error = str(err).split("\n")[0]
154+
self.Error.connection(error)
155+
156+
def open_table(self):
157+
data = self.get_table()
158+
self.data_desc_table = data
159+
self.Outputs.data.send(data)
160+
info = str(len(data)) if data else self.info.NoOutput
161+
self.info.set_output_summary(info)
162+
163+
def get_table(self) -> Table:
164+
"""
165+
Derived widgets should override this to get corresponding table.
166+
167+
Returns
168+
-------
169+
table: Table
170+
"""
171+
raise NotImplementedError
172+
173+
def clear(self):
174+
self.Error.connection.clear()
175+
self.database_desc = None
176+
self.data_desc_table = None
177+
self.Outputs.data.send(None)
178+
self.info.set_output_summary(self.info.NoOutput)
179+
180+
def send_report(self):
181+
if not self.database_desc:
182+
self.report_paragraph("No database connection.")
183+
return
184+
self.report_items("Database", self.database_desc)
185+
if self.data_desc_table:
186+
self.report_items("Data",
187+
report.describe_data(self.data_desc_table))
Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
# pylint: disable=missing-docstring
2+
import unittest
3+
from unittest.mock import Mock
4+
from collections import OrderedDict
5+
from types import SimpleNamespace
6+
7+
from Orange.data import Table
8+
from Orange.data.sql.backend import Backend
9+
from Orange.widgets.tests.base import WidgetTest
10+
from Orange.widgets.utils.owbasesql import OWBaseSql
11+
from Orange.data.sql.backend.base import BackendError
12+
13+
14+
USERNAME = "UN"
15+
PASSWORD = "PASS"
16+
17+
18+
class BrokenBackend(Backend): # pylint: disable=abstract-method
19+
def __init__(self, connection_params):
20+
super().__init__(connection_params)
21+
raise BackendError("Error connecting to DB.")
22+
23+
24+
class TestableSqlWidget(OWBaseSql):
25+
name = "SQL"
26+
27+
def __init__(self):
28+
self.mocked_backend = Mock()
29+
super().__init__()
30+
31+
def get_backend(self):
32+
return self.mocked_backend
33+
34+
def get_table(self) -> Table:
35+
return Table("iris")
36+
37+
@staticmethod
38+
def _credential_manager(_, __): # pylint: disable=arguments-differ
39+
return SimpleNamespace(username=USERNAME, password=PASSWORD)
40+
41+
42+
class TestOWBaseSql(WidgetTest):
43+
def setUp(self):
44+
self.host, self.port, self.db = "host", "port", "DB"
45+
settings = {"host": self.host, "port": self.port,
46+
"database": self.db, "schema": ""}
47+
self.widget = self.create_widget(TestableSqlWidget,
48+
stored_settings=settings)
49+
50+
def test_connect(self):
51+
self.widget.mocked_backend.assert_called_once_with(
52+
{"host": "host", "port": "port", "database": self.db,
53+
"user": USERNAME, "password": PASSWORD})
54+
self.assertDictEqual(
55+
self.widget.database_desc,
56+
OrderedDict((("Host", "host"), ("Port", "port"),
57+
("Database", self.db), ("User name", USERNAME))))
58+
59+
def test_connection_error(self):
60+
self.widget.get_backend = Mock(return_value=BrokenBackend)
61+
self.widget.connectbutton.click()
62+
self.assertTrue(self.widget.Error.connection.is_shown())
63+
self.assertIsNone(self.widget.database_desc)
64+
65+
def test_output(self):
66+
self.widget.open_table()
67+
self.assertIsNotNone(self.get_output(self.widget.Outputs.data))
68+
self.assertIsNotNone(self.widget.data_desc_table)
69+
70+
def test_output_error(self):
71+
self.widget.get_table = lambda: None
72+
self.widget.open_table()
73+
self.assertIsNone(self.get_output(self.widget.Outputs.data))
74+
self.assertIsNone(self.widget.data_desc_table)
75+
76+
def test_missing_database_parameter(self):
77+
self.widget.open_table()
78+
self.widget.databasetext.setText("")
79+
self.widget.mocked_backend.reset_mock()
80+
self.widget.connectbutton.click()
81+
self.widget.mocked_backend.assert_not_called()
82+
self.assertIsNone(self.get_output(self.widget.Outputs.data))
83+
self.assertIsNone(self.widget.data_desc_table)
84+
self.assertFalse(self.widget.Error.connection.is_shown())
85+
86+
def test_report(self):
87+
self.widget.report_button.click() # DB connection
88+
self.widget.open_table()
89+
self.widget.report_button.click() # table
90+
self.widget.databasetext.setText("")
91+
self.widget.connectbutton.click()
92+
self.widget.report_button.click() # empty
93+
94+
95+
if __name__ == "__main__":
96+
unittest.main()

0 commit comments

Comments
 (0)