diff --git a/Orange/data/sql/backend/base.py b/Orange/data/sql/backend/base.py index 208e01a375c..b7ffad70c32 100644 --- a/Orange/data/sql/backend/base.py +++ b/Orange/data/sql/backend/base.py @@ -68,6 +68,37 @@ def list_tables(self, schema=None): tables.append(TableDesc(name, schema, sql)) return tables + def n_tables_query(self, schema=None) -> str: + """Return a query to count tables in database. + + Parameters + ---------- + schema : Optional[str] + If set, only tables from schema should be listed + + Returns + ------- + Query string. + """ + raise NotImplementedError + + def n_tables(self, schema=None) -> int: + """Return number of tables in database. + + Parameters + ---------- + schema : Optional[str] + If set, only tables from given schema will be listed. + + Returns + ------- + Number of tables in the database. + """ + query = self.n_tables_query(schema) + with self.execute_sql_query(query) as cur: + res = cur.fetchone() + return res[0] + def get_fields(self, table_name): """Return a list of field names and metadata in the given table diff --git a/Orange/data/sql/backend/mssql.py b/Orange/data/sql/backend/mssql.py index 4562b59b328..db1de80cd5c 100644 --- a/Orange/data/sql/backend/mssql.py +++ b/Orange/data/sql/backend/mssql.py @@ -43,6 +43,9 @@ def list_tables_query(self, schema=None): ORDER BY [TABLE_NAME] """ + def n_tables_query(self, _=None) -> str: + return "SELECT COUNT(*) FROM information_schema.tables" + def quote_identifier(self, name): return "[{}]".format(name) diff --git a/Orange/data/sql/backend/postgres.py b/Orange/data/sql/backend/postgres.py index d02adaad26f..ec22004f02e 100644 --- a/Orange/data/sql/backend/postgres.py +++ b/Orange/data/sql/backend/postgres.py @@ -21,7 +21,7 @@ class Psycopg2Backend(Backend): display_name = "PostgreSQL" connection_pool = None - auto_create_extensions = True + auto_create_extensions = False def __init__(self, connection_params): super().__init__(connection_params) @@ -113,6 +113,12 @@ def list_tables_query(self, schema=None): AND NOT c.relname LIKE '\\_\\_%' ORDER BY 1,2;""".format(schema_clause) + def n_tables_query(self, schema=None) -> str: + query = "SELECT COUNT(*) FROM information_schema.tables" + if schema: + query += f" WHERE table_schema = '{schema}'" + return query + def create_variable(self, field_name, field_metadata, type_hints, inspect_table=None): if field_name in type_hints: diff --git a/Orange/widgets/data/owsql.py b/Orange/widgets/data/owsql.py index a8bdd5d9330..b2d46decc23 100644 --- a/Orange/widgets/data/owsql.py +++ b/Orange/widgets/data/owsql.py @@ -1,7 +1,10 @@ -from AnyQt.QtWidgets import QComboBox, QTextEdit, QMessageBox, QApplication +from AnyQt.QtWidgets import QComboBox, QTextEdit, QMessageBox, QApplication, \ + QGridLayout, QLineEdit from AnyQt.QtGui import QCursor from AnyQt.QtCore import Qt +from orangewidget.utils.combobox import ComboBoxSearch + from Orange.data import Table from Orange.data.sql.backend import Backend from Orange.data.sql.backend.base import BackendError @@ -14,6 +17,7 @@ from Orange.widgets.widget import Output, Msg MAX_DL_LIMIT = 1000000 +MAX_TABLES = 1000 def is_postgres(backend): @@ -52,11 +56,12 @@ class Outputs: buttons_area_orientation = None + TABLE, CUSTOM_SQL = range(2) selected_backend = Setting(None) + data_source = Setting(TABLE) table = Setting(None) sql = Setting("") guess_values = Setting(True) - download = Setting(False) materialize = Setting(False) materialize_table_name = Setting("") @@ -64,9 +69,6 @@ class Outputs: class Information(OWBaseSql.Information): data_sampled = Msg("Data description was generated from a sample.") - class Warning(OWBaseSql.Warning): - missing_extension = Msg("Database is missing extensions: {}") - class Error(OWBaseSql.Error): no_backends = Msg("Please install a backend to use this widget.") @@ -76,9 +78,9 @@ def __init__(self): self.backendcombo = None self.tables = None self.tablecombo = None + self.tabletext = None self.sqltext = None self.custom_sql = None - self.downloadcb = None super().__init__() def _setup_gui(self): @@ -106,21 +108,33 @@ def __backend_changed(self): self.selected_backend = backend.display_name if backend else None def _add_tables_controls(self): - vbox = gui.vBox(self.controlArea, "Tables") - box = gui.vBox(vbox) + box = gui.vBox(self.controlArea, 'Data Selection') + form = QGridLayout() + radio_buttons = gui.radioButtons( + box, self, 'data_source', orientation=form, + callback=self.__on_data_source_changed) + radio_table = gui.appendRadioButton( + radio_buttons, 'Table:', addToLayout=False) + radio_custom_sql = gui.appendRadioButton( + radio_buttons, 'Custom SQL:', addToLayout=False) + self.tables = TableModel() - self.tablecombo = QComboBox( + self.tablecombo = ComboBoxSearch( minimumContentsLength=35, sizeAdjustPolicy=QComboBox.AdjustToMinimumContentsLengthWithIcon ) self.tablecombo.setModel(self.tables) self.tablecombo.setToolTip('table') self.tablecombo.activated[int].connect(self.select_table) - box.layout().addWidget(self.tablecombo) + + self.tabletext = QLineEdit(placeholderText='TABLE_NAME') + self.tabletext.setToolTip('table') + self.tabletext.editingFinished.connect(self.select_table) + self.tabletext.setVisible(False) self.custom_sql = gui.vBox(box) - self.custom_sql.setVisible(False) + self.custom_sql.setVisible(self.data_source == self.CUSTOM_SQL) self.sqltext = QTextEdit(self.custom_sql) self.sqltext.setPlainText(self.sql) self.custom_sql.layout().addWidget(self.sqltext) @@ -133,15 +147,18 @@ def _add_tables_controls(self): gui.button(self.custom_sql, self, 'Execute', callback=self.open_table) - box.layout().addWidget(self.custom_sql) + form.addWidget(radio_table, 1, 0, Qt.AlignLeft) + form.addWidget(self.tablecombo, 1, 1) + form.addWidget(self.tabletext, 1, 1) + form.addWidget(radio_custom_sql, 2, 0, Qt.AlignLeft) gui.checkBox(box, self, "guess_values", "Auto-discover categorical variables", callback=self.open_table) - self.downloadcb = gui.checkBox(box, self, "download", - "Download data to local memory", - callback=self.open_table) + def __on_data_source_changed(self): + self.custom_sql.setVisible(self.data_source == self.CUSTOM_SQL) + self.select_table() def highlight_error(self, text=""): err = ['', 'QLineEdit {border: 2px solid red;}'] @@ -155,14 +172,6 @@ def get_backend(self): return self.backends[self.backendcombo.currentIndex()] def on_connection_success(self): - if getattr(self.backend, 'missing_extension', False): - self.Warning.missing_extension( - ", ".join(self.backend.missing_extension)) - self.download = True - self.downloadcb.setEnabled(False) - if not is_postgres(self.backend): - self.download = True - self.downloadcb.setEnabled(False) super().on_connection_success() self.refresh_tables() self.select_table() @@ -173,8 +182,6 @@ def on_connection_error(self, err): def clear(self): super().clear() - self.Warning.missing_extension.clear() - self.downloadcb.setEnabled(True) self.highlight_error() self.tablecombo.clear() self.tablecombo.repaint() @@ -186,37 +193,44 @@ def refresh_tables(self): return self.tables.append("Select a table") - self.tables.append("Custom SQL") - self.tables.extend(self.backend.list_tables(self.schema)) - index = self.tablecombo.findText(str(self.table)) - self.tablecombo.setCurrentIndex(index if index != -1 else 0) + if self.backend.n_tables(self.schema) <= MAX_TABLES: + self.tables.extend(self.backend.list_tables(self.schema)) + index = self.tablecombo.findText(str(self.table)) + self.tablecombo.setCurrentIndex(index if index != -1 else 0) + self.tablecombo.setVisible(True) + self.tabletext.setVisible(False) + else: + self.tablecombo.setVisible(False) + self.tabletext.setVisible(True) self.tablecombo.repaint() # Called on tablecombo selection change: def select_table(self): - curIdx = self.tablecombo.currentIndex() - if self.tablecombo.itemText(curIdx) != "Custom SQL": - self.custom_sql.setVisible(False) + if self.data_source == self.TABLE: return self.open_table() else: - self.custom_sql.setVisible(True) self.data_desc_table = None - self.database_desc["Table"] = "(None)" + if self.database_desc: + self.database_desc["Table"] = "(None)" self.table = None if len(str(self.sql)) > 14: return self.open_table() return None def get_table(self): + if self.backend is None: + return None curIdx = self.tablecombo.currentIndex() - if curIdx <= 0: + if self.data_source == self.TABLE and curIdx <= 0 and \ + self.tabletext.text() == "": if self.database_desc: self.database_desc["Table"] = "(None)" self.data_desc_table = None return None - if self.tablecombo.itemText(curIdx) != "Custom SQL": - self.table = self.tables[self.tablecombo.currentIndex()] + if self.data_source == self.TABLE: + self.table = self.tables[curIdx] if curIdx > 0 else \ + self.tabletext.text() self.database_desc["Table"] = self.table if "Query" in self.database_desc: del self.database_desc["Query"] @@ -290,45 +304,45 @@ def get_table(self): QApplication.restoreOverrideCursor() table.domain = domain - if self.download: - if table.approx_len() > AUTO_DL_LIMIT: - if is_postgres(self.backend): - confirm = QMessageBox(self) - confirm.setIcon(QMessageBox.Warning) - confirm.setText("Data appears to be big. Do you really " - "want to download it to local memory?\n" - "Table length: {:,}. Limit {:,}".format(table.approx_len(), MAX_DL_LIMIT)) - - if table.approx_len() <= MAX_DL_LIMIT: - confirm.addButton("Yes", QMessageBox.YesRole) - no_button = confirm.addButton("No", QMessageBox.NoRole) - sample_button = confirm.addButton("Yes, a sample", - QMessageBox.YesRole) - confirm.exec() - if confirm.clickedButton() == no_button: - return None - elif confirm.clickedButton() == sample_button: - table = table.sample_percentage( - AUTO_DL_LIMIT / table.approx_len() * 100) + if table.approx_len() > AUTO_DL_LIMIT: + if is_postgres(self.backend): + confirm = QMessageBox(self) + confirm.setIcon(QMessageBox.Warning) + confirm.setText("Data appears to be big. Do you really " + "want to download it to local memory?\n" + "Table length: {:,}. Limit {:,}".format( + table.approx_len(), MAX_DL_LIMIT)) + + if table.approx_len() <= MAX_DL_LIMIT: + confirm.addButton("Yes", QMessageBox.YesRole) + no_button = confirm.addButton("No", QMessageBox.NoRole) + sample_button = confirm.addButton("Yes, a sample", + QMessageBox.YesRole) + confirm.exec() + if confirm.clickedButton() == no_button: + return None + elif confirm.clickedButton() == sample_button: + table = table.sample_percentage( + AUTO_DL_LIMIT / table.approx_len() * 100) + else: + if table.approx_len() > MAX_DL_LIMIT: + QMessageBox.warning( + self, 'Warning', + "Data is too big to download.\n" + "Table length: {:,}. Limit {:,}".format(table.approx_len(), MAX_DL_LIMIT) + ) + return None else: - if table.approx_len() > MAX_DL_LIMIT: - QMessageBox.warning( - self, 'Warning', - "Data is too big to download.\n" - "Table length: {:,}. Limit {:,}".format(table.approx_len(), MAX_DL_LIMIT) - ) + confirm = QMessageBox.question( + self, 'Question', + "Data appears to be big. Do you really " + "want to download it to local memory?", + QMessageBox.Yes | QMessageBox.No, QMessageBox.No) + if confirm == QMessageBox.No: return None - else: - confirm = QMessageBox.question( - self, 'Question', - "Data appears to be big. Do you really " - "want to download it to local memory?", - QMessageBox.Yes | QMessageBox.No, QMessageBox.No) - if confirm == QMessageBox.No: - return None - - table.download_data(MAX_DL_LIMIT) - table = Table(table) + + table.download_data(MAX_DL_LIMIT) + table = Table(table) return table diff --git a/Orange/widgets/data/tests/test_owsql.py b/Orange/widgets/data/tests/test_owsql.py index 546a3b95958..72ac87d23dc 100644 --- a/Orange/widgets/data/tests/test_owsql.py +++ b/Orange/widgets/data/tests/test_owsql.py @@ -9,6 +9,17 @@ from Orange.widgets.tests.base import WidgetTest, simulate from Orange.tests.sql.base import DataBaseTest as dbt +mock_msgbox = mock.MagicMock() +mock_msgbox().addButton.return_value = "NO" +mock_msgbox().clickedButton.return_value = "NO" + + +def mock_sqltable(*args, **_): + table = Table(args[1]) + table.get_domain = lambda **_: table.domain + table.download_data = lambda *_: 1 + return table + class TestOWSqlConnected(WidgetTest, dbt): def setUpDB(self): @@ -28,7 +39,7 @@ def test_connection(self): self.assertFalse(self.widget.Error.connection.is_shown()) self.assertIsNotNone(self.widget.database_desc) - tables = ["Select a table", "Custom SQL"] + tables = ["Select a table"] self.assertTrue(set(self.widget.tables).issuperset(set(tables))) @dbt.run_on(["postgres"]) @@ -62,42 +73,6 @@ def set_connection_params(self): class TestOWSql(WidgetTest): - - @mock.patch('Orange.widgets.data.owsql.Backend') - def test_missing_extension(self, mock_backends): - """Test for correctly handled missing backend extension""" - backend = mock.Mock() - backend().display_name = "PostgreSQL" - backend().missing_extension = ["missing extension"] - backend().list_tables.return_value = [] - mock_backends.available_backends.return_value = [backend] - - settings = {"host": "host", "port": "port", - "database": "DB", "schema": "", - "username": "username", "password": "password"} - widget = self.create_widget(OWSql, stored_settings=settings) - - self.assertTrue(widget.Warning.missing_extension.is_shown()) - self.assertTrue(widget.download) - self.assertFalse(widget.downloadcb.isEnabled()) - - @mock.patch('Orange.widgets.data.owsql.Backend') - def test_non_postgres(self, mock_backends): - """Test if download is enforced for non postgres backends""" - backend = mock.Mock() - backend().display_name = "database" - del backend().missing_extension - backend().list_tables.return_value = [] - mock_backends.available_backends.return_value = [backend] - - settings = {"host": "host", "port": "port", - "database": "DB", "schema": "", - "username": "username", "password": "password"} - widget = self.create_widget(OWSql, stored_settings=settings) - - self.assertTrue(widget.download) - self.assertFalse(widget.downloadcb.isEnabled()) - @mock.patch('Orange.widgets.data.owsql.Table', mock.PropertyMock(return_value=Table('iris'))) @mock.patch('Orange.widgets.data.owsql.SqlTable') @@ -108,6 +83,7 @@ def test_restore_table(self, mock_backends, mock_sqltable): backend().display_name = "database" del backend().missing_extension backend().list_tables.return_value = ["a", "b", "c"] + backend().n_tables.return_value = 3 mock_backends.available_backends.return_value = [backend] mock_sqltable().approx_len.return_value = 100 @@ -144,6 +120,93 @@ def test_selected_backend(self, mocked_backends: mock.Mock): widget = self.create_widget(OWSql, stored_settings=settings) self.assertEqual(widget.backendcombo.currentText(), "") + @mock.patch('Orange.widgets.data.owsql.Backend') + def test_data_source(self, mocked_backends: mock.Mock): + widget: OWSql = self.create_widget(OWSql) + widget.controls.data_source.buttons[OWSql.CUSTOM_SQL].click() + + backend = mock.Mock() + backend().display_name = "Dummy Backend" + backend().list_tables.return_value = ["a", "b", "c"] + backend().n_tables.return_value = 3 + mocked_backends.available_backends.return_value = [backend] + + settings = {"selected_backend": "Dummy Backend", + "host": "host", "port": "port", "database": "DB", + "schema": "", "username": "username", + "password": "password"} + widget: OWSql = self.create_widget(OWSql, stored_settings=settings) + self.assertEqual(widget.tablecombo.currentText(), "Select a table") + self.assertFalse(widget.tablecombo.isHidden()) + self.assertTrue(widget.tabletext.isHidden()) + self.assertTrue(widget.custom_sql.isHidden()) + + widget.controls.data_source.buttons[OWSql.CUSTOM_SQL].click() + self.assertEqual(widget.tablecombo.currentText(), "Select a table") + self.assertFalse(widget.tablecombo.isHidden()) + self.assertTrue(widget.tabletext.isHidden()) + self.assertFalse(widget.custom_sql.isHidden()) + + widget.controls.data_source.buttons[OWSql.TABLE].click() + self.assertEqual(widget.tablecombo.currentText(), "Select a table") + self.assertFalse(widget.tablecombo.isHidden()) + self.assertTrue(widget.tabletext.isHidden()) + self.assertTrue(widget.custom_sql.isHidden()) + + @mock.patch('Orange.widgets.data.owsql.MAX_TABLES', 2) + @mock.patch('Orange.widgets.data.owsql.SqlTable', + mock.Mock(side_effect=mock_sqltable)) + @mock.patch('Orange.widgets.data.owsql.Backend') + def test_table_text(self, mocked_backends: mock.Mock): + backend = mock.Mock() + backend().display_name = "Dummy Backend" + backend().list_tables.return_value = ["iris", "zoo", "titanic"] + backend().n_tables.return_value = 3 + mocked_backends.available_backends.return_value = [backend] + + settings = {"selected_backend": "Dummy Backend", + "host": "host", "port": "port", "database": "DB", + "schema": "", "username": "username", + "password": "password"} + widget: OWSql = self.create_widget(OWSql, stored_settings=settings) + self.assertTrue(widget.tablecombo.isHidden()) + self.assertFalse(widget.tabletext.isHidden()) + widget.tabletext.setText("zoo") + widget.select_table() + output = self.get_output(widget.Outputs.data, widget=widget) + self.assertIsInstance(output, Table) + self.assertEqual(len(output), 101) + + @mock.patch('Orange.widgets.data.owsql.AUTO_DL_LIMIT', 120) + @mock.patch('Orange.widgets.data.owsql.is_postgres', + mock.Mock(return_value=True)) + @mock.patch('Orange.widgets.data.owsql.QMessageBox', mock_msgbox) + @mock.patch('Orange.widgets.data.owsql.SqlTable', + mock.Mock(side_effect=mock_sqltable)) + @mock.patch('Orange.widgets.data.owsql.Backend') + def test_auto_dl_limit(self, mocked_backends: mock.Mock): + backend = mock.Mock() + backend().display_name = "Dummy Backend" + backend().list_tables.return_value = ["iris", "zoo", "titanic"] + backend().n_tables.return_value = 3 + mocked_backends.available_backends.return_value = [backend] + + settings = {"selected_backend": "Dummy Backend", + "host": "host", "port": "port", "database": "DB", + "schema": "", "username": "username", + "password": "password"} + widget: OWSql = self.create_widget(OWSql, stored_settings=settings) + widget.tablecombo.setCurrentIndex(2) + widget.select_table() + output = self.get_output(widget.Outputs.data, widget=widget) + self.assertIsInstance(output, Table) + self.assertEqual(len(output), 101) + + widget.tablecombo.setCurrentIndex(1) + widget.select_table() + output = self.get_output(widget.Outputs.data, widget=widget) + self.assertIsNone(output) + if __name__ == "__main__": unittest.main() diff --git a/i18n/si/msgs.jaml b/i18n/si/msgs.jaml index ceedc1def5c..e1f543d269a 100644 --- a/i18n/si/msgs.jaml +++ b/i18n/si/msgs.jaml @@ -1952,6 +1952,8 @@ data/sql/backend/mssql.py: WHERE TABLE_TYPE in ('VIEW' ,'BASE TABLE') ORDER BY [TABLE_NAME] ": false + def `n_tables_query`: + SELECT COUNT(*) FROM information_schema.tables: false def `quote_identifier`: [{}]: false def `create_sql_query`: @@ -2024,6 +2026,9 @@ data/sql/backend/postgres.py: {} AND NOT c.relname LIKE '\\_\\_%' ORDER BY 1,2;": false + def `n_tables_query`: + SELECT COUNT(*) FROM information_schema.tables: false + " WHERE table_schema = '{schema}'": false def `create_variable`: extract(epoch from {}): false ({})::double precision: false @@ -8214,13 +8219,15 @@ widgets/data/owsql.py: Attribute-valued dataset read from the input file.: Podatki prebrani iz baze class `Information`: Data description was generated from a sample.: Podatki vsebujejo samo vzorec. - class `Warning`: - 'Database is missing extensions: {}': Podatkovna baza nima razširitev: {} class `Error`: Please install a backend to use this widget.: Namestite knjižnico za podatkovno bazo. def `_add_tables_controls`: - Tables: Tabele + Data Selection: Podatki + data_source: false + Table:: Tabela + Custom SQL:: Prikrojen SQL table: tabela + TABLE_NAME: IME_TABELE materialize: false 'Materialize to table ': 'Materializiraj v tabelo ' Save results of the query in a table: Shrani podatke poizvedbe v tabelo @@ -8228,24 +8235,17 @@ widgets/data/owsql.py: Execute: Izvedi guess_values: false Auto-discover categorical variables: Samodejno zaznaj kategorične spremenljivke - download: false - Download data to local memory: Shrani podatke v pomnilnik def `highlight_error`: 'QLineEdit {border: 2px solid red;}': false server: false host: false role: false database: false - def `on_connection_success`: - missing_extension: false - ', ': true def `on_connection_error`: \n: false def `refresh_tables`: Select a table: Izberi tabelo - Custom SQL: Prikrojen SQL def `select_table`: - Custom SQL: Prikrojen SQL Table: Tabela (None): (Brez) def `get_table`: