Skip to content

Commit e55195f

Browse files
committed
Implement format handlers for Parquet, JSON, and CSV; remove obsolete fsspec utility functions and metadata helpers; refactor schema conversion logic; enhance Writer class with retry logic for dataset writing; introduce utility functions for disk usage tracking and file name management; update dependencies in pyproject.toml.
1 parent 6527551 commit e55195f

File tree

11 files changed

+1331
-1170
lines changed

11 files changed

+1331
-1170
lines changed

pydala/cache.py

Lines changed: 85 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,85 @@
1+
import posixpath
2+
import os
3+
from typing import Any
4+
5+
from fsspec.implementations.cache_mapper import AbstractCacheMapper
6+
from fsspec.implementations.cached import SimpleCacheFileSystem
7+
from loguru import logger
8+
9+
from .helpers.security import safe_join, validate_path
10+
11+
12+
class FileNameCacheMapper(AbstractCacheMapper):
13+
def __init__(self, directory: str):
14+
self.directory = validate_path(directory)
15+
16+
def __call__(self, path: str) -> str:
17+
validated_path = validate_path(path)
18+
full_path = safe_join(self.directory, validated_path)
19+
parent_dir = posixpath.dirname(full_path)
20+
os.makedirs(parent_dir, exist_ok=True)
21+
return validated_path
22+
23+
24+
class MonitoredSimpleCacheFileSystem(SimpleCacheFileSystem):
25+
def __init__(self, verbose: bool = False, **kwargs):
26+
self._verbose = verbose
27+
super().__init__(**kwargs)
28+
self._mapper = FileNameCacheMapper(kwargs.get("cache_storage", "~/.tmp"))
29+
30+
def _check_file(self, path: str):
31+
self._check_cache()
32+
cache_path = self._mapper(path)
33+
for storage in self.storage:
34+
fn = posixpath.join(storage, cache_path)
35+
if posixpath.exists(fn):
36+
return fn
37+
if self._verbose:
38+
logger.info(f"Downloading {self.protocol[0]}://{path}")
39+
40+
def size(self, path: str):
41+
cached_file = self._check_file(self._strip_protocol(path))
42+
if cached_file is None:
43+
return self.fs.size(path)
44+
return posixpath.getsize(cached_file)
45+
46+
def __getattribute__(self, item: str) -> Any:
47+
if item in self._delegated_methods:
48+
return lambda *args, **kwargs: getattr(type(self), item).__get__(self, type(self))(*args, **kwargs)
49+
if item in {"__reduce_ex__", "__reduce__"}:
50+
raise AttributeError(item)
51+
if item == "transaction":
52+
return type(self).transaction.__get__(self, type(self))
53+
if item in {"_cache", "transaction_type"}:
54+
return getattr(type(self), item)
55+
if item == "__class__":
56+
return type(self)
57+
return self._delegate_to_fs(item)
58+
59+
def _delegate_to_fs(self, item: str) -> Any:
60+
d = object.__getattribute__(self, "__dict__")
61+
fs = d.get("fs")
62+
if item in d:
63+
return d[item]
64+
if fs is None:
65+
return super().__getattribute__(item)
66+
if item in fs.__dict__:
67+
return fs.__dict__[item]
68+
cls = type(fs)
69+
m = getattr(cls, item, None)
70+
if m is None:
71+
raise AttributeError(f"'{item}' not found in underlying fs")
72+
if callable(m) and not hasattr(m, "__self__") or m.__self__ is None:
73+
return m.__get__(fs, cls)
74+
return m
75+
76+
_delegated_methods = {
77+
"size", "glob", "load_cache", "_open", "save_cache", "close_and_update",
78+
"__init__", "__getattribute__", "__reduce__", "_make_local_details", "open",
79+
"cat", "cat_file", "cat_ranges", "get", "read_block", "tail", "head", "info",
80+
"ls", "exists", "isfile", "isdir", "_check_file", "_check_cache", "_mkcache",
81+
"clear_cache", "clear_expired_cache", "pop_from_cache", "local_file",
82+
"_paths_from_path", "get_mapper", "open_many", "commit_many", "hash_name",
83+
"__hash__", "__eq__", "to_json", "to_dict", "cache_size", "pipe_file", "pipe",
84+
"start_transaction", "end_transaction"
85+
}

