Skip to content

Commit 4993339

Browse files
authored
Merge pull request #65 from BAMresearch/csv_iosource
Csv iosource
2 parents 8b32e24 + f41a4f8 commit 4993339

File tree

14 files changed

+597
-166
lines changed

14 files changed

+597
-166
lines changed

src/modacor/io/csv/__init__.py

Whitespace-only changes.

src/modacor/io/csv/csv_source.py

Lines changed: 210 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,210 @@
1+
# SPDX-License-Identifier: BSD-3-Clause
2+
# /usr/bin/env python3
3+
# -*- coding: utf-8 -*-
4+
5+
from __future__ import annotations
6+
7+
__coding__ = "utf-8"
8+
__authors__ = ["Brian R. Pauw"]
9+
__copyright__ = "Copyright 2025, The MoDaCor team"
10+
__date__ = "12/12/2025"
11+
__status__ = "Development" # "Development", "Production"
12+
# end of header and standard imports
13+
14+
__all__ = ["CSVSource"]
15+
16+
from collections.abc import Callable
17+
from pathlib import Path
18+
from typing import Any
19+
20+
import numpy as np
21+
from attrs import define, field, validators
22+
23+
from modacor.dataclasses.messagehandler import MessageHandler
24+
from modacor.io.io_source import ArraySlice
25+
26+
from ..io_source import IoSource
27+
28+
29+
def _is_callable(_, __, value):
30+
if not callable(value):
31+
raise TypeError("method must be callable")
32+
33+
34+
@define(kw_only=True)
35+
class CSVSource(IoSource):
36+
"""
37+
IoSource for loading columnar data from CSV-like text files using NumPy's
38+
loadtxt or genfromtxt.
39+
40+
Expected usage
41+
--------------
42+
- Data is 1D per column (no multi-dimensional fields).
43+
- Columns are returned as 1D arrays; each column corresponds to one data_key.
44+
- for np.loadtxt, column names must be provided via dtype with field names, e.g.:
45+
dtype=[("q", float), ("I", float), ("I_sigma", float)]
46+
- for np.genfromtxt, column names come from the first row or are specified explicitly via the `names` parameter. Typical patterns:
47+
* np.genfromtxt(..., names=True, delimiter=..., ...) # use first row as names
48+
* np.genfromtxt(..., names=["q", "I", "I_sigma"], ...) # specify names explicitly
49+
so that they can be clearly identified later.
50+
51+
Configuration
52+
-------------
53+
`iosource_method_kwargs` is passed directly to the NumPy function `method`.
54+
This allows you to use all standard NumPy options, e.g.:
55+
56+
For np.genfromtxt:
57+
delimiter=","
58+
skip_header=3
59+
max_rows=1000
60+
usecols=(0, 1, 2)
61+
names=True or names=["q", "I", "sigma"]
62+
dtype=None or dtype=float
63+
encoding="utf-8"
64+
comments="#"
65+
...
66+
67+
For np.loadtxt:
68+
delimiter=","
69+
skiprows=3
70+
max_rows=1000
71+
usecols=(0, 1, 2)
72+
dtype=float
73+
encoding="utf-8"
74+
comments="#"
75+
...
76+
77+
Notes
78+
-----
79+
- 2D arrays (no field names) are not supported in this implementation.
80+
If the resulting array does not have `dtype.names`, a ValueError is raised.
81+
"""
82+
83+
# external API:
84+
resource_location: Path = field(converter=Path, validator=validators.instance_of((Path)))
85+
method: Callable[..., np.ndarray] = field(
86+
default=np.genfromtxt, validator=_is_callable
87+
) # default to genfromtxt, better for names
88+
# internal use (type hints; real values set per-instance)
89+
_data_cache: np.ndarray | None = field(init=False, default=None)
90+
_data_dict_cache: dict[str, np.ndarray] = field(factory=dict)
91+
_file_datasets_dtypes: dict[str, np.dtype] = field(init=False)
92+
_file_datasets_shapes: dict[str, tuple[int, ...]] = field(init=False)
93+
logger: MessageHandler = field(init=False)
94+
95+
def __attrs_post_init__(self) -> None:
96+
# super().__init__(source_reference=self.source_reference, iosource_method_kwargs=self.iosource_method_kwargs)
97+
self.logger = MessageHandler(level=self.logging_level, name="CSVSource")
98+
# Set file path
99+
if not self.resource_location.is_file():
100+
self.logger.error(f"CSVSource: file {self.resource_location} does not exist.")
101+
102+
# Bookkeeping structures for IoSource API
103+
self._file_datasets_shapes: dict[str, tuple[int, ...]] = {}
104+
self._file_datasets_dtypes: dict[str, np.dtype] = {}
105+
106+
# Load and preprocess data immediately
107+
self._load_data()
108+
self._preload()
109+
110+
# ------------------------------------------------------------------ #
111+
# Internal loading / preprocessing #
112+
# ------------------------------------------------------------------ #
113+
114+
def _load_data(self) -> None:
115+
"""
116+
Load the CSV data into a structured NumPy array using the configured
117+
method (np.genfromtxt or np.loadtxt).
118+
119+
iosource_method_kwargs are passed directly to that method.
120+
"""
121+
self.logger.info(
122+
f"CSVSource loading data from {self.resource_location} "
123+
f"using {self.method.__name__} with options: {self.iosource_method_kwargs}"
124+
)
125+
126+
try:
127+
self._data_cache = self.method(self.resource_location, **self.iosource_method_kwargs)
128+
except Exception as exc: # noqa: BLE001
129+
self.logger.error(f"Error while loading CSV data from {self.resource_location}: {exc}")
130+
raise
131+
132+
if self._data_cache is None:
133+
raise ValueError(f"CSVSource: no data loaded from file {self.resource_location}.")
134+
# Ensure we have a structured array with named fields
135+
if self._data_cache.dtype.names is None:
136+
raise ValueError(
137+
"CSVSource expected a structured array with named fields, "
138+
"but dtype.names is None.\n"
139+
"Hint: use np.genfromtxt with 'names=True' or 'names=[...]', "
140+
"or provide an appropriate 'dtype' with field names."
141+
)
142+
143+
def _preload(self) -> None:
144+
"""
145+
Populate dataset lists, shapes, and dtypes from the structured array.
146+
"""
147+
assert self._data_cache is not None # for type checkers
148+
149+
self._data_dict_cache = {}
150+
self._file_datasets_shapes.clear()
151+
self._file_datasets_dtypes.clear()
152+
153+
for name in self._data_cache.dtype.names:
154+
column = self._data_cache[name]
155+
self._data_dict_cache[name] = self._data_cache[name]
156+
self._file_datasets_shapes[name] = column.shape
157+
self._file_datasets_dtypes[name] = column.dtype
158+
159+
self.logger.info(f"CSVSource loaded datasets: {self._file_datasets_shapes.keys()}")
160+
161+
# ------------------------------------------------------------------ #
162+
# IoSource API #
163+
# ------------------------------------------------------------------ #
164+
165+
def get_static_metadata(self, data_key: str) -> None:
166+
"""
167+
CSVSource does not support static metadata; always returns None.
168+
"""
169+
self.logger.warning(
170+
f"You asked for static metadata '{data_key}', but CSVSource does not support static metadata."
171+
)
172+
return None
173+
174+
def get_data(self, data_key: str, load_slice: ArraySlice = ...) -> np.ndarray:
175+
"""
176+
Return the data column corresponding to `data_key`, cast to float, apply `load_slice`.
177+
178+
- data_key must match one of the field names in the structured array.
179+
- `load_slice` is applied to that 1D column (e.g. ellipsis, slice, array of indices).
180+
"""
181+
if self._data_cache is None:
182+
raise RuntimeError("CSVSource data cache is empty; loading may have failed.")
183+
184+
try:
185+
column = self._data_dict_cache[data_key]
186+
except KeyError:
187+
raise KeyError(
188+
f"Data key '{data_key}' not found in CSV data. Available keys: {list(self._data_dict_cache.keys())}" # noqa: E713
189+
) from None
190+
191+
return np.asarray(column[load_slice]).astype(float)
192+
193+
def get_data_shape(self, data_key: str) -> tuple[int, ...]:
194+
if data_key in self._file_datasets_shapes:
195+
return self._file_datasets_shapes[data_key]
196+
return ()
197+
198+
def get_data_dtype(self, data_key: str) -> np.dtype | None:
199+
if data_key in self._file_datasets_dtypes:
200+
return self._file_datasets_dtypes[data_key]
201+
return None
202+
203+
def get_data_attributes(self, data_key: str) -> dict[str, Any]:
204+
"""
205+
CSV has no per-dataset attributes; return a dict with None.
206+
"""
207+
self.logger.warning(
208+
f"You asked for attributes of '{data_key}', but CSVSource does not support data attributes."
209+
)
210+
return {data_key: None}
Lines changed: 27 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -13,13 +13,13 @@
1313
__status__ = "Development" # "Development", "Production"
1414
# end of header and standard imports
1515

