Skip to content

Commit 3f9686d

Browse files
authored
Merge pull request #1674 from astaric/refactor-sql
[ENH] SQL Server support in SQL widget
2 parents c90e602 + a078fe7 commit 3f9686d

File tree

12 files changed

+683
-326
lines changed

12 files changed

+683
-326
lines changed

Orange/data/sql/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,2 +0,0 @@
1-
import Orange.misc
2-
psycopg2 = Orange.misc.import_late_warning("psycopg2")
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
from .base import Backend
2+
3+
try:
4+
from .postgres import Psycopg2Backend
5+
except ImportError:
6+
pass
7+
8+
try:
9+
from .mssql import PymssqlBackend
10+
except ImportError:
11+
pass

Orange/data/sql/backend/base.py

Lines changed: 238 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,238 @@
1+
import logging
2+
from contextlib import contextmanager
3+
4+
from Orange.util import Registry
5+
6+
log = logging.getLogger(__name__)
7+
8+
9+
class BackendError(Exception):
10+
pass
11+
12+
13+
class Backend(metaclass=Registry):
14+
"""Base class for SqlTable backends. Implementations should define
15+
all of the methods defined below.
16+
17+
Parameters
18+
----------
19+
connection_params: dict
20+
connection params
21+
"""
22+
23+
display_name = ""
24+
25+
def __init__(self, connection_params):
26+
self.connection_params = connection_params
27+
28+
@classmethod
29+
def available_backends(cls):
30+
"""Return a list of all available backends"""
31+
return cls.registry.values()
32+
33+
# "meta" methods
34+
35+
def list_tables_query(self, schema=None):
36+
"""Return a list of tuples (schema, table_name)
37+
38+
Parameters
39+
----------
40+
schema : Optional[str]
41+
If set, only tables from schema should be listed
42+
43+
Returns
44+
-------
45+
A list of tuples
46+
"""
47+
raise NotImplementedError
48+
49+
def list_tables(self, schema=None):
50+
"""Return a list of tables in database
51+
52+
Parameters
53+
----------
54+
schema : Optional[str]
55+
If set, only tables from given schema will be listed
56+
57+
Returns
58+
-------
59+
A list of TableDesc objects, describing the tables in the database
60+
"""
61+
query = self.list_tables_query(schema)
62+
with self.execute_sql_query(query) as cur:
63+
tables = []
64+
for schema, name in cur.fetchall():
65+
sql = "{}.{}".format(
66+
self.quote_identifier(schema),
67+
self.quote_identifier(name)) if schema else self.quote_identifier(name)
68+
tables.append(TableDesc(name, schema, sql))
69+
return tables
70+
71+
def get_fields(self, table_name):
72+
"""Return a list of field names and metadata in the given table
73+
74+
Parameters
75+
----------
76+
table_name: str
77+
78+
Returns
79+
-------
80+
a list of tuples (field_name, *field_metadata)
81+
both will be passed to create_variable
82+
"""
83+
query = self.create_sql_query(table_name, ["*"], limit=0)
84+
with self.execute_sql_query(query) as cur:
85+
return cur.description
86+
87+
def get_distinct_values(self, field_name, table_name):
88+
"""Return a list of distinct values of field
89+
90+
Parameters
91+
----------
92+
field_name : name of the field
93+
table_name : name of the table or query to search
94+
95+
Returns
96+
-------
97+
List[str] of values
98+
"""
99+
fields = [self.quote_identifier(field_name)]
100+
101+
query = self.create_sql_query(table_name, fields,
102+
group_by=fields, order_by=fields,
103+
limit=21)
104+
with self.execute_sql_query(query) as cur:
105+
values = cur.fetchall()
106+
if len(values) > 20:
107+
return ()
108+
else:
109+
return tuple(str(x[0]) for x in values)
110+
111+
def create_variable(self, field_name, field_metadata,
112+
type_hints, inspect_table=None):
113+
"""Create variable based on field information
114+
115+
Parameters
116+
----------
117+
field_name : str
118+
name do the field
119+
field_metadata : tuple
120+
data to guess field type from
121+
type_hints : Domain
122+
domain with variable templates
123+
inspect_table : Option[str]
124+
name of the table to expect the field values or None
125+
if no inspection is to be performed
126+
127+
Returns
128+
-------
129+
Variable representing the field
130+
"""
131+
raise NotImplementedError
132+
133+
def count_approx(self, query):
134+
"""Return estimated number of rows returned by query.
135+
136+
Parameters
137+
----------
138+
query : str
139+
140+
Returns
141+
-------
142+
Approximate number of rows
143+
"""
144+
raise NotImplementedError
145+
146+
# query related methods
147+
148+
def create_sql_query(
149+
self, table_name, fields, filters=(),
150+
group_by=None, order_by=None, offset=None, limit=None,
151+
use_time_sample=None):
152+
"""Construct an sql query using the provided elements.
153+
154+
Parameters
155+
----------
156+
table_name : str
157+
fields : List[str]
158+
filters : List[str]
159+
group_by: List[str]
160+
order_by: List[str]
161+
offset: int
162+
limit: int
163+
use_time_sample: int
164+
165+
Returns
166+
-------
167+
string containing sql query
168+
"""
169+
raise NotImplementedError
170+
171+
@contextmanager
172+
def execute_sql_query(self, query, params=None):
173+
"""Context manager for execution of sql queries
174+
175+
Usage:
176+
```
177+
with backend.execute_sql_query("SELECT * FROM foo") as cur:
178+
cur.fetch_all()
179+
```
180+
181+
Parameters
182+
----------
183+
query : string
184+
query to be executed
185+
params: tuple
186+
parameters to be passed to the query
187+
188+
Returns
189+
-------
190+
yields a cursor that can be used to access the data
191+
"""
192+
raise NotImplementedError
193+
194+
def quote_identifier(self, name):
195+
"""Quote identifier name so it can be safely used in queries
196+
197+
Parameters
198+
----------
199+
name: str
200+
name of the parameter
201+
202+
Returns
203+
-------
204+
quoted parameter that can be used in sql queries
205+
"""
206+
raise NotImplementedError
207+
208+
def unquote_identifier(self, quoted_name):
209+
"""Remove quotes from identifier name
210+
Used when sql table name is used in where parameter to
211+
query special tables
212+
213+
Parameters
214+
----------
215+
quoted_name : str
216+
217+
Returns
218+
-------
219+
unquoted name
220+
"""
221+
raise NotImplementedError
222+
223+
224+
class TableDesc:
225+
def __init__(self, name, schema, sql):
226+
self.name = name
227+
self.schema = schema
228+
self.sql = sql
229+
230+
def __str__(self):
231+
return self.name
232+
233+
class ToSql:
234+
def __init__(self, sql):
235+
self.sql = sql
236+
237+
def __call__(self):
238+
return self.sql

