Skip to content

Commit 78dc5ae

Browse files
authored
Automatically load streams on accessor (#79)
* Expose public method to set parent data stream * Implement automatic data loading * Add tests for automatic data loading * Add tests for implicit loading with errored streams * Add example on how to use the context * Export method at the module level * Use public property for consistency * Expose public method to set parent data stream * Implement automatic data loading * Add tests for automatic data loading * Add tests for implicit loading with errored streams * Add example on how to use the context * Export method at the module level * Use public property for consistency
1 parent fd84793 commit 78dc5ae

File tree

3 files changed

+277
-44
lines changed

3 files changed

+277
-44
lines changed

src/contraqctor/contract/__init__.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,12 @@
11
from . import camera, csv, harp, json, mux, text, utils
2-
from .base import Dataset, DataStream, DataStreamCollection, DataStreamCollectionBase, FilePathBaseParam
2+
from .base import (
3+
Dataset,
4+
DataStream,
5+
DataStreamCollection,
6+
DataStreamCollectionBase,
7+
FilePathBaseParam,
8+
implicit_loading,
9+
)
310
from .utils import print_data_stream_tree
411

512
__all__ = [
@@ -8,6 +15,7 @@
815
"DataStreamCollection",
916
"Dataset",
1017
"DataStreamCollectionBase",
18+
"implicit_loading",
1119
"print_data_stream_tree",
1220
"camera",
1321
"csv",

src/contraqctor/contract/base.py

Lines changed: 111 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import abc
2+
import contextvars
23
import dataclasses
34
import os
5+
from contextlib import contextmanager
46
from typing import (
57
Any,
68
ClassVar,
@@ -21,6 +23,50 @@
2123

2224
from contraqctor import _typing
2325

26+
_implicit_loading = contextvars.ContextVar("implicit_loading", default=True)
27+
28+
29+
@contextmanager
30+
def implicit_loading(value: bool = True):
31+
"""Context manager to control whether streams automatically load data on access.
32+
33+
When enabled, data streams will automatically load their data when accessed. When disabled,
34+
accessing a data stream without prior loading will raise an error. Call `load()` explicitly
35+
instead.
36+
37+
Args:
38+
value: True to enable auto-loading, False to disable. Default is True.
39+
40+
Examples:
41+
```python
42+
# Assume you have nested collections already created
43+
# collection.at("sensors").at("temperature") -> temperature sensor data
44+
# collection.at("sensors").at("humidity") -> humidity sensor data
45+
# collection.at("logs").at("error_log") -> error log file
46+
47+
# With implicit loading enabled (default behavior)
48+
with implicit_loading(True):
49+
# Data loads automatically on access
50+
temp_data = collection.at("sensors").at("temperature").data
51+
humidity_data = collection.at("sensors").at("humidity").data
52+
53+
# With implicit loading disabled - requires explicit loading
54+
with implicit_loading(False):
55+
# This would raise ValueError: "Data has not been loaded yet"
56+
try:
57+
temp_data = collection.at("sensors").at("temperature").data
58+
except ValueError:
59+
# Must load explicitly first
60+
collection.load_all()
61+
temp_data = collection.at("sensors").at("temperature").data
62+
```
63+
"""
64+
token = _implicit_loading.set(value)
65+
try:
66+
yield
67+
finally:
68+
_implicit_loading.reset(token)
69+
2470

2571
@runtime_checkable
2672
class _AtProtocol(Protocol):
@@ -92,9 +138,9 @@ def resolved_name(self) -> str:
92138
"""
93139
builder = self.name
94140
d = self
95-
while d._parent is not None:
96-
builder = f"{d._parent.name}::{builder}"
97-
d = d._parent
141+
while d.parent is not None:
142+
builder = f"{d.parent.name}::{builder}"
143+
d = d.parent
98144
return builder
99145

100146
@property
@@ -115,6 +161,14 @@ def parent(self) -> Optional["DataStream"]:
115161
"""
116162
return self._parent
117163

164+
def set_parent(self, parent: "DataStream") -> None:
165+
"""Set the parent data stream.
166+
167+
Args:
168+
parent: The parent data stream to set.
169+
"""
170+
self._parent = parent
171+
118172
@property
119173
def is_collection(self) -> bool:
120174
"""Check if this data stream is a collection of other streams.
@@ -235,10 +289,25 @@ def data(self) -> _typing.TData:
235289
Raises:
236290
ValueError: If data has not been loaded yet.
237291
"""
292+
return self._solve_data_load()
293+
294+
def _solve_data_load(self) -> _typing.TData:
295+
"""Resolve data loading based on the current state and implicit loading setting."""
296+
if self.has_data:
297+
return cast(_typing.TData, self._data)
298+
299+
# If there is an error we do not auto load
300+
# and instead raise the existing error
301+
# We use .load() to explicitly retry loading
302+
if (not self.has_error) and _implicit_loading.get():
303+
self.load()
304+
238305
if self.has_error:
239306
cast(_typing.ErrorOnLoad, self._data).raise_from_error()
240-
if not self.has_data:
307+
308+
if not (self.has_data):
241309
raise ValueError("Data has not been loaded yet.")
310+
242311
return cast(_typing.TData, self._data)
243312

244313
def clear(self) -> Self:
@@ -293,7 +362,7 @@ def __str__(self):
293362
f"name={self._name}, "
294363
f"description={self._description}, "
295364
f"reader_params={self._reader_params}, "
296-
f"data_type={self._data.__class__.__name__ if self.has_data else 'Not Loaded'}"
365+
f"data_type={self.data.__class__.__name__ if self.has_data else 'Data not loaded'}"
297366
)
298367

299368
def __iter__(self) -> Generator["DataStream", None, None]:
@@ -370,18 +439,19 @@ def __init__(self, data_stream: "DataStreamCollectionBase[TDataStream, Any]"):
370439

371440
def __call__(self, name: str) -> TDataStream:
372441
"""Access a data stream by name."""
373-
if not self._data_stream.has_data:
374-
raise ValueError("data streams have not been read yet. Cannot access data streams.")
442+
443+
self._data_stream._solve_data_load()
444+
375445
try:
376-
return self._data_stream._hashmap[name]
377-
except KeyError:
378-
raise KeyError(f"Stream with name: '{name}' not found in data streams.")
446+
return self._data_stream._data_stream_mapping[name]
447+
except KeyError as exc:
448+
raise KeyError(f"Stream with name: '{name}' not found in data streams.") from exc
379449

380450
def __dir__(self):
381451
"""List available attributes for the At accessor. This ensures autocompletion at runtime."""
382452
base = list(object.__dir__(self))
383-
if hasattr(self, "_data_stream") and hasattr(self._data_stream, "_hashmap"):
384-
h = list(self._data_stream._hashmap.keys())
453+
if hasattr(self, "_data_stream") and hasattr(self._data_stream, "_data_stream_mapping"):
454+
h = list(self._data_stream._data_stream_mapping.keys())
385455
return h + base
386456
else:
387457
return base
@@ -392,8 +462,9 @@ def __getattribute__(self, name: str) -> Any:
392462
return object.__getattribute__(self, name)
393463
except AttributeError:
394464
_data_stream = object.__getattribute__(self, "_data_stream")
395-
if name in _data_stream._hashmap:
396-
return _data_stream._hashmap[name]
465+
if name in _data_stream._data_stream_mapping:
466+
# Redirect to __call__ to get the stream by name
467+
return self.__call__(name)
397468
raise
398469

399470

@@ -423,12 +494,12 @@ def __init__(
423494
**kwargs,
424495
) -> None:
425496
super().__init__(name=name, description=description, reader_params=reader_params, **kwargs)
426-
self._hashmap: Dict[str, TDataStream] = {}
427-
self._update_hashmap()
497+
self._data_stream_mapping: Dict[str, TDataStream] = {}
498+
self._update_data_stream_mapping()
428499
self._at = _At(self)
429500

430-
def _update_hashmap(self) -> None:
431-
"""Update the internal hashmap of child data streams.
501+
def _update_data_stream_mapping(self) -> None:
502+
"""Update the internal mapping of name: child data streams.
432503
433504
Validates that all child streams have unique names and updates the lookup table.
434505
@@ -437,11 +508,11 @@ def _update_hashmap(self) -> None:
437508
"""
438509
if not self.has_data:
439510
return
440-
stream_keys = [stream.name for stream in self.data]
511+
stream_keys = [stream.name for stream in self._data]
441512
duplicates = [name for name in stream_keys if stream_keys.count(name) > 1]
442513
if duplicates:
443514
raise ValueError(f"Duplicate names found in the data stream collection: {set(duplicates)}")
444-
self._hashmap = {stream.name: stream for stream in self.data}
515+
self._data_stream_mapping = {stream.name: stream for stream in self._data}
445516
self._update_parent_references()
446517
return
447518

@@ -450,8 +521,8 @@ def _update_parent_references(self) -> None:
450521
451522
Sets this collection as the parent for all child streams.
452523
"""
453-
for stream in self._hashmap.values():
454-
stream._parent = self
524+
for stream in self._data_stream_mapping.values():
525+
stream.set_parent(self)
455526

456527
@property
457528
def at(self) -> _At[TDataStream]:
@@ -478,7 +549,7 @@ def load(self) -> Self:
478549
if not isinstance(self._data, list):
479550
self._data = _typing.UnsetData
480551
raise ValueError("Data must be a list of DataStreams.")
481-
self._update_hashmap()
552+
self._update_data_stream_mapping()
482553
return self
483554

484555
def __str__(self: Self) -> str:
@@ -494,9 +565,13 @@ def __str__(self: Self) -> str:
494565
if not self.has_data:
495566
return f"{self.__class__.__name__} has not been loaded yet."
496567

497-
for key, value in self._hashmap.items():
568+
for key, value in self._data_stream_mapping.items():
498569
table.append(
499-
[key, value.data.__class__.__name__ if value.has_data else "Unknown", "Yes" if value.has_data else "No"]
570+
[
571+
key,
572+
value.data.__class__.__name__ if value.has_data else "Unknown",
573+
"Yes" if value.has_data else "No",
574+
]
500575
)
501576

502577
max_lengths = [max(len(str(row[i])) for row in table) for i in range(len(table[0]))]
@@ -519,8 +594,9 @@ def __iter__(self) -> Generator[DataStream, None, None]:
519594
DataStream: Child data streams.
520595
521596
"""
522-
for value in self._hashmap.values():
523-
yield value
597+
# We intentionally yield from self.data to trigger
598+
# automatic loading if needed
599+
yield from self.data
524600

525601
def iter_all(self) -> Generator[DataStream, None, None]:
526602
"""Iterator for all child data streams, including nested collections.
@@ -627,7 +703,7 @@ def bind_data_streams(self, data_streams: List[DataStream]) -> Self:
627703
if self.has_data:
628704
raise ValueError("Data streams are already set. Cannot bind again.")
629705
self._data = data_streams
630-
self._update_hashmap()
706+
self._update_data_stream_mapping()
631707
return self
632708

633709
def add_stream(self, stream: DataStream) -> Self:
@@ -660,14 +736,14 @@ def add_stream(self, stream: DataStream) -> Self:
660736
"""
661737
if not self.has_data:
662738
self._data = [stream]
663-
self._update_hashmap()
739+
self._update_data_stream_mapping()
664740
return self
665741

666-
if stream.name in self._hashmap:
742+
if stream.name in self._data_stream_mapping:
667743
raise KeyError(f"Stream with name: '{stream.name}' already exists in data streams.")
668744

669745
self._data.append(stream)
670-
self._update_hashmap()
746+
self._update_data_stream_mapping()
671747
return self
672748

673749
def remove_stream(self, name: str) -> None:
@@ -683,10 +759,10 @@ def remove_stream(self, name: str) -> None:
683759
if not self.has_data:
684760
raise ValueError("Data streams have not been read yet. Cannot access data streams.")
685761

686-
if name not in self._hashmap:
762+
if name not in self._data_stream_mapping:
687763
raise KeyError(f"Data stream with name '{name}' not found in data streams.")
688-
self._data.remove(self._hashmap[name])
689-
self._update_hashmap()
764+
self._data.remove(self._data_stream_mapping[name])
765+
self._update_data_stream_mapping()
690766
return
691767

692768
@classmethod
@@ -709,7 +785,7 @@ def from_data_stream(cls, data_stream: DataStream) -> Self:
709785
raise TypeError("data_stream must be an instance of DataStream.")
710786
if not data_stream.has_data:
711787
raise ValueError("DataStream has not been loaded yet. Cannot create DataStreamCollection.")
712-
data = data_stream.data if data_stream.is_collection else [data_stream.data]
788+
data = data_stream._data if data_stream.is_collection else [data_stream._data]
713789
return cls(name=data_stream.name, data_streams=data, description=data_stream.description)
714790

715791

0 commit comments

Comments
 (0)