16-
__all__ = ["HDFLoader"]
16+
__all__ = ["HDFSource"]
1717

18-
from logging import WARNING
1918
from pathlib import Path
2019

2120
import h5py
2221
import numpy as np
22+
from attrs import define, field, validators
2323

2424
from modacor.dataclasses.messagehandler import MessageHandler
2525

@@ -29,17 +29,30 @@
2929
from ..io_source import IoSource
3030

3131

32-
class HDFLoader(IoSource):
33-
_data_cache: dict[str, np.ndarray] = None
34-
_file_path: Path | None = None
35-
_static_metadata_cache: dict[str, Any] = None
36-
37-
def __init__(self, source_reference: str, logging_level=WARNING, resource_location: Path | str | None = None):
38-
super().__init__(source_reference=source_reference)
39-
self.logger = MessageHandler(level=logging_level, name="HDFLoader")
40-
self._file_path = Path(resource_location) if resource_location is not None else None
41-
# self._file_reference = None # let's not leave open file references lying around if we can help it.
42-
self._file_datasets = []
32+
@define(kw_only=True)
33+
class HDFSource(IoSource):
34+
resource_location: Path | str | None = field(
35+
init=True, default=None, validator=validators.optional(validators.instance_of((Path, str)))
36+
)
37+
_data_cache: dict[str, np.ndarray] = field(init=False, factory=dict, validator=validators.instance_of(dict))
38+
_file_path: Path | None = field(
39+
init=False, default=None, validator=validators.optional(validators.instance_of(Path))
40+
)
41+
_file_datasets_shapes: dict[str, tuple[int, ...]] = field(
42+
init=False, factory=dict, validator=validators.instance_of(dict)
43+
)
44+
_file_datasets_dtypes: dict[str, np.dtype] = field(init=False, factory=dict, validator=validators.instance_of(dict))
45+
_static_metadata_cache: dict[str, Any] = field(init=False, factory=dict, validator=validators.instance_of(dict))
46+
logger: MessageHandler = field(init=False)
47+
48+
# source_reference comes from IoSource
49+
# iosource_method_kwargs comes from IoSource
50+
51+
def __attrs_post_init__(self):
52+
# super().__init__(source_reference=source_reference)
53+
self.logger = MessageHandler(level=self.logging_level, name="HDFSource")
54+
self._file_path = Path(self.resource_location) if self.resource_location is not None else None
55+
# self._file_datasets = []
4356
self._file_datasets_shapes = {}
4457
self._file_datasets_dtypes = {}
4558
self._data_cache = {}
@@ -61,7 +74,7 @@ def _find_datasets(self, path_name, path_object):
6174
the datasets within
6275
"""
6376
if isinstance(path_object, h5py._hl.dataset.Dataset):
64-
self._file_datasets.append(path_name)
77+
# self._file_datasets.append(path_name)
6578
self._file_datasets_shapes[path_name] = path_object.shape
6679
self._file_datasets_dtypes[path_name] = path_object.dtype
6780

src/modacor/io/io_source.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,8 @@
44

55
from __future__ import annotations
66

7+
from logging import WARNING
8+
79
import attrs
810

911
__coding__ = "utf-8"
@@ -67,6 +69,9 @@ class IoSource:
6769
configuration: dict[str, Any] = field(factory=default_config)
6870
source_reference: str = field(default="", converter=str, validator=attrs.validators.instance_of(str))
6971
type_reference: str = "IoSource"
72+
# for passing extra kwargs to the data loading method if needed (e.g. csv_source)
73+
iosource_method_kwargs: dict[str, Any] = field(factory=dict, validator=attrs.validators.instance_of(dict))
74+
logging_level: int = field(default=WARNING, validator=attrs.validators.instance_of(int))
7075

7176
def get_data(self, data_key: str, load_slice: Optional[ArraySlice] = None) -> np.ndarray:
7277
"""
Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,14 +11,15 @@
1111
__status__ = "Development" # "Development", "Production"
1212
# end of header and standard imports
1313

