Skip to content

Commit 2d849ed

Browse files
committed
get if_else working
improve result comparison
1 parent 3eea4c9 commit 2d849ed

File tree

5 files changed

+88
-10
lines changed

5 files changed

+88
-10
lines changed

data_algebra/db_model.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,18 @@ def _db_is_bad_expr(dbmodel, expression):
4444
)
4545

4646

47+
def _db_if_else_expr(dbmodel, expression):
48+
if_expr = dbmodel.expr_to_sql(expression.args[0], want_inline_parens=True)
49+
x_expr = dbmodel.expr_to_sql(expression.args[1], want_inline_parens=True)
50+
y_expr = dbmodel.expr_to_sql(expression.args[2], want_inline_parens=True)
51+
return (
52+
"CASE" +\
53+
" WHEN " + if_expr + " THEN " + x_expr +\
54+
" WHEN NOT " + if_expr + " THEN " + y_expr +\
55+
" ELSE NULL END")
56+
57+
58+
4759
def _db_neg_expr(dbmodel, expression):
4860
subexpr = dbmodel.expr_to_sql(expression.args[0], want_inline_parens=True)
4961
return "( -" + subexpr + " )"
@@ -53,6 +65,7 @@ def _db_neg_expr(dbmodel, expression):
5365
"is_null": _db_is_null_expr,
5466
"is_bad": _db_is_bad_expr,
5567
"neg": _db_neg_expr,
68+
"if_else": _db_if_else_expr,
5669
}
5770

5871

data_algebra/expr_rep.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,6 @@
11
from typing import Union
22
import collections
33

4-
import pandas
5-
64
import data_algebra.util
75
import data_algebra.env
86

@@ -43,6 +41,15 @@ def __uop_expr__(self, op, *, params=None):
4341
raise TypeError("op is supposed to be a string")
4442
return Expression(op, (self,), params=params)
4543

44+
def __triop_expr__(self, op, x, y):
45+
if not isinstance(op, str):
46+
raise TypeError("op is supposed to be a string")
47+
if not isinstance(x, Term):
48+
x = Value(x)
49+
if not isinstance(y, Term):
50+
y = Value(y)
51+
return Expression(op, (self, x, y), inline=False)
52+
4653
# tree re-write
4754

4855
def replace_view(self, view):
@@ -289,6 +296,9 @@ def is_null(self):
289296
def is_bad(self):
290297
return self.__uop_expr__("is_bad")
291298

299+
def if_else(self, x, y):
300+
return self.__triop_expr__("if_else", x, y)
301+
292302

293303
class Value(Term):
294304
def __init__(self, value):
@@ -337,15 +347,12 @@ def get_column_names(self, columns_seen):
337347
}
338348

339349

340-
pandas_eval_env = {
341-
"is_null": lambda x: pandas.isnull(x),
342-
"is_bad": data_algebra.util.is_bad,
343-
}
344-
345-
346350
pd_formatters = {
347351
"is_bad": lambda expr: "@is_bad(" + expr.args[0].to_pandas() + ")",
348352
"is_null": lambda expr: "@is_null(" + expr.args[0].to_pandas() + ")",
353+
"if_else": lambda expr: "@if_else(" + expr.args[0].to_pandas() +\
354+
", " + expr.args[1].to_pandas() +\
355+
", " + expr.args[2].to_pandas()+ ")",
349356
"neg": lambda expr: "-" + expr.args[0].to_pandas(want_inline_parens=True),
350357
}
351358

data_algebra/pandas_model.py

Lines changed: 11 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,21 @@
1+
2+
import numpy
13
import pandas
24

35
import data_algebra
6+
import data_algebra.util
47
import data_algebra.data_model
58
import data_algebra.expr_rep
69
import data_algebra.data_ops
710

811

12+
pandas_eval_env = {
13+
"is_null": lambda x: pandas.isnull(x),
14+
"is_bad": data_algebra.util.is_bad,
15+
"if_else": lambda c, x, y: numpy.where(c, x, y)
16+
}
17+
18+
919
class PandasModel(data_algebra.data_model.DataModel):
1020
def __init__(self):
1121
data_algebra.data_model.DataModel.__init__(self)
@@ -52,7 +62,7 @@ def extend_step(self, op, *, data_map, eval_env):
5262
op_src = opk.to_pandas()
5363
res[k] = res.eval(
5464
op_src,
55-
local_dict=data_algebra.expr_rep.pandas_eval_env,
65+
local_dict=pandas_eval_env,
5666
global_dict=eval_env,
5767
)
5868
else:

data_algebra/util.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,8 +107,11 @@ def equivalent_frames(
107107
if not all([ca_null[i] == cb_null[i] for i in range(a.shape[0])]):
108108
return False
109109
if can_convert_v_to_numeric(ca):
110+
ca = numpy.asarray(ca, dtype=float)
111+
cb = numpy.asarray(cb, dtype=float)
110112
dif = ca - cb
111-
if dif.abs().max() > float_tol:
113+
dif = numpy.asarray([abs(d) for d in dif if not pandas.isnull(d)])
114+
if dif.max() > float_tol:
112115
return False
113116
else:
114117
if not all([ca[i] == cb[i] for i in range(a.shape[0])]):

tests/test_if_else.py

Lines changed: 45 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,45 @@
1+
2+
import sqlite3
3+
4+
import pandas
5+
6+
from data_algebra.data_ops import *
7+
from data_algebra.SQLite import SQLiteModel
8+
import data_algebra.util
9+
10+
11+
def test_if_else():
12+
d = pandas.DataFrame({
13+
'a': [True, False],
14+
'b': [1 ,2],
15+
'c': [3, 4]
16+
})
17+
18+
ops = TableDescription('d', ['a', 'b', 'c']). \
19+
extend({'d': 'a.if_else(b, c)'})
20+
21+
expect = pandas.DataFrame({
22+
'a': [True, False],
23+
'b': [1, 2],
24+
'c': [3, 4],
25+
'd': [1, 4],
26+
})
27+
28+
res_pandas = ops.transform(d)
29+
30+
assert data_algebra.util.equivalent_frames(res_pandas, expect)
31+
32+
db_model = SQLiteModel()
33+
34+
ops_sql = ops.to_sql(db_model)
35+
36+
conn = sqlite3.connect(":memory:")
37+
db_model.prepare_connection(conn)
38+
39+
db_model.insert_table(conn, d, 'd')
40+
41+
res_db = db_model.read_query(conn, ops_sql)
42+
43+
conn.close()
44+
45+
assert data_algebra.util.equivalent_frames(res_db, expect)

0 commit comments

Comments
 (0)