Skip to content

Commit 1df0243

Browse files
committed
rebuild and retest
1 parent 75182c8 commit 1df0243

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

42 files changed

+48663
-48484
lines changed

build/lib/data_algebra/SQLite.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -83,7 +83,6 @@ def _sqlite_remainder_expr(dbmodel, expression):
8383
+ ")"
8484
)
8585

86-
8786
def _sqlite_logical_or_expr(dbmodel, expression):
8887
"""
8988
Return SQL or.
@@ -110,6 +109,8 @@ def _sqlite_logical_and_expr(dbmodel, expression):
110109
"is_inf": _sqlite_is_inf_expr,
111110
"rand": _sqlite_RAND_expr,
112111
"remainder": _sqlite_remainder_expr,
112+
"%": _sqlite_remainder_expr,
113+
"mod": _sqlite_remainder_expr,
113114
"logical_or": _sqlite_logical_or_expr,
114115
"logical_and": _sqlite_logical_and_expr,
115116
}
@@ -318,6 +319,8 @@ def prepare_connection(self, conn):
318319
"""
319320
Insert user functions into db.
320321
"""
322+
# # https://stackoverflow.com/questions/52416482/load-sqlite3-extension-in-python3-sqlite
323+
# conn.enable_load_extension(True)
321324
# https://docs.python.org/3/library/sqlite3.html#sqlite3.Connection.create_function
322325
conn.create_function("is_bad", 1, _check_scalar_bad)
323326
conn.create_function("is_nan", 1, _check_scalar_nan)

build/lib/data_algebra/db_model.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -844,7 +844,7 @@ def db_handle(self, conn, *, db_engine=None):
844844

845845
def prepare_connection(self, conn):
846846
"""
847-
Do any augmentation or preperation of a database connection. Example: adding stored procedures.
847+
Do any augmentation or preparation of a database connection. Example: adding stored procedures.
848848
"""
849849
pass
850850

@@ -854,7 +854,7 @@ def prepare_connection(self, conn):
854854
def execute(self, conn, q):
855855
"""
856856
857-
:param conn: database connectionex
857+
:param conn: database connection
858858
:param q: sql query
859859
"""
860860
if isinstance(q, data_algebra.data_ops.ViewRepresentation):

build/lib/data_algebra/pandas_base.py

Lines changed: 91 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
"""
44

55
from abc import ABC
6-
from typing import Any, Callable, Dict
6+
from typing import Any, Callable, Dict, List, Optional
77
import datetime
88
import types
99
import numbers
@@ -21,6 +21,47 @@
2121
# TODO: possibly import dask, Nvidia Rapids, modin, datatable versions
2222

2323

24+
def none_mark_scalar_or_length(v) -> Optional[int]:
25+
"""
26+
Test if item is a scalar (returning None) if it is, else length of object.
27+
28+
:param v: value to test
29+
:return: None if value is a scalar, else length.
30+
"""
31+
# get some of the obvious types, and str (as str doesn't throw on len)
32+
if isinstance(v, (type(None), str, int, float)):
33+
return None # obvious scalar
34+
# len() throws on scalars other than str
35+
try:
36+
return len(v)
37+
except TypeError:
38+
return None # len() failed, probably a scalar
39+
40+
41+
def promote_scalar_to_array(vi, *, target_len: int) -> List:
42+
"""
43+
Convert a scalar into a vector. Pass a non-trivial array through.
44+
45+
:param vi: value to promote to scalar
46+
:target_len: length for vector
47+
:return: list
48+
"""
49+
assert isinstance(target_len, int)
50+
assert target_len >= 0
51+
if target_len <= 0:
52+
return []
53+
len_v = none_mark_scalar_or_length(vi)
54+
# noinspection PyBroadException
55+
if len_v is None:
56+
return [vi] * target_len # scalar
57+
if len_v == target_len:
58+
return vi
59+
if len_v == 1:
60+
return [vi[0]] * target_len # TODO: see if we can eliminate this one
61+
else:
62+
raise ValueError("incompatible column lengths")
63+
64+
2465
def _negate_or_subtract(*args):
2566
if len(args) == 1:
2667
return numpy.negative(args[0])
@@ -436,56 +477,48 @@ def sql_proxy_step(self, op, *, data_map: dict, narrow: bool):
436477
res = db_handle.read_query("\n".join(op.sql))
437478
return res
438479

