77
88class Table :
99 def __init__ (self , data , columns : list , table_name : str = "data" ):
10- self .conn = sqlite3 .connect (":memory:" )
10+ self ._conn = sqlite3 .connect (":memory:" )
1111 self ._name = table_name
1212 self ._columns = columns
1313 if isinstance (data , csv .DictReader ):
@@ -22,60 +22,62 @@ def _create_table_from_reader(self, reader: csv.DictReader, columns: list):
2222
2323 column_defs = ", " .join (f"`{ col } ` TEXT" for col in columns )
2424
25- self .conn .execute (f"CREATE TABLE { self ._name } ({ column_defs } )" )
25+ self ._conn .execute (f"CREATE TABLE { self ._name } ({ column_defs } )" )
2626
2727 placeholders = ", " .join ("?" * len (columns ))
2828 for row in rows :
2929 values = [row [col ] for col in columns ]
30- self .conn .execute (
30+ self ._conn .execute (
3131 f"INSERT INTO { self ._name } VALUES ({ placeholders } )" , values
3232 )
3333
34- self .conn .commit ()
34+ self ._conn .commit ()
3535
3636 def _create_table_from_tuples (self , data : list , columns : list ):
3737 if not data :
3838 return
3939
4040 column_defs = ", " .join (f"`{ col } ` TEXT" for col in columns )
4141
42- self .conn .execute (f"CREATE TABLE { self ._name } ({ column_defs } )" )
42+ self ._conn .execute (f"CREATE TABLE { self ._name } ({ column_defs } )" )
4343
4444 placeholders = ", " .join ("?" * len (columns ))
4545 for row in data :
46- self .conn .execute (f"INSERT INTO { self ._name } VALUES ({ placeholders } )" , row )
46+ self ._conn .execute (f"INSERT INTO { self ._name } VALUES ({ placeholders } )" , row )
4747
48- self .conn .commit ()
48+ self ._conn .commit ()
4949
5050 def query (self , sql : str ):
51- return self .conn .execute (sql ).fetchall ()
51+ return self ._conn .execute (sql ).fetchall ()
5252
5353 def close (self ):
54- if self .conn :
55- self .conn .close ()
54+ if self ._conn :
55+ self ._conn .close ()
5656
5757 @property
5858 def columns (self ):
59- cursor = self .conn .execute (f"PRAGMA table_info({ self ._name } )" )
59+ cursor = self ._conn .execute (f"PRAGMA table_info({ self ._name } )" )
6060 return [row [1 ] for row in cursor .fetchall ()]
6161
6262 def head (self , n : int = 5 ):
63- cursor = self .conn .execute (f"SELECT * FROM { self ._name } LIMIT { n } " )
63+ cursor = self ._conn .execute (f"SELECT * FROM { self ._name } LIMIT { n } " )
6464 rows = cursor .fetchall ()
6565 return PrintableTable (self .columns , rows )
6666
6767 def tail (self , n : int = 5 ):
68- cursor = self .conn .execute (
68+ cursor = self ._conn .execute (
6969 f"SELECT * FROM { self ._name } ORDER BY rowid DESC LIMIT { n } "
7070 )
7171 rows = cursor .fetchall ()
7272 return PrintableTable (self .columns , rows )
7373
7474 def group_by (self , group_by_column : str , return_columns : List [str ] = None ):
75+ if group_by_column not in self .columns :
76+ raise ValueError (f"Column '{ group_by_column } ' not found in table" )
7577 if return_columns is None :
7678 return_columns = self .columns
7779 column_names = ", " .join ([f"`{ col } `" for col in return_columns ])
78- cursor = self .conn .execute (
80+ cursor = self ._conn .execute (
7981 f"SELECT { column_names } FROM { self ._name } ORDER BY `{ group_by_column } `"
8082 )
8183 rows = cursor .fetchall ()
@@ -88,19 +90,26 @@ def group_by(self, group_by_column: str, return_columns: List[str] = None):
8890
8991 def unique (self , columns ):
9092 if isinstance (columns , str ):
91- cursor = self .conn .execute (f"SELECT DISTINCT `{ columns } ` FROM { self ._name } " )
93+ cursor = self ._conn .execute (
94+ f"SELECT DISTINCT `{ columns } ` FROM { self ._name } "
95+ )
9296 values = [row [0 ] for row in cursor .fetchall ()]
9397 return Column (columns , values )
9498 else :
9599 column_names = ", " .join ([f"`{ col } `" for col in columns ])
96- cursor = self .conn .execute (
100+ cursor = self ._conn .execute (
97101 f"SELECT DISTINCT { column_names } FROM { self ._name } "
98102 )
99103 return Table (cursor .fetchall (), columns , f"{ self ._name } _unique" )
100104
101105 def where (self , where_clause : str ):
102- cursor = self .conn .execute (f"SELECT * FROM { self ._name } WHERE { where_clause } " )
103- return Table (cursor .fetchall (), self .columns , f"{ self ._name } _filtered" )
106+ try :
107+ cursor = self ._conn .execute (
108+ f"SELECT * FROM { self ._name } WHERE { where_clause } "
109+ )
110+ return Table (cursor .fetchall (), self .columns , f"{ self ._name } _filtered" )
111+ except Exception as e :
112+ raise ValueError (f"Invalid WHERE clause: { where_clause } " ) from e
104113
105114 def where_not_null (self , null_column ):
106115 if isinstance (null_column , list ):
@@ -111,13 +120,13 @@ def where_not_null(self, null_column):
111120 table_suffix = null_column
112121
113122 table_name = f"{ self ._name } _not_null_{ table_suffix } "
114- cursor = self .conn .execute (f"SELECT * FROM { self ._name } WHERE { conditions } " )
123+ cursor = self ._conn .execute (f"SELECT * FROM { self ._name } WHERE { conditions } " )
115124 return Table (cursor .fetchall (), table_name = table_name , columns = self .columns )
116125
117126 def where_in (self , column_name : str , values : set , columns : list ):
118127 placeholders = ", " .join ("?" * len (values ))
119128 column_names = ", " .join ([f"`{ col } `" for col in columns ])
120- cursor = self .conn .execute (
129+ cursor = self ._conn .execute (
121130 f"SELECT { column_names } FROM { self ._name } WHERE `{ column_name } ` IN ({ placeholders } )" ,
122131 list (values ),
123132 )
@@ -134,21 +143,23 @@ def to_dict(self, key_column: str = None, value_column: str = None):
134143 raise ValueError (
135144 f"Columns { key_column } and { value_column } must be different"
136145 )
137- cursor = self .conn .execute (
146+ cursor = self ._conn .execute (
138147 f"SELECT `{ key_column } `, `{ value_column } ` FROM { self ._name } "
139148 )
140149 return dict (cursor .fetchall ())
141150
142151 def value_counts (self , column : str ):
143- cursor = self .conn .execute (
152+ if column not in self .columns :
153+ raise ValueError (f"Column '{ column } ' not found in table" )
154+ cursor = self ._conn .execute (
144155 f"SELECT `{ column } `, COUNT(*) FROM { self ._name } GROUP BY `{ column } ` ORDER BY COUNT(*) DESC"
145156 )
146157 return Table (cursor .fetchall (), [column , "count" ], f"{ self ._name } _counts" )
147158
148159 def agg (self , group_column : str , agg_column : str , func ):
149160 builtin_funcs = {list , set }
150161 query = f"SELECT `{ group_column } `, `{ agg_column } ` FROM { self ._name } GROUP BY `{ group_column } `, `{ agg_column } `"
151- result = self .conn .execute (query ).fetchall ()
162+ result = self ._conn .execute (query ).fetchall ()
152163 d = defaultdict (list )
153164 for k , v in result :
154165 d [k ].append (v )
@@ -160,43 +171,50 @@ def agg(self, group_column: str, agg_column: str, func):
160171
161172 def __setitem__ (self , column : str , values ):
162173 if column in self .columns :
163- self .conn .execute (f"ALTER TABLE { self ._name } DROP COLUMN `{ column } `" )
164- self .conn .execute (f"ALTER TABLE { self ._name } ADD COLUMN `{ column } ` TEXT" )
174+ self ._conn .execute (f"ALTER TABLE { self ._name } DROP COLUMN `{ column } `" )
175+ self ._conn .execute (f"ALTER TABLE { self ._name } ADD COLUMN `{ column } ` TEXT" )
165176 for i , value in enumerate (values ):
166- self .conn .execute (
177+ self ._conn .execute (
167178 f"UPDATE { self ._name } SET `{ column } ` = ? WHERE rowid = ?" ,
168179 (value , i + 1 ),
169180 )
170- self .conn .commit ()
181+ self ._conn .commit ()
171182
172183 def __getitem__ (self , column ):
173184 if isinstance (column , list ):
185+ for col in column :
186+ if col not in self .columns :
187+ raise ValueError (f"Column '{ col } ' not found in table" )
174188 column_names = ", " .join ([f"`{ col } `" for col in column ])
175- result = self .conn .execute (
189+ result = self ._conn .execute (
176190 f"SELECT { column_names } FROM { self ._name } "
177191 ).fetchall ()
178192 return Table (result , column , f"{ self ._name } _subset" )
179193 else :
180- result = self .conn .execute (
194+ if column not in self .columns :
195+ raise ValueError (f"Column '{ column } ' not found in table" )
196+ result = self ._conn .execute (
181197 f"SELECT `{ column } ` FROM { self ._name } "
182198 ).fetchall ()
183199 values = [row [0 ] for row in result ]
184200 return Column (column , values )
185201
186202 def rename (self , column_mapping : dict ):
187203 for old_name , new_name in column_mapping .items ():
188- self .conn .execute (
204+ if old_name not in self .columns :
205+ raise ValueError (f"Column '{ old_name } ' not found in table" )
206+ self ._conn .execute (
189207 f"ALTER TABLE { self ._name } RENAME COLUMN `{ old_name } ` TO `{ new_name } `"
190208 )
191- self .conn .commit ()
209+ self ._conn .commit ()
192210 return self
193211
194212 def union (self , other_table ):
195213 if self .columns != other_table .columns :
196214 raise ValueError ("Tables must have the same columns for union" )
197215
198- self_data = self .conn .execute (f"SELECT * FROM { self ._name } " ).fetchall ()
199- other_data = other_table .conn .execute (
216+ self_data = self ._conn .execute (f"SELECT * FROM { self ._name } " ).fetchall ()
217+ other_data = other_table ._conn .execute (
200218 f"SELECT * FROM { other_table ._name } "
201219 ).fetchall ()
202220
@@ -205,24 +223,29 @@ def union(self, other_table):
205223
206224 def remove (self , column_name : str , values ):
207225 placeholders = ", " .join ("?" * len (values ))
208- self .conn .execute (
226+ self ._conn .execute (
209227 f"DELETE FROM { self ._name } WHERE `{ column_name } ` IN ({ placeholders } )" ,
210228 list (values ),
211229 )
212- self .conn .commit ()
230+ self ._conn .commit ()
213231 return self
214232
215233 def concat_columns (self , columns : list ):
234+ for col in columns :
235+ if col not in self .columns :
236+ raise ValueError (f"Column '{ col } ' not found in table" )
216237 column_names = " || " .join ([f"`{ col } `" for col in columns ])
217- result = self .conn .execute (
238+ result = self ._conn .execute (
218239 f"SELECT { column_names } FROM { self ._name } "
219240 ).fetchall ()
220241 values = [row [0 ] for row in result ]
221242 concat_name = "_" .join (columns )
222243 return Column (concat_name , values )
223244
224245 def explode (self , column : str , delimiter : str ):
225- all_data = self .conn .execute (f"SELECT * FROM { self ._name } " ).fetchall ()
246+ if column not in self .columns :
247+ raise ValueError (f"Column '{ column } ' not found in table" )
248+ all_data = self ._conn .execute (f"SELECT * FROM { self ._name } " ).fetchall ()
226249 col_index = self .columns .index (column )
227250
228251 exploded_data = []
@@ -239,7 +262,7 @@ def explode(self, column: str, delimiter: str):
239262 return Table (exploded_data , self .columns , f"{ self ._name } _exploded" )
240263
241264 def __len__ (self ):
242- cursor = self .conn .execute (f"SELECT COUNT(*) FROM { self ._name } " )
265+ cursor = self ._conn .execute (f"SELECT COUNT(*) FROM { self ._name } " )
243266 return cursor .fetchone ()[0 ]
244267
245268 def __str__ (self ):
0 commit comments