|
| 1 | + |
| 2 | +import copy |
| 3 | + |
| 4 | +import pandas |
| 5 | + |
| 6 | +import data_algebra.data_ops |
| 7 | + |
| 8 | + |
| 9 | + |
| 10 | +class DataOpArrow: |
| 11 | + """ Represent a section of operators as a categorical arrow.""" |
| 12 | + |
| 13 | + def __init__(self, v): |
| 14 | + if not isinstance(v, data_algebra.data_ops.ViewRepresentation): |
| 15 | + raise TypeError("expected v to be data_algebra.data_ops") |
| 16 | + self.v = v |
| 17 | + cused = v.columns_used() |
| 18 | + if len(cused) != 1: |
| 19 | + raise ValueError("v must use exactly one table") |
| 20 | + k = [k for k in cused.keys()][0] |
| 21 | + self.incoming_columns = cused[k] |
| 22 | + self.outgoing_columns = v.column_names |
| 23 | + |
| 24 | + def _r_copy_replace(self, ops): |
| 25 | + """re-write ops replacing any TableDescription with self.v""" |
| 26 | + if isinstance(ops, data_algebra.data_ops.TableDescription): |
| 27 | + return self.v |
| 28 | + node = copy.copy(ops) |
| 29 | + node.sources = [self._r_copy_replace(s) for s in node.sources] |
| 30 | + return node |
| 31 | + |
| 32 | + def transform(self, other): |
| 33 | + """replace self input table with other""" |
| 34 | + if isinstance(other, pandas.DataFrame): |
| 35 | + cols = set(other.columns) |
| 36 | + missing = set(self.incoming_columns) - cols |
| 37 | + if len(missing) > 0: |
| 38 | + raise ValueError("missing required columns: " + str(missing)) |
| 39 | + if len(cols - set(self.incoming_columns)): |
| 40 | + other = other[self.incoming_columns] |
| 41 | + return self.v.transform(other) |
| 42 | + if isinstance(other, data_algebra.data_ops.ViewRepresentation): |
| 43 | + other = DataOpArrow(other) |
| 44 | + if not isinstance(other, DataOpArrow): |
| 45 | + raise TypeError("other must be a DataOpArrow") |
| 46 | + missing = set(self.incoming_columns) - set(other.outgoing_columns) |
| 47 | + if len(missing) > 0: |
| 48 | + raise ValueError("missing required columns: " + str(missing)) |
| 49 | + if len(set(other.outgoing_columns) - set(self.incoming_columns)): |
| 50 | + # extra columns, in a strict categorical formulation we would |
| 51 | + # reject this. instead insert a select columns node to get the match |
| 52 | + other = DataOpArrow(other.v.select_columns([c for c in self.incoming_columns])) |
| 53 | + # check categorical arrow composition conditions |
| 54 | + if set(self.incoming_columns) != set(other.outgoing_columns): |
| 55 | + raise ValueError("arrow composition conditions not met (incoming column set doesn't match outgoing)") |
| 56 | + return DataOpArrow(other._r_copy_replace(self.v)) |
| 57 | + |
| 58 | + def __rshift__(self, other): # override self >> other |
| 59 | + return other.transform(self) |
| 60 | + |
| 61 | + def __rrshift__(self, other): # override other >> self |
| 62 | + return self.transform(other) |
| 63 | + |
| 64 | + def __repr__(self): |
| 65 | + return "DataOpArrow(" + self.v.__repr__() + ")" |
| 66 | + |
| 67 | + def __str__(self): |
| 68 | + return "[" + str(self.incoming_columns) + " -> " + str(self.outgoing_columns) + "]" |
0 commit comments