14-
__all__ = ["YAMLLoader"]
14+
__all__ = ["YAMLSource"]
1515

1616
from logging import WARNING
1717
from pathlib import Path
1818
from typing import Any
1919

2020
import numpy as np
2121
import yaml
22+
from attrs import define, field, validators
2223

2324
from modacor.dataclasses.messagehandler import MessageHandler
2425
from modacor.io.io_source import ArraySlice
@@ -37,7 +38,8 @@ def get_from_nested_dict_by_path(data, path):
3738
return data
3839

3940

40-
class YAMLLoader(IoSource):
41+
@define(kw_only=True)
42+
class YAMLSource(IoSource):
4143
"""
4244
This IoSource is used to load and make experiment metadata available to
4345
the processing pipeline modules.
@@ -48,17 +50,18 @@ class YAMLLoader(IoSource):
4850
The entries are returned as BaseData elements, with units and uncertainties.
4951
"""
5052

51-
_yaml_data: dict[str, Any] = dict()
52-
_data_cache: dict[str, np.ndarray] = None
53-
_file_path: Path | None = None
54-
_static_metadata_cache: dict[str, Any] = None
55-
56-
def __init__(self, source_reference: str, logging_level=WARNING, resource_location: Path | str | None = None):
57-
super().__init__(source_reference=source_reference)
58-
self.logger = MessageHandler(level=logging_level, name="YAMLLoader")
59-
self._file_path = Path(resource_location) if resource_location is not None else None
60-
self._file_datasets = []
61-
self._file_datasets_shapes = {}
53+
resource_location: Path = field(converter=Path, validator=validators.instance_of((Path)))
54+
_yaml_data: dict[str, Any] = field(factory=dict, validator=validators.instance_of(dict))
55+
_data_cache: dict[str, np.ndarray] = field(factory=dict, validator=validators.instance_of(dict))
56+
_file_path: Path | None = field(default=None, validator=validators.optional(validators.instance_of(Path)))
57+
_static_metadata_cache: dict[str, Any] = field(factory=dict, validator=validators.instance_of(dict))
58+
logging_level: int = field(default=WARNING, validator=validators.instance_of(int))
59+
logger: MessageHandler = field(init=False)
60+
61+
def __attrs_post_init__(self):
62+
# super().__init__(source_reference=source_reference)
63+
self.logger = MessageHandler(level=self.logging_level, name="YAMLSource")
64+
self._file_path = Path(self.resource_location) if self.resource_location is not None else None
6265
self._data_cache = {} # for values that are float
6366
self._static_metadata_cache = {} # for other elements such as strings and tags
6467
self._preload() # load the yaml data immediately

0 commit comments

Comments
 (0)