Skip to content

Commit a237b26

Browse files
Simplify fetch API: remove download_path, fetch.py, Fetch1 class
- Move decode logic from fetch._get() to codecs.decode_attribute() - Remove download_path parameter from all fetch methods (use config["download_path"] or config.override() instead) - Convert fetch1 from property+Fetch1 class to regular method - Delete fetch.py entirely (no longer needed) API changes: - table.to_dicts(download_path=...) → use config.override(download_path=...) - table.fetch1 (property) → table.fetch1() (method) - unchanged usage 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent a024baa commit a237b26

File tree

4 files changed

+159
-200
lines changed

4 files changed

+159
-200
lines changed

src/datajoint/codecs.py

Lines changed: 74 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -441,6 +441,80 @@ def lookup_codec(codec_spec: str) -> tuple[Codec, str | None]:
441441
raise DataJointError(f"Codec <{type_name}> is not registered. " "Define a Codec subclass with name='{type_name}'.")
442442

443443

444+
# =============================================================================
445+
# Decode Helper
446+
# =============================================================================
447+
448+
449+
def decode_attribute(attr, data, squeeze: bool = False):
450+
"""
451+
Decode raw database value using attribute's codec or native type handling.
452+
453+
This is the central decode function used by all fetch methods. It handles:
454+
- Codec chains (e.g., <blob@store> → <hash> → bytes)
455+
- Native type conversions (JSON, UUID)
456+
- External storage downloads (via config["download_path"])
457+
458+
Args:
459+
attr: Attribute from the table's heading.
460+
data: Raw value fetched from the database.
461+
squeeze: If True, remove singleton dimensions from numpy arrays.
462+
463+
Returns:
464+
Decoded Python value.
465+
"""
466+
import json
467+
import uuid as uuid_module
468+
469+
import numpy as np
470+
471+
if data is None:
472+
return None
473+
474+
if attr.codec:
475+
# Get store if present for external storage
476+
store = getattr(attr, "store", None)
477+
if store is not None:
478+
dtype_spec = f"<{attr.codec.name}@{store}>"
479+
else:
480+
dtype_spec = f"<{attr.codec.name}>"
481+
482+
final_dtype, type_chain, _ = resolve_dtype(dtype_spec)
483+
484+
# Process the final storage type (what's in the database)
485+
if final_dtype.lower() == "json":
486+
data = json.loads(data)
487+
elif final_dtype.lower() in ("longblob", "blob", "mediumblob", "tinyblob"):
488+
pass # Blob data is already bytes
489+
elif final_dtype.lower() == "binary(16)":
490+
data = uuid_module.UUID(bytes=data)
491+
492+
# Apply decoders in reverse order: innermost first, then outermost
493+
for codec in reversed(type_chain):
494+
data = codec.decode(data, key=None)
495+
496+
# Squeeze arrays if requested
497+
if squeeze and isinstance(data, np.ndarray):
498+
data = data.squeeze()
499+
500+
return data
501+
502+
# No codec - handle native types
503+
if attr.json:
504+
return json.loads(data)
505+
506+
if attr.uuid:
507+
import uuid as uuid_module
508+
509+
return uuid_module.UUID(bytes=data)
510+
511+
if attr.is_blob:
512+
return data # Raw bytes
513+
514+
# Native types - pass through unchanged
515+
return data
516+
517+
444518
# =============================================================================
445519
# Auto-register built-in codecs
446520
# =============================================================================

src/datajoint/expression.py

Lines changed: 84 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import pandas
1919

2020
from .errors import DataJointError
21-
from .fetch import Fetch1, _get
21+
from .codecs import decode_attribute
2222
from .preview import preview, repr_html
2323
from .settings import config
2424

@@ -582,53 +582,116 @@ def fetch(self):
582582
"See table.fetch.__doc__ for details."
583583
)
584584