pydala/catalog.py

Lines changed: 85 additions & 96 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,75 @@
1010
from pydala.helpers.polars import pl
1111

1212
from .dataset import CsvDataset, JsonDataset, ParquetDataset, PyarrowDataset
13+
14+
from abc import ABC, abstractmethod
15+
16+
class AbstractLoader(ABC):
17+
def _matches_format(self, params) -> bool:
18+
raise NotImplementedError
19+
20+
def _read_data(self, catalog, params, **kwargs) -> pl.DataFrame:
21+
raise NotImplementedError
22+
23+
def _get_dataset_class(self):
24+
raise NotImplementedError
25+
26+
def load(self, catalog, table_name, as_dataset: bool, with_metadata: bool = False, **kwargs):
27+
params = catalog._get_table_params(table_name=table_name)
28+
if not self._matches_format(params):
29+
return None
30+
if not as_dataset:
31+
df = self._read_data(catalog, params, **kwargs)
32+
catalog.ddb_con.register(table_name, df)
33+
return df
34+
cls = self._get_dataset_class(with_metadata)
35+
return cls(params.path, filesystem=catalog.fs[params.filesystem], name=table_name, ddb_con=catalog.ddb_con, **kwargs)
36+
37+
class ParquetLoader(AbstractLoader):
38+
def _matches_format(self, params) -> bool:
39+
return 'parquet' in params.format.lower()
40+
41+
def _read_data(self, catalog, params, **kwargs) -> pl.DataFrame:
42+
fs = catalog.fs[params.filesystem]
43+
if params.path.endswith('.parquet'):
44+
return fs.read_parquet(params.path, **kwargs)
45+
return fs.read_parquet_dataset(params.path, **kwargs)
46+
47+
def _get_dataset_class(self, with_metadata: bool = True):
48+
return ParquetDataset if with_metadata else PyarrowDataset
49+
50+
class CsvLoader(AbstractLoader):
51+
def _matches_format(self, params) -> bool:
52+
return 'csv' in params.format.lower()
53+
54+
def _read_data(self, catalog, params, **kwargs) -> pl.DataFrame:
55+
fs = catalog.fs[params.filesystem]
56+
if params.path.endswith('.csv'):
57+
return fs.read_csv(params.path, **kwargs)
58+
return fs.read_csv_dataset(params.path, **kwargs)
59+
60+
def _get_dataset_class(self, with_metadata: bool = True):
61+
return CsvDataset
62+
63+
class JsonLoader(AbstractLoader):
64+
def _matches_format(self, params) -> bool:
65+
return 'json' in params.format.lower()
66+
67+
def _read_data(self, catalog, params, **kwargs) -> pl.DataFrame:
68+
fs = catalog.fs[params.filesystem]
69+
if params.path.endswith('.json'):
70+
return fs.read_json(params.path, **kwargs)
71+
return fs.read_json_dataset(params.path, **kwargs)
72+
73+
def _get_dataset_class(self, with_metadata: bool = True):
74+
return JsonDataset
75+
76+
# Registry
77+
LOADERS = {
78+
'parquet': ParquetLoader(),
79+
'csv': CsvLoader(),
80+
'json': JsonLoader(),
81+
}
1382
from .filesystem import FileSystem
1483
from .helpers.misc import delattr_rec, get_nested_keys, getattr_rec, setattr_rec
1584
from .helpers.sql import get_table_names
@@ -162,87 +231,22 @@ def files(self, table_name: str) -> list[str]:
162231
)
163232

164233
def load_parquet(
165-
self, table_name: str, as_dataset=True, with_metadata: bool = True, **kwargs
234+
self, table_name: str, as_dataset: bool = True, with_metadata: bool = True, **kwargs
166235
) -> ParquetDataset | PyarrowDataset | pl.DataFrame | None:
167-
params = self._get_table_params(table_name=table_name)
168-
169-
if "parquet" not in params.format.lower():
170-
return
171-
if not as_dataset:
172-
if params.path.endswith(".parquet"):
173-
df = self.fs[params.filesystem].read_parquet(params.path, **kwargs)
174-
self.ddb_con.register(table_name, df)
175-
return df
176-
177-
df = self.fs[params.filesystem].read_parquet_dataset(params.path, **kwargs)
178-
self.ddb_con.register(table_name, df)
179-
return df
180-
181-
if with_metadata:
182-
return ParquetDataset(
183-
params.path,
184-
filesystem=self.fs[params.filesystem],
185-
name=table_name,
186-
ddb_con=self.ddb_con,
187-
**kwargs,
188-
)
189-
190-
return PyarrowDataset(
191-
params.path,
192-
filesystem=self.fs[params.filesystem],
193-
name=table_name,
194-
ddb_con=self.ddb_con,
195-
**kwargs,
196-
)
236+
"""Load Parquet table as DataFrame or dataset."""
237+
return self.load(table_name, as_dataset=as_dataset, with_metadata=with_metadata, **kwargs)
197238

