Skip to content

Commit 646b3f8

Browse files
committed
add null/bad tests
1 parent 08a1a16 commit 646b3f8

File tree

9 files changed

+106
-11
lines changed

9 files changed

+106
-11
lines changed

build/lib/data_algebra/SQLite.py

Lines changed: 27 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,32 @@
1+
import math
12
import pandas
3+
import numpy
4+
import numbers
25

36
import data_algebra.util
47
import data_algebra.data_ops
58
import data_algebra.db_model
69

10+
711
# map from op-name to special SQL formatting code
8-
SQLite_formatters = {"___": lambda dbmodel, expression: expression.to_python()}
12+
13+
def _sqlite_is_bad_expr(dbmodel, expression):
14+
return "is_bad(" + dbmodel.expr_to_sql(expression.args[0], want_inline_parens=False) + ")"
15+
16+
17+
SQLite_formatters = {
18+
"is_bad": _sqlite_is_bad_expr,
19+
}
20+
21+
22+
def _check_scalar_bad(x):
23+
if x is None:
24+
return 1
25+
if not isinstance(x, numbers.Number):
26+
return 0
27+
if numpy.isinf(x) or numpy.isnan(x):
28+
return 1
29+
return 0
930

1031

1132
class SQLiteModel(data_algebra.db_model.DBModel):
@@ -19,6 +40,11 @@ def __init__(self):
1940
sql_formatters=SQLite_formatters,
2041
)
2142

43+
def prepare_connection(self, conn):
44+
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function
45+
conn.create_function("exp", 1, math.exp)
46+
conn.create_function("is_bad", 1, _check_scalar_bad)
47+
2248
def quote_identifier(self, identifier):
2349
if not isinstance(identifier, str):
2450
raise Exception("expected identifier to be a str")

build/lib/data_algebra/db_model.py

Lines changed: 25 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,24 @@
1010
import data_algebra.cdata
1111

1212

13+
# map from op-name to special SQL formatting code
14+
15+
def _db_is_null_expr(dbmodel, expression):
16+
return "(" + dbmodel.expr_to_sql(expression.args[0], want_inline_parens=False) + " IS NULL)"
17+
18+
19+
def _db_is_bad_expr(dbmodel, expression):
20+
subexpr = dbmodel.expr_to_sql(expression.args[0], want_inline_parens=True)
21+
return "(" + subexpr + " IS NULL OR " + subexpr + " >= Infinity OR " + subexpr + " <= -Infinity " +\
22+
" OR (" + subexpr + " != 0 AND " + subexpr + " == -" + subexpr + "))"
23+
24+
25+
db_expr_formatters = {
26+
"is_null": _db_is_null_expr,
27+
"is_bad": _db_is_bad_expr,
28+
}
29+
30+
1331
class DBModel:
1432
"""A model of how SQL should be generated for a given database.
1533
"""
@@ -31,11 +49,17 @@ def __init__(
3149
self.string_quote = string_quote
3250
if sql_formatters is None:
3351
sql_formatters = {}
34-
self.sql_formatters = sql_formatters
52+
self.sql_formatters = sql_formatters.copy()
53+
for k in db_expr_formatters.keys():
54+
if k not in self.sql_formatters.keys():
55+
self.sql_formatters[k] = db_expr_formatters[k]
3556
if op_replacements is None:
3657
op_replacements = {"==": "="}
3758
self.op_replacements = op_replacements
3859

60+
def prepare_connection(self, conn):
61+
pass
62+
3963
def quote_identifier(self, identifier):
4064
if not isinstance(identifier, str):
4165
raise Exception("expected identifier to be a str")

coverage.txt

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,21 +20,21 @@ tests/test_sqlite.py . [100%]
2020
Name Stmts Miss Cover
2121
--------------------------------------------------
2222
data_algebra/PostgreSQL.py 19 4 79%
23-
data_algebra/SQLite.py 28 5 82%
23+
data_algebra/SQLite.py 44 6 86%
2424
data_algebra/__init__.py 5 0 100%
2525
data_algebra/cdata.py 49 5 90%
2626
data_algebra/cdata_impl.py 94 16 83%
2727
data_algebra/data_ops.py 692 124 82%
2828
data_algebra/data_pipe.py 158 33 79%
29-
data_algebra/db_model.py 339 84 75%
29+
data_algebra/db_model.py 350 86 75%
3030
data_algebra/env.py 54 12 78%
3131
data_algebra/expr_rep.py 294 89 70%
3232
data_algebra/pending_eval.py 34 34 0%
3333
data_algebra/pipe.py 65 19 71%
3434
data_algebra/util.py 69 6 91%
3535
data_algebra/yaml.py 73 9 88%
3636
--------------------------------------------------
37-
TOTAL 1973 440 78%
37+
TOTAL 2000 443 78%
3838

3939

40-
========================== 14 passed in 2.22 seconds ===========================
40+
========================== 14 passed in 2.42 seconds ===========================

data_algebra/SQLite.py

Lines changed: 24 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,34 @@
1+
import math
12
import pandas
3+
import numpy
4+
import numbers
25

36
import data_algebra.util
47
import data_algebra.data_ops
58
import data_algebra.db_model
69

710

811
# map from op-name to special SQL formatting code
12+
13+
def _sqlite_is_bad_expr(dbmodel, expression):
14+
return "is_bad(" + dbmodel.expr_to_sql(expression.args[0], want_inline_parens=False) + ")"
15+
16+
917
SQLite_formatters = {
18+
"is_bad": _sqlite_is_bad_expr,
1019
}
1120

1221

22+
def _check_scalar_bad(x):
23+
if x is None:
24+
return 1
25+
if not isinstance(x, numbers.Number):
26+
return 0
27+
if numpy.isinf(x) or numpy.isnan(x):
28+
return 1
29+
return 0
30+
31+
1332
class SQLiteModel(data_algebra.db_model.DBModel):
1433
"""A model of how SQL should be generated for SQLite"""
1534

@@ -21,6 +40,11 @@ def __init__(self):
2140
sql_formatters=SQLite_formatters,
2241
)
2342