585-
@property
586-
def fetch1(self):
587-
return Fetch1(self)
585+
def fetch1(self, *attrs, squeeze=False):
586+
"""
587+
Fetch exactly one row from the query result.
588+
589+
If no attributes are specified, returns the result as a dict.
590+
If attributes are specified, returns the corresponding values as a tuple.
591+
592+
:param attrs: attribute names to fetch (if empty, fetch all as dict)
593+
:param squeeze: if True, remove extra dimensions from arrays
594+
:return: dict (no attrs) or tuple/value (with attrs)
595+
:raises DataJointError: if not exactly one row in result
596+
597+
Examples::
598+
599+
d = table.fetch1() # returns dict with all attributes
600+
a, b = table.fetch1('a', 'b') # returns tuple of attribute values
601+
value = table.fetch1('a') # returns single value
602+
"""
603+
heading = self.heading
604+
605+
if not attrs:
606+
# Fetch all attributes, return as dict
607+
cursor = self.cursor(as_dict=True)
608+
row = cursor.fetchone()
609+
if not row or cursor.fetchone():
610+
raise DataJointError("fetch1 requires exactly one tuple in the input set.")
611+
return {name: decode_attribute(heading[name], row[name], squeeze=squeeze) for name in heading.names}
612+
else:
613+
# Handle "KEY" specially - it means primary key columns
614+
def is_key(attr):
615+
return attr == "KEY"
616+
617+
has_key = any(is_key(a) for a in attrs)
618+
619+
if has_key and len(attrs) == 1:
620+
# Just fetching KEY - return the primary key dict
621+
keys = self.keys()
622+
if len(keys) != 1:
623+
raise DataJointError(f"fetch1 should only return one tuple. {len(keys)} tuples found")
624+
return keys[0]
625+
626+
# Fetch specific attributes, return as tuple
627+
# Replace KEY with primary key columns for projection
628+
proj_attrs = []
629+
for attr in attrs:
630+
if is_key(attr):
631+
proj_attrs.extend(self.primary_key)
632+
else:
633+
proj_attrs.append(attr)
634+
635+
dicts = self.proj(*proj_attrs).to_dicts(squeeze=squeeze)
636+
if len(dicts) != 1:
637+
raise DataJointError(f"fetch1 should only return one tuple. {len(dicts)} tuples found")
638+
row = dicts[0]
639+
640+
# Build result values, handling KEY specially
641+
values = []
642+
for attr in attrs:
643+
if is_key(attr):
644+
# Return dict of primary key columns
645+
values.append({k: row[k] for k in self.primary_key})
646+
else:
647+
values.append(row[attr])
648+
649+
return values[0] if len(attrs) == 1 else tuple(values)
588650

589651
def _apply_top(self, order_by=None, limit=None, offset=None):
590652
"""Apply order_by, limit, offset if specified, return modified expression."""
591653
if order_by is not None or limit is not None or offset is not None:
592654
return self.restrict(Top(limit, order_by, offset))
593655
return self
594656

595-
def to_dicts(self, order_by=None, limit=None, offset=None, squeeze=False, download_path="."):
657+
def to_dicts(self, order_by=None, limit=None, offset=None, squeeze=False):
596658
"""
597659
Fetch all rows as a list of dictionaries.
598660
599661
:param order_by: attribute(s) to order by, or "KEY"/"KEY DESC"
600662
:param limit: maximum number of rows to return
601663
:param offset: number of rows to skip
602664
:param squeeze: if True, remove extra dimensions from arrays
603-
:param download_path: path for downloading external data (attachments, filepaths)
604665
:return: list of dictionaries, one per row
666+
667+
For external storage types (attachments, filepaths), files are downloaded
668+
to config["download_path"]. Use config.override() to change::
669+
670+
with dj.config.override(download_path="/data"):
671+
data = table.to_dicts()
605672
"""
606673
expr = self._apply_top(order_by, limit, offset)
607674
cursor = expr.cursor(as_dict=True)
608675
heading = expr.heading
609-
return [
610-
{name: _get(expr.connection, heading[name], row[name], squeeze, download_path) for name in heading.names}
611-
for row in cursor
612-
]
676+
return [{name: decode_attribute(heading[name], row[name], squeeze) for name in heading.names} for row in cursor]
613677

614-
def to_pandas(self, order_by=None, limit=None, offset=None, squeeze=False, download_path="."):
678+
def to_pandas(self, order_by=None, limit=None, offset=None, squeeze=False):
615679
"""
616680
Fetch all rows as a pandas DataFrame with primary key as index.
617681
618682
:param order_by: attribute(s) to order by, or "KEY"/"KEY DESC"
619683
:param limit: maximum number of rows to return
620684
:param offset: number of rows to skip
621685
:param squeeze: if True, remove extra dimensions from arrays
622-
:param download_path: path for downloading external data
623686
:return: pandas DataFrame with primary key columns as index
624687
"""
625-
dicts = self.to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze, download_path=download_path)
688+
dicts = self.to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze)
626689
df = pandas.DataFrame(dicts)
627690
if len(df) > 0 and self.primary_key:
628691
df = df.set_index(self.primary_key)
629692
return df
630693