198239
def load_csv(
199240
self, table_name: str, as_dataset: bool = True, **kwargs
200241
) -> CsvDataset | pl.DataFrame | None:
201-
params = self._get_table_params(table_name=table_name)
202-
203-
if "csv" not in params.format.lower():
204-
return
205-
if not as_dataset:
206-
if params.path.endswith(".csv"):
207-
df = self.fs[params.filesystem].read_parquet(params.path, **kwargs)
208-
self.ddb_con.register(table_name, df)
209-
return df
210-
211-
df = self.fs[params.filesystem].read_parquet_dataset(params.path, **kwargs)
212-
self.ddb_con.register(table_name, df)
213-
return df
214-
215-
return CsvDataset(
216-
params.path,
217-
filesystem=self.fs[params.filesystem],
218-
name=table_name,
219-
ddb_con=self.ddb_con,
220-
**kwargs,
221-
)
242+
"""Load CSV table as DataFrame or dataset."""
243+
return self.load(table_name, as_dataset=as_dataset, with_metadata=False, **kwargs)
222244

223245
def load_json(
224246
self, table_name: str, as_dataset: bool = True, **kwargs
225247
) -> JsonDataset | pl.DataFrame | None:
226-
params = self._get_table_params(table_name=table_name)
227-
228-
if "json" not in params.format.lower():
229-
return
230-
if not as_dataset:
231-
if params.path.endswith(".json"):
232-
df = self.fs[params.filesystem].read_json(params.path, **kwargs)
233-
self.ddb_con.register(table_name, df)
234-
return df
235-
236-
df = self.fs[params.filesystem].read_json_dataset(params.path, **kwargs)
237-
self.ddb_con.register(table_name, df)
238-
return df
239-
return JsonDataset(
240-
params.path,
241-
filesystem=self.fs[params.filesystem],
242-
name=table_name,
243-
ddb_con=self.ddb_con,
244-
**kwargs,
245-
)
248+
"""Load JSON table as DataFrame or dataset."""
249+
return self.load(table_name, as_dataset=as_dataset, with_metadata=False, **kwargs)
246250

247251
def load(
248252
self,
@@ -253,30 +257,15 @@ def load(
253257
**kwargs,
254258
):
255259
params = self._get_table_params(table_name=table_name)
256-
257-
if params.format.lower() == "parquet":
258-
if table_name not in self.table and not reload:
259-
self.table[table_name] = self.load_parquet(
260-
table_name,
261-
as_dataset=as_dataset,
262-
with_metadata=with_metadata,
263-
**kwargs,
264-
)
265-
return self.table[table_name]
266-
267-
elif params.format.lower() == "csv":
268-
if table_name not in self.table and not reload:
269-
self.table[table_name] = self.load_csv(
270-
table_name, as_dataset=as_dataset, **kwargs
271-
)
272-
return self.table[table_name]
273-
274-
elif params.format.lower() == "json":
275-
if table_name not in self.table and not reload:
276-
self.table[table_name] = self.load_json(table_name, **kwargs)
277-
return self.table[table_name]
278-
279-
# return None
260+
format_lower = params.format.lower()
261+
loader = LOADERS.get(format_lower)
262+
if loader is None:
263+
return None
264+
if table_name not in self.table and not reload:
265+
self.table[table_name] = loader.load(
266+
self, table_name, as_dataset, with_metadata, **kwargs
267+
)
268+
return self.table[table_name]
280269

281270
# def _ddb_table_mapping(self, table_name: str):
282271
# params = getattr_rec(self._catalog, self._get_table_from_table_name(table_name=table_name))

0 commit comments

Comments
 (0)