Skip to content

Commit ca2fc81

Browse files
committed
Simple Table
- Add more tests to Table - Add tests to capture failure scenarios in Table
1 parent 4942a5b commit ca2fc81

File tree

3 files changed

+210
-38
lines changed

3 files changed

+210
-38
lines changed

pyard/simple_table.py

Lines changed: 61 additions & 38 deletions
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@
77

88
class Table:
99
def __init__(self, data, columns: list, table_name: str = "data"):
10-
self.conn = sqlite3.connect(":memory:")
10+
self._conn = sqlite3.connect(":memory:")
1111
self._name = table_name
1212
self._columns = columns
1313
if isinstance(data, csv.DictReader):
@@ -22,60 +22,62 @@ def _create_table_from_reader(self, reader: csv.DictReader, columns: list):
2222

2323
column_defs = ", ".join(f"`{col}` TEXT" for col in columns)
2424

25-
self.conn.execute(f"CREATE TABLE {self._name} ({column_defs})")
25+
self._conn.execute(f"CREATE TABLE {self._name} ({column_defs})")
2626

2727
placeholders = ", ".join("?" * len(columns))
2828
for row in rows:
2929
values = [row[col] for col in columns]
30-
self.conn.execute(
30+
self._conn.execute(
3131
f"INSERT INTO {self._name} VALUES ({placeholders})", values
3232
)
3333

34-
self.conn.commit()
34+
self._conn.commit()
3535

3636
def _create_table_from_tuples(self, data: list, columns: list):
3737
if not data:
3838
return
3939

4040
column_defs = ", ".join(f"`{col}` TEXT" for col in columns)
4141

42-
self.conn.execute(f"CREATE TABLE {self._name} ({column_defs})")
42+
self._conn.execute(f"CREATE TABLE {self._name} ({column_defs})")
4343

4444
placeholders = ", ".join("?" * len(columns))
4545
for row in data:
46-
self.conn.execute(f"INSERT INTO {self._name} VALUES ({placeholders})", row)
46+
self._conn.execute(f"INSERT INTO {self._name} VALUES ({placeholders})", row)
4747

48-
self.conn.commit()
48+
self._conn.commit()
4949

5050
def query(self, sql: str):
51-
return self.conn.execute(sql).fetchall()
51+
return self._conn.execute(sql).fetchall()
5252

5353
def close(self):
54-
if self.conn:
55-
self.conn.close()
54+
if self._conn:
55+
self._conn.close()
5656

5757
@property
5858
def columns(self):
59-
cursor = self.conn.execute(f"PRAGMA table_info({self._name})")
59+
cursor = self._conn.execute(f"PRAGMA table_info({self._name})")
6060
return [row[1] for row in cursor.fetchall()]
6161

6262
def head(self, n: int = 5):
63-
cursor = self.conn.execute(f"SELECT * FROM {self._name} LIMIT {n}")
63+
cursor = self._conn.execute(f"SELECT * FROM {self._name} LIMIT {n}")
6464
rows = cursor.fetchall()
6565
return PrintableTable(self.columns, rows)
6666

6767
def tail(self, n: int = 5):
68-
cursor = self.conn.execute(
68+
cursor = self._conn.execute(
6969
f"SELECT * FROM {self._name} ORDER BY rowid DESC LIMIT {n}"
7070
)
7171
rows = cursor.fetchall()
7272
return PrintableTable(self.columns, rows)
7373

7474
def group_by(self, group_by_column: str, return_columns: List[str] = None):
75+
if group_by_column not in self.columns:
76+
raise ValueError(f"Column '{group_by_column}' not found in table")
7577
if return_columns is None:
7678
return_columns = self.columns
7779
column_names = ", ".join([f"`{col}`" for col in return_columns])
78-
cursor = self.conn.execute(
80+
cursor = self._conn.execute(
7981
f"SELECT {column_names} FROM {self._name} ORDER BY `{group_by_column}`"
8082
)
8183
rows = cursor.fetchall()
@@ -88,19 +90,26 @@ def group_by(self, group_by_column: str, return_columns: List[str] = None):
8890

8991
def unique(self, columns):
9092
if isinstance(columns, str):
91-
cursor = self.conn.execute(f"SELECT DISTINCT `{columns}` FROM {self._name}")
93+
cursor = self._conn.execute(
94+
f"SELECT DISTINCT `{columns}` FROM {self._name}"
95+
)
9296
values = [row[0] for row in cursor.fetchall()]
9397
return Column(columns, values)
9498
else:
9599
column_names = ", ".join([f"`{col}`" for col in columns])
96-
cursor = self.conn.execute(
100+
cursor = self._conn.execute(
97101
f"SELECT DISTINCT {column_names} FROM {self._name}"
98102
)
99103
return Table(cursor.fetchall(), columns, f"{self._name}_unique")
100104

101105
def where(self, where_clause: str):
102-
cursor = self.conn.execute(f"SELECT * FROM {self._name} WHERE {where_clause}")
103-
return Table(cursor.fetchall(), self.columns, f"{self._name}_filtered")
106+
try:
107+
cursor = self._conn.execute(
108+
f"SELECT * FROM {self._name} WHERE {where_clause}"
109+
)
110+
return Table(cursor.fetchall(), self.columns, f"{self._name}_filtered")
111+
except Exception as e:
112+
raise ValueError(f"Invalid WHERE clause: {where_clause}") from e
104113

105114
def where_not_null(self, null_column):
106115
if isinstance(null_column, list):
@@ -111,13 +120,13 @@ def where_not_null(self, null_column):
111120
table_suffix = null_column
112121

113122
table_name = f"{self._name}_not_null_{table_suffix}"
114-
cursor = self.conn.execute(f"SELECT * FROM {self._name} WHERE {conditions}")
123+
cursor = self._conn.execute(f"SELECT * FROM {self._name} WHERE {conditions}")
115124
return Table(cursor.fetchall(), table_name=table_name, columns=self.columns)
116125

117126
def where_in(self, column_name: str, values: set, columns: list):
118127
placeholders = ", ".join("?" * len(values))
119128
column_names = ", ".join([f"`{col}`" for col in columns])
120-
cursor = self.conn.execute(
129+
cursor = self._conn.execute(
121130
f"SELECT {column_names} FROM {self._name} WHERE `{column_name}` IN ({placeholders})",
122131
list(values),
123132
)
@@ -134,21 +143,23 @@ def to_dict(self, key_column: str = None, value_column: str = None):
134143
raise ValueError(
135144
f"Columns {key_column} and {value_column} must be different"
136145
)
137-
cursor = self.conn.execute(
146+
cursor = self._conn.execute(
138147
f"SELECT `{key_column}`, `{value_column}` FROM {self._name}"
139148
)
140149
return dict(cursor.fetchall())
141150

142151
def value_counts(self, column: str):
143-
cursor = self.conn.execute(
152+
if column not in self.columns:
153+
raise ValueError(f"Column '{column}' not found in table")
154+
cursor = self._conn.execute(
144155
f"SELECT `{column}`, COUNT(*) FROM {self._name} GROUP BY `{column}` ORDER BY COUNT(*) DESC"
145156
)
146157
return Table(cursor.fetchall(), [column, "count"], f"{self._name}_counts")
147158

148159
def agg(self, group_column: str, agg_column: str, func):
149160
builtin_funcs = {list, set}
150161
query = f"SELECT `{group_column}`, `{agg_column}` FROM {self._name} GROUP BY `{group_column}`, `{agg_column}`"
151-
result = self.conn.execute(query).fetchall()
162+
result = self._conn.execute(query).fetchall()
152163
d = defaultdict(list)
153164
for k, v in result:
154165
d[k].append(v)
@@ -160,43 +171,50 @@ def agg(self, group_column: str, agg_column: str, func):
160171

161172
def __setitem__(self, column: str, values):
162173
if column in self.columns:
163-
self.conn.execute(f"ALTER TABLE {self._name} DROP COLUMN `{column}`")
164-
self.conn.execute(f"ALTER TABLE {self._name} ADD COLUMN `{column}` TEXT")
174+
self._conn.execute(f"ALTER TABLE {self._name} DROP COLUMN `{column}`")
175+
self._conn.execute(f"ALTER TABLE {self._name} ADD COLUMN `{column}` TEXT")
165176
for i, value in enumerate(values):
166-
self.conn.execute(
177+
self._conn.execute(
167178
f"UPDATE {self._name} SET `{column}` = ? WHERE rowid = ?",
168179
(value, i + 1),
169180
)
170-
self.conn.commit()
181+
self._conn.commit()
171182

172183
def __getitem__(self, column):
173184
if isinstance(column, list):
185+
for col in column:
186+
if col not in self.columns:
187+
raise ValueError(f"Column '{col}' not found in table")
174188
column_names = ", ".join([f"`{col}`" for col in column])
175-
result = self.conn.execute(
189+
result = self._conn.execute(
176190
f"SELECT {column_names} FROM {self._name}"
177191
).fetchall()
178192
return Table(result, column, f"{self._name}_subset")
179193
else:
180-
result = self.conn.execute(
194+
if column not in self.columns:
195+
raise ValueError(f"Column '{column}' not found in table")
196+
result = self._conn.execute(
181197
f"SELECT `{column}` FROM {self._name}"
182198
).fetchall()
183199
values = [row[0] for row in result]
184200
return Column(column, values)
185201

186202
def rename(self, column_mapping: dict):
187203
for old_name, new_name in column_mapping.items():
188-
self.conn.execute(
204+
if old_name not in self.columns:
205+
raise ValueError(f"Column '{old_name}' not found in table")
206+
self._conn.execute(
189207
f"ALTER TABLE {self._name} RENAME COLUMN `{old_name}` TO `{new_name}`"
190208
)
191-
self.conn.commit()
209+
self._conn.commit()
192210
return self
193211

194212
def union(self, other_table):
195213
if self.columns != other_table.columns:
196214
raise ValueError("Tables must have the same columns for union")
197215

198-
self_data = self.conn.execute(f"SELECT * FROM {self._name}").fetchall()
199-
other_data = other_table.conn.execute(
216+
self_data = self._conn.execute(f"SELECT * FROM {self._name}").fetchall()
217+
other_data = other_table._conn.execute(
200218
f"SELECT * FROM {other_table._name}"
201219
).fetchall()
202220

@@ -205,24 +223,29 @@ def union(self, other_table):
205223

206224
def remove(self, column_name: str, values):
207225
placeholders = ", ".join("?" * len(values))
208-
self.conn.execute(
226+
self._conn.execute(
209227
f"DELETE FROM {self._name} WHERE `{column_name}` IN ({placeholders})",
210228
list(values),
211229
)
212-
self.conn.commit()
230+
self._conn.commit()
213231
return self
214232

215233
def concat_columns(self, columns: list):
234+
for col in columns:
235+
if col not in self.columns:
236+
raise ValueError(f"Column '{col}' not found in table")
216237
column_names = " || ".join([f"`{col}`" for col in columns])
217-
result = self.conn.execute(
238+
result = self._conn.execute(
218239
f"SELECT {column_names} FROM {self._name}"
219240
).fetchall()
220241
values = [row[0] for row in result]
221242
concat_name = "_".join(columns)
222243
return Column(concat_name, values)
223244

224245
def explode(self, column: str, delimiter: str):
225-
all_data = self.conn.execute(f"SELECT * FROM {self._name}").fetchall()
246+
if column not in self.columns:
247+
raise ValueError(f"Column '{column}' not found in table")
248+
all_data = self._conn.execute(f"SELECT * FROM {self._name}").fetchall()
226249
col_index = self.columns.index(column)
227250

228251
exploded_data = []
@@ -239,7 +262,7 @@ def explode(self, column: str, delimiter: str):
239262
return Table(exploded_data, self.columns, f"{self._name}_exploded")
240263

241264
def __len__(self):
242-
cursor = self.conn.execute(f"SELECT COUNT(*) FROM {self._name}")
265+
cursor = self._conn.execute(f"SELECT COUNT(*) FROM {self._name}")
243266
return cursor.fetchone()[0]
244267

245268
def __str__(self):
Lines changed: 44 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,44 @@
1+
from pyard.simple_table import Column
2+
3+
4+
def test_column_creation():
5+
col = Column("age", ["25", "30", "35"])
6+
assert col.name == "-age"
7+
assert len(col) == 3
8+
9+
10+
def test_column_apply():
11+
col = Column("age", ["25", "30"])
12+
result = col.apply(lambda x: int(x) * 2)
13+
assert result == [50, 60]
14+
15+
16+
def test_column_to_list():
17+
col = Column("name", ["John", "Jane"])
18+
assert col.to_list() == ["John", "Jane"]
19+
20+
21+
def test_column_getitem():
22+
col = Column("age", ["25", "30", "35"])
23+
assert col[0] == "25"
24+
assert col[1] == "30"
25+
assert col[-1] == "35"
26+
27+
28+
def test_column_iter():
29+
col = Column("name", ["John", "Jane"])
30+
values = list(col)
31+
assert values == ["John", "Jane"]
32+
33+
34+
def test_column_len():
35+
col = Column("empty", [])
36+
assert len(col) == 0
37+
38+
col = Column("data", ["a", "b", "c"])
39+
assert len(col) == 3
40+
41+
42+
def test_column_name_property():
43+
col = Column("test_column", ["value"])
44+
assert col.name == "-test_column"

0 commit comments

Comments
 (0)