43+
def prepare_connection(self, conn):
44+
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function
45+
conn.create_function("exp", 1, math.exp)
46+
conn.create_function("is_bad", 1, _check_scalar_bad)
47+
2448
def quote_identifier(self, identifier):
2549
if not isinstance(identifier, str):
2650
raise Exception("expected identifier to be a str")

data_algebra/db_model.py

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,9 @@ def __init__(
5757
op_replacements = {"==": "="}
5858
self.op_replacements = op_replacements
5959

60+
def prepare_connection(self, conn):
61+
pass
62+
6063
def quote_identifier(self, identifier):
6164
if not isinstance(identifier, str):
6265
raise Exception("expected identifier to be a str")
463 Bytes
Binary file not shown.

dist/data_algebra-0.1.4.tar.gz

416 Bytes
Binary file not shown.

tests/test_null_bad.py

Lines changed: 22 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,25 +1,26 @@
11

22
import math
33
import numpy
4+
import sqlite3
45

56
import data_algebra.util
67
from data_algebra.data_ops import *
78
import data_algebra.SQLite
89

910
def test_null_bad():
10-
ops = TableDescription("d", ["x"]).extend({
11+
ops = TableDescription("d", ["x"]).extend({
1112
"x_is_null": "x.is_null()",
1213
"x_is_bad": "x.is_bad()"
1314
})
1415

1516
d = pandas.DataFrame({
16-
'x': [1, numpy.nan, math.inf, -math.inf, None, 2]
17+
'x': [1, numpy.nan, math.inf, -math.inf, None, 0]
1718
})
1819

1920
d2 = ops.transform(d)
2021

2122
expect = pandas.DataFrame({
22-
'x': [1, numpy.nan, math.inf, -math.inf, None, 2],
23+
'x': [1, numpy.nan, math.inf, -math.inf, None, 0],
2324
'x_is_null': [False, True, False, False, True, False],
2425
'x_is_bad': [False, True, True, True, True, False]
2526
})
@@ -31,3 +32,21 @@ def test_null_bad():
3132

3233
sql = ops.to_sql(db_model, pretty=True)
3334
assert isinstance(sql, str)
35+
36+
conn = sqlite3.connect(":memory:")
37+
db_model.prepare_connection(conn)
38+
39+
db_model.insert_table(conn, d, 'd')
40+
41+
res = db_model.read_query(conn, sql)
42+
43+
conn.close()
44+
45+
expectr = pandas.DataFrame({
46+
'x': [1, numpy.nan, math.inf, -math.inf, None, 0],
47+
'x_is_null': [0, 1, 0, 0, 1, 0],
48+
'x_is_bad': [0, 1, 1, 1, 1, 0]
49+
})
50+
51+
assert all(res['x_is_null'] == expectr['x_is_null'])
52+
assert all(res['x_is_bad'] == expectr['x_is_bad'])

tests/test_scoring_example.py

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -96,8 +96,7 @@ def test_scoring_example():
9696
sql_s = ops.to_sql(db_model_p, pretty=True)
9797

9898
conn = sqlite3.connect(":memory:")
99-
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function
100-
conn.create_function("exp", 1, math.exp)
99+
db_model_s.prepare_connection(conn)
101100

102101
db_model_s.insert_table(conn, d_local, "d")
103102
back = db_model_s.read_table(conn, "d")

0 commit comments

Comments
 (0)