1010from pydala .helpers .polars import pl
1111
1212from .dataset import CsvDataset , JsonDataset , ParquetDataset , PyarrowDataset
13+
14+ from abc import ABC , abstractmethod
15+
16+ class AbstractLoader (ABC ):
17+ def _matches_format (self , params ) -> bool :
18+ raise NotImplementedError
19+
20+ def _read_data (self , catalog , params , ** kwargs ) -> pl .DataFrame :
21+ raise NotImplementedError
22+
23+ def _get_dataset_class (self ):
24+ raise NotImplementedError
25+
26+ def load (self , catalog , table_name , as_dataset : bool , with_metadata : bool = False , ** kwargs ):
27+ params = catalog ._get_table_params (table_name = table_name )
28+ if not self ._matches_format (params ):
29+ return None
30+ if not as_dataset :
31+ df = self ._read_data (catalog , params , ** kwargs )
32+ catalog .ddb_con .register (table_name , df )
33+ return df
34+ cls = self ._get_dataset_class (with_metadata )
35+ return cls (params .path , filesystem = catalog .fs [params .filesystem ], name = table_name , ddb_con = catalog .ddb_con , ** kwargs )
36+
37+ class ParquetLoader (AbstractLoader ):
38+ def _matches_format (self , params ) -> bool :
39+ return 'parquet' in params .format .lower ()
40+
41+ def _read_data (self , catalog , params , ** kwargs ) -> pl .DataFrame :
42+ fs = catalog .fs [params .filesystem ]
43+ if params .path .endswith ('.parquet' ):
44+ return fs .read_parquet (params .path , ** kwargs )
45+ return fs .read_parquet_dataset (params .path , ** kwargs )
46+
47+ def _get_dataset_class (self , with_metadata : bool = True ):
48+ return ParquetDataset if with_metadata else PyarrowDataset
49+
50+ class CsvLoader (AbstractLoader ):
51+ def _matches_format (self , params ) -> bool :
52+ return 'csv' in params .format .lower ()
53+
54+ def _read_data (self , catalog , params , ** kwargs ) -> pl .DataFrame :
55+ fs = catalog .fs [params .filesystem ]
56+ if params .path .endswith ('.csv' ):
57+ return fs .read_csv (params .path , ** kwargs )
58+ return fs .read_csv_dataset (params .path , ** kwargs )
59+
60+ def _get_dataset_class (self , with_metadata : bool = True ):
61+ return CsvDataset
62+
63+ class JsonLoader (AbstractLoader ):
64+ def _matches_format (self , params ) -> bool :
65+ return 'json' in params .format .lower ()
66+
67+ def _read_data (self , catalog , params , ** kwargs ) -> pl .DataFrame :
68+ fs = catalog .fs [params .filesystem ]
69+ if params .path .endswith ('.json' ):
70+ return fs .read_json (params .path , ** kwargs )
71+ return fs .read_json_dataset (params .path , ** kwargs )
72+
73+ def _get_dataset_class (self , with_metadata : bool = True ):
74+ return JsonDataset
75+
76+ # Registry
77+ LOADERS = {
78+ 'parquet' : ParquetLoader (),
79+ 'csv' : CsvLoader (),
80+ 'json' : JsonLoader (),
81+ }
1382from .filesystem import FileSystem
1483from .helpers .misc import delattr_rec , get_nested_keys , getattr_rec , setattr_rec
1584from .helpers .sql import get_table_names
@@ -162,87 +231,22 @@ def files(self, table_name: str) -> list[str]:
162231 )
163232
164233 def load_parquet (
165- self , table_name : str , as_dataset = True , with_metadata : bool = True , ** kwargs
234+ self , table_name : str , as_dataset : bool = True , with_metadata : bool = True , ** kwargs
166235 ) -> ParquetDataset | PyarrowDataset | pl .DataFrame | None :
167- params = self ._get_table_params (table_name = table_name )
168-
169- if "parquet" not in params .format .lower ():
170- return
171- if not as_dataset :
172- if params .path .endswith (".parquet" ):
173- df = self .fs [params .filesystem ].read_parquet (params .path , ** kwargs )
174- self .ddb_con .register (table_name , df )
175- return df
176-
177- df = self .fs [params .filesystem ].read_parquet_dataset (params .path , ** kwargs )
178- self .ddb_con .register (table_name , df )
179- return df
180-
181- if with_metadata :
182- return ParquetDataset (
183- params .path ,
184- filesystem = self .fs [params .filesystem ],
185- name = table_name ,
186- ddb_con = self .ddb_con ,
187- ** kwargs ,
188- )
189-
190- return PyarrowDataset (
191- params .path ,
192- filesystem = self .fs [params .filesystem ],
193- name = table_name ,
194- ddb_con = self .ddb_con ,
195- ** kwargs ,
196- )
236+ """Load Parquet table as DataFrame or dataset."""
237+ return self .load (table_name , as_dataset = as_dataset , with_metadata = with_metadata , ** kwargs )
197238
198239 def load_csv (
199240 self , table_name : str , as_dataset : bool = True , ** kwargs
200241 ) -> CsvDataset | pl .DataFrame | None :
201- params = self ._get_table_params (table_name = table_name )
202-
203- if "csv" not in params .format .lower ():
204- return
205- if not as_dataset :
206- if params .path .endswith (".csv" ):
207- df = self .fs [params .filesystem ].read_parquet (params .path , ** kwargs )
208- self .ddb_con .register (table_name , df )
209- return df
210-
211- df = self .fs [params .filesystem ].read_parquet_dataset (params .path , ** kwargs )
212- self .ddb_con .register (table_name , df )
213- return df
214-
215- return CsvDataset (
216- params .path ,
217- filesystem = self .fs [params .filesystem ],
218- name = table_name ,
219- ddb_con = self .ddb_con ,
220- ** kwargs ,
221- )
242+ """Load CSV table as DataFrame or dataset."""
243+ return self .load (table_name , as_dataset = as_dataset , with_metadata = False , ** kwargs )
222244
223245 def load_json (
224246 self , table_name : str , as_dataset : bool = True , ** kwargs
225247 ) -> JsonDataset | pl .DataFrame | None :
226- params = self ._get_table_params (table_name = table_name )
227-
228- if "json" not in params .format .lower ():
229- return
230- if not as_dataset :
231- if params .path .endswith (".json" ):
232- df = self .fs [params .filesystem ].read_json (params .path , ** kwargs )
233- self .ddb_con .register (table_name , df )
234- return df
235-
236- df = self .fs [params .filesystem ].read_json_dataset (params .path , ** kwargs )
237- self .ddb_con .register (table_name , df )
238- return df
239- return JsonDataset (
240- params .path ,
241- filesystem = self .fs [params .filesystem ],
242- name = table_name ,
243- ddb_con = self .ddb_con ,
244- ** kwargs ,
245- )
248+ """Load JSON table as DataFrame or dataset."""
249+ return self .load (table_name , as_dataset = as_dataset , with_metadata = False , ** kwargs )
246250
247251 def load (
248252 self ,
@@ -253,30 +257,15 @@ def load(
253257 ** kwargs ,
254258 ):
255259 params = self ._get_table_params (table_name = table_name )
256-
257- if params .format .lower () == "parquet" :
258- if table_name not in self .table and not reload :
259- self .table [table_name ] = self .load_parquet (
260- table_name ,
261- as_dataset = as_dataset ,
262- with_metadata = with_metadata ,
263- ** kwargs ,
264- )
265- return self .table [table_name ]
266-
267- elif params .format .lower () == "csv" :
268- if table_name not in self .table and not reload :
269- self .table [table_name ] = self .load_csv (
270- table_name , as_dataset = as_dataset , ** kwargs
271- )
272- return self .table [table_name ]
273-
274- elif params .format .lower () == "json" :
275- if table_name not in self .table and not reload :
276- self .table [table_name ] = self .load_json (table_name , ** kwargs )
277- return self .table [table_name ]
278-
279- # return None
260+ format_lower = params .format .lower ()
261+ loader = LOADERS .get (format_lower )
262+ if loader is None :
263+ return None
264+ if table_name not in self .table and not reload :
265+ self .table [table_name ] = loader .load (
266+ self , table_name , as_dataset , with_metadata , ** kwargs
267+ )
268+ return self .table [table_name ]
280269
281270 # def _ddb_table_mapping(self, table_name: str):
282271 # params = getattr_rec(self._catalog, self._get_table_from_table_name(table_name=table_name))
0 commit comments