439-
def columns_to_frame_(self, cols, *, target_rows=0):
480+
def columns_to_frame_(self, cols: Dict[str, Any], *, target_rows: Optional[int] = None):
440481
"""
441482
Convert a dictionary of column names to series-like objects and scalars into a Pandas data frame.
483+
Deal with special cases, such as some columns coming in as scalars (often from Panda aggregation).
442484
443485
:param cols: dictionary mapping column names to columns
444486
:param target_rows: number of rows we are shooting for
445487
:return: Pandas data frame.
446488
"""
447489
# noinspection PyUnresolvedReferences
448490
assert isinstance(cols, dict)
491+
assert isinstance(target_rows, (int, type(None)))
492+
if target_rows is not None:
493+
assert target_rows >= 0
449494
if len(cols) < 1:
450-
return self.pd.DataFrame(cols)
451-
for k, v in cols.items():
452-
try:
453-
target_rows = max(target_rows, len(v))
454-
except TypeError:
455-
target_rows = max(target_rows, 1) # scalar
495+
# all scalars, so nothing carrying index information
496+
if target_rows is not None:
497+
return self.pd.DataFrame({}, index=range(target_rows)).reset_index(drop=True, inplace=False)
498+
else:
499+
return self.pd.DataFrame({})
500+
was_all_scalars = True
501+
for v in cols.values():
502+
ln = none_mark_scalar_or_length(v)
503+
if ln is not None:
504+
was_all_scalars = False
505+
if target_rows is None:
506+
target_rows = ln
507+
else:
508+
assert target_rows == ln
509+
if was_all_scalars:
510+
if target_rows is None:
511+
target_rows = 1
512+
# all scalars, so nothing carrying index information
513+
promoted_cols = {k: promote_scalar_to_array(v, target_len=target_rows) for (k, v) in cols.items()}
514+
return self.pd.DataFrame(promoted_cols, index=range(target_rows)).reset_index(drop=True, inplace=False)
515+
assert target_rows is not None
456516
if target_rows < 1:
457-
# noinspection PyBroadException
458-
try:
459-
res = self.pd.DataFrame(cols)
460-
if res.shape[0] > 0:
461-
res = res.loc[[False] * res.shape[0], :].reset_index(
462-
drop=True, inplace=False
463-
)
464-
except Exception:
465-
res = self.pd.DataFrame({k: [] for k in cols.keys()})
466-
return res
467-
517+
# no rows, so presuming no index information (shouldn't have come from an aggregation)
518+
return self.pd.DataFrame({k: [] for k in cols.keys()})
468519
# agg can return scalars, which then can't be made into a self.pd.DataFrame
469-
def promote_scalar(vi, *, target_len):
470-
"""
471-
Convert a scalar into a vector.
472-
"""
473-
# noinspection PyBroadException
474-
try:
475-
len_v = len(vi)
476-
if len_v != target_len:
477-
if len_v == 0:
478-
return [None] * target_len
479-
elif len_v == 1:
480-
return [vi[0]] * target_len
481-
else:
482-
raise ValueError("incompatible column lengths")
483-
except Exception:
484-
return [vi] * target_len # scalar
485-
return vi
486-
487-
cols = {k: promote_scalar(v, target_len=target_rows) for (k, v) in cols.items()}
488-
return self.pd.DataFrame(cols)
520+
promoted_cols = {k: promote_scalar_to_array(v, target_len=target_rows) for (k, v) in cols.items()}
521+
return self.pd.DataFrame(promoted_cols)
489522

490523
def add_data_frame_columns_to_data_frame_(self, res, transient_new_frame):
491524
"""
@@ -542,14 +575,22 @@ def extend_step(self, op, *, data_map, narrow):
542575
"""
543576
if op.node_name != "ExtendNode":
544577
raise TypeError("op was supposed to be a data_algebra.data_ops.ExtendNode")
578+
res = self._eval_value_source(op.sources[0], data_map=data_map, narrow=narrow)
579+
if res.shape[0] <= 0:
580+
# special case out no-row frame
581+
incoming_col_set = set(res.columns)
582+
v_dict = {k: [] for k in res.columns}
583+
for k in op.ops.keys():
584+
if k not in incoming_col_set:
585+
v_dict[k] = []
586+
return self.pd.DataFrame(v_dict)
545587
window_situation = (
546588
op.windowed_situation
547589
or (len(op.partition_by) > 0)
548590
or (len(op.order_by) > 0)
549591
)
550592
if window_situation:
551593
op.check_extend_window_fns_()
552-
res = self._eval_value_source(op.sources[0], data_map=data_map, narrow=narrow)
553594
if not window_situation:
554595
with warnings.catch_warnings():
555596
warnings.simplefilter(
@@ -569,7 +610,6 @@ def extend_step(self, op, *, data_map, narrow):
569610
col_list = col_list + [c]
570611
col_set.add(c)
571612
order_cols = [c for c in col_list] # must be partition by followed by order
572-
573613
for (k, opk) in op.ops.items():
574614
# assumes all args are column names or values, enforce this earlier
575615
if len(opk.args) > 0:
@@ -751,10 +791,15 @@ def project_step(self, op, *, data_map, narrow):
751791
# agg can return scalars, which then can't be made into a self.pd.DataFrame
752792
res = self.columns_to_frame_(cols)
753793
res = res.reset_index(
754-
drop=len(op.group_by) < 1
794+
drop=(len(op.group_by) < 1) or (res.shape[0] <= 0)
755795
) # grouping variables in the index
756796
missing_group_cols = set(op.group_by) - set(res.columns)
757-
assert len(missing_group_cols) <= 0
797+
if res.shape[0] > 0:
798+
if len(missing_group_cols) != 0:
799+
raise ValueError("Missing column groups")
800+
else:
801+
for g in missing_group_cols:
802+
res[g] = []
758803
if "_data_table_temp_col" in res.columns:
759804
res = res.drop("_data_table_temp_col", axis=1, inplace=False)
760805
# double check shape is what we expect
@@ -873,6 +918,9 @@ def natural_join_step(self, op, *, data_map, narrow):
873918
)
874919
left = self._eval_value_source(op.sources[0], data_map=data_map, narrow=narrow)
875920
right = self._eval_value_source(op.sources[1], data_map=data_map, narrow=narrow)
921+
if (left.shape[0] == 0) and (right.shape[0] == 0):
922+
# pandas seems to not like this case
923+
return self.pd.DataFrame({k: [] for k in op.columns_produced()})
876924
common_cols = set([c for c in left.columns]).intersection(
877925
[c for c in right.columns]
878926
)

build/lib/data_algebra/test_util.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -390,7 +390,8 @@ def check_transform_on_handles(
390390
}
391391
empty_res = ops.eval(empty_map)
392392
assert local_data_model.is_appropriate_data_instance(empty_res)
393-
assert set(empty_res.columns) == set(res.columns)
393+
if set(empty_res.columns) != set(res.columns):
394+
raise Exception("columns mismatch")
394395
if empty_produces_empty:
395396
assert empty_res.shape[0] == 0
396397
else:

0 commit comments

Comments
 (0)