Orange/data/sql/backend/mssql.py

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,118 @@
1+
from contextlib import contextmanager
2+
3+
import pymssql
4+
5+
from Orange.data import StringVariable, TimeVariable, ContinuousVariable, DiscreteVariable
6+
from Orange.data.sql.backend import Backend
7+
from Orange.data.sql.backend.base import ToSql, BackendError
8+
9+
10+
class PymssqlBackend(Backend):
11+
display_name = "SQL Server"
12+
13+
def __init__(self, connection_params):
14+
connection_params["server"] = connection_params.pop("host", None)
15+
16+
for key in list(connection_params):
17+
if connection_params[key] is None:
18+
del connection_params[key]
19+
20+
super().__init__(connection_params)
21+
try:
22+
self.connection = pymssql.connect(**connection_params)
23+
except pymssql.Error as ex:
24+
raise BackendError(str(ex)) from ex
25+
26+
def list_tables_query(self, schema=None):
27+
return """
28+
SELECT [TABLE_SCHEMA], [TABLE_NAME]
29+
FROM information_schema.tables
30+
WHERE TABLE_TYPE='BASE TABLE'
31+
ORDER BY [TABLE_NAME]
32+
"""
33+
34+
def quote_identifier(self, name):
35+
return "[{}]".format(name)
36+
37+
def unquote_identifier(self, quoted_name):
38+
return quoted_name[1:-1]
39+
40+
def create_sql_query(self, table_name, fields, filters=(),
41+
group_by=None, order_by=None, offset=None, limit=None,
42+
use_time_sample=None):
43+
sql = ["SELECT"]
44+
if limit and not offset:
45+
sql.extend(["TOP", str(limit)])
46+
sql.append(', '.join(fields))
47+
sql.extend(["FROM", table_name])
48+
if use_time_sample:
49+
sql.append("TABLESAMPLE system_time(%i)" % use_time_sample)
50+
if filters:
51+
sql.extend(["WHERE", " AND ".join(filters)])
52+
if group_by:
53+
sql.extend(["GROUP BY", ", ".join(group_by)])
54+
55+
if offset and not order_by:
56+
order_by = fields[0].split("AS")[1:]
57+
58+
if order_by:
59+
sql.extend(["ORDER BY", ",".join(order_by)])
60+
if offset:
61+
sql.extend(["OFFSET", str(offset), "ROWS"])
62+
if limit:
63+
sql.extend(["FETCH FIRST", str(limit), "ROWS ONLY"])
64+
65+
return " ".join(sql)
66+
67+
@contextmanager
68+
def execute_sql_query(self, query, params=()):
69+
print(query)
70+
try:
71+
with self.connection.cursor() as cur:
72+
cur.execute(query, *params)
73+
yield cur
74+
finally:
75+
self.connection.commit()
76+
77+
def create_variable(self, field_name, field_metadata, type_hints, inspect_table=None):
78+
if field_name in type_hints:
79+
var = type_hints[field_name]
80+
else:
81+
var = self._guess_variable(field_name, field_metadata,
82+
inspect_table)
83+
84+
field_name_q = self.quote_identifier(field_name)
85+
if var.is_continuous:
86+
if isinstance(var, TimeVariable):
87+
var.to_sql = ToSql("DATEDIFF(s, '1970-01-01 00:00:00', {})".format(field_name_q))
88+
else:
89+
var.to_sql = ToSql(field_name_q)
90+
else: # discrete or string
91+
var.to_sql = ToSql(field_name_q)
92+
return var
93+
94+
def _guess_variable(self, field_name, field_metadata, inspect_table):
95+
from pymssql import STRING, NUMBER, DATETIME, DECIMAL
96+
97+
type_code, *rest = field_metadata
98+
99+
if type_code in (NUMBER, DECIMAL):
100+
return ContinuousVariable(field_name)
101+
102+
if type_code == DATETIME:
103+
tv = TimeVariable(field_name)
104+
tv.have_date = True
105+
tv.have_time = True
106+
return tv
107+
108+
if type_code == STRING:
109+
if inspect_table:
110+
values = [] #self._get_distinct_values(field_name, inspect_table)
111+
if values:
112+
return DiscreteVariable(field_name, values)
113+
114+
return StringVariable(field_name)
115+
116+
def count_approx(self, query):
117+
# TODO: Figure out how to do count estimates on mssql
118+
raise NotImplementedError

0 commit comments

Comments
 (0)