Skip to content

Commit 9edeee1

Browse files
committed
working on example
1 parent dbc8c21 commit 9edeee1

File tree

2 files changed

+23
-7
lines changed

2 files changed

+23
-7
lines changed

Examples/WindowFunctions/Arrow.ipynb

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,8 @@
3737
"\n",
3838
"table_description = describe_table(d)\n",
3939
"\n",
40+
"d['irrelevant_column'] = 1\n",
41+
"\n",
4042
"id_ops_a = table_description. \\\n",
4143
" project(group_by=['g']). \\\n",
4244
" extend({\n",
@@ -57,7 +59,7 @@
5759
{
5860
"name": "stdout",
5961
"text": [
60-
"[{'x', 'v', 'g'} -> ['g', 'x', 'v', 'ngroup']]\n"
62+
"[{'g', 'v', 'x'} -> ['g', 'x', 'v', 'ngroup']]\n"
6163
],
6264
"output_type": "stream"
6365
}
@@ -112,7 +114,7 @@
112114
{
113115
"name": "stdout",
114116
"text": [
115-
"[{'v', 'x', 'g', 'ngroup'} -> ['g', 'x', 'v', 'ngroup', 'row_number', 'shift_v']]\n"
117+
"[{'g', 'v', 'ngroup', 'x'} -> ['g', 'x', 'v', 'ngroup', 'row_number', 'shift_v']]\n"
116118
],
117119
"output_type": "stream"
118120
}
@@ -172,7 +174,7 @@
172174
{
173175
"name": "stdout",
174176
"text": [
175-
"[{'shift_v', 'g', 'v', 'row_number', 'ngroup', 'x'} -> ['g', 'x', 'v', 'ngroup', 'row_number', 'shift_v', 'size', 'max_v', 'min_v', 'sum_v', 'mean_v', 'count_v', 'size_v']]\n"
177+
"[{'g', 'x', 'row_number', 'v', 'shift_v', 'ngroup'} -> ['g', 'x', 'v', 'ngroup', 'row_number', 'shift_v', 'size', 'max_v', 'min_v', 'sum_v', 'mean_v', 'count_v', 'size_v']]\n"
176178
],
177179
"output_type": "stream"
178180
}

data_algebra/arrow.py

Lines changed: 18 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -32,21 +32,35 @@ def _r_copy_replace(self, ops):
3232
def transform(self, other):
3333
"""replace self input table with other"""
3434
if isinstance(other, pandas.DataFrame):
35+
cols = set(other.columns)
36+
missing = set(self.incoming_columns) - cols
37+
if len(missing) > 0:
38+
raise ValueError("missing required columns: " + str(missing))
39+
if len(cols - set(self.incoming_columns)):
40+
other = other[self.incoming_columns]
3541
return self.v.transform(other)
3642
if isinstance(other, data_algebra.data_ops.ViewRepresentation):
3743
other = DataOpArrow(other)
3844
if not isinstance(other, DataOpArrow):
3945
raise TypeError("other must be a DataOpArrow")
46+
missing = set(self.incoming_columns) - set(other.outgoing_columns)
47+
if len(missing) > 0:
48+
raise ValueError("missing required columns: " + str(missing))
49+
if len(set(other.outgoing_columns) - set(self.incoming_columns)):
50+
# extra columns, in a strict categorical formulation we would
51+
# reject this. instead insert a select columns node to get the match
52+
other = DataOpArrow(other.v.select_columns([c for c in self.incoming_columns]))
53+
# check categorical arrow composition conditions
4054
if set(self.incoming_columns) != set(other.outgoing_columns):
41-
raise TypeError("arrow composition conditions not met (incoming columsn don't match outgoing)")
55+
raise ValueError("arrow composition conditions not met (incoming column set doesn't match outgoing)")
4256
return DataOpArrow(other._r_copy_replace(self.v))
4357

44-
def __rrshift__(self, other): # override other >> self
45-
return self.transform(other)
46-
4758
def __rshift__(self, other): # override self >> other
4859
return other.transform(self)
4960

61+
def __rrshift__(self, other): # override other >> self
62+
return self.transform(other)
63+
5064
def __repr__(self):
5165
return "DataOpArrow(" + self.v.__repr__() + ")"
5266

0 commit comments

Comments
 (0)