631-
def to_polars(self, order_by=None, limit=None, offset=None, squeeze=False, download_path="."):
694+
def to_polars(self, order_by=None, limit=None, offset=None, squeeze=False):
632695
"""
633696
Fetch all rows as a polars DataFrame.
634697
@@ -638,17 +701,16 @@ def to_polars(self, order_by=None, limit=None, offset=None, squeeze=False, downl
638701
:param limit: maximum number of rows to return
639702
:param offset: number of rows to skip
640703
:param squeeze: if True, remove extra dimensions from arrays
641-
:param download_path: path for downloading external data
642704
:return: polars DataFrame
643705
"""
644706
try:
645707
import polars
646708
except ImportError:
647709
raise ImportError("polars is required for to_polars(). " "Install with: pip install datajoint[polars]")
648-
dicts = self.to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze, download_path=download_path)
710+
dicts = self.to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze)
649711
return polars.DataFrame(dicts)
650712

651-
def to_arrow(self, order_by=None, limit=None, offset=None, squeeze=False, download_path="."):
713+
def to_arrow(self, order_by=None, limit=None, offset=None, squeeze=False):
652714
"""
653715
Fetch all rows as a PyArrow Table.
654716
@@ -658,19 +720,18 @@ def to_arrow(self, order_by=None, limit=None, offset=None, squeeze=False, downlo
658720
:param limit: maximum number of rows to return
659721
:param offset: number of rows to skip
660722
:param squeeze: if True, remove extra dimensions from arrays
661-
:param download_path: path for downloading external data
662723
:return: pyarrow Table
663724
"""
664725
try:
665726
import pyarrow
666727
except ImportError:
667728
raise ImportError("pyarrow is required for to_arrow(). " "Install with: pip install datajoint[arrow]")
668-
dicts = self.to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze, download_path=download_path)
729+
dicts = self.to_dicts(order_by=order_by, limit=limit, offset=offset, squeeze=squeeze)
669730
if not dicts:
670731
return pyarrow.table({})
671732
return pyarrow.Table.from_pylist(dicts)
672733

673-
def to_arrays(self, *attrs, include_key=False, order_by=None, limit=None, offset=None, squeeze=False, download_path="."):
734+
def to_arrays(self, *attrs, include_key=False, order_by=None, limit=None, offset=None, squeeze=False):
674735
"""
675736
Fetch data as numpy arrays.
676737
@@ -683,7 +744,6 @@ def to_arrays(self, *attrs, include_key=False, order_by=None, limit=None, offset
683744
:param limit: maximum number of rows to return
684745
:param offset: number of rows to skip
685746
:param squeeze: if True, remove extra dimensions from arrays
686-
:param download_path: path for downloading external data
687747
:return: numpy recarray (no attrs) or tuple of arrays (with attrs).
688748
With include_key=True: (keys, *arrays) where keys is list[dict]
689749
@@ -713,7 +773,7 @@ def to_arrays(self, *attrs, include_key=False, order_by=None, limit=None, offset
713773

714774
# Project to only needed columns
715775
projected = expr.proj(*fetch_attrs)
716-
dicts = projected.to_dicts(squeeze=squeeze, download_path=download_path)
776+
dicts = projected.to_dicts(squeeze=squeeze)
717777

718778
# Extract keys if requested
719779
if include_key:
@@ -736,7 +796,7 @@ def to_arrays(self, *attrs, include_key=False, order_by=None, limit=None, offset
736796
return result_arrays[0] if len(attrs) == 1 else tuple(result_arrays)
737797
else:
738798
# Fetch all columns as structured array
739-
get = partial(_get, expr.connection, squeeze=squeeze, download_path=download_path)
799+
get = partial(decode_attribute, squeeze=squeeze)
740800
cursor = expr.cursor(as_dict=False)
741801
rows = list(cursor.fetchall())
742802

@@ -842,10 +902,7 @@ def __iter__(self):
842902
cursor = self.cursor(as_dict=True)
843903
heading = self.heading
844904
for row in cursor:
845-
yield {
846-
name: _get(self.connection, heading[name], row[name], squeeze=False, download_path=".")
847-
for name in heading.names
848-
}
905+
yield {name: decode_attribute(heading[name], row[name], squeeze=False) for name in heading.names}
849906

850907
def cursor(self, as_dict=False):
851908
"""

0 commit comments

Comments
 (0)