Skip to content

Commit f081477

Browse files
Add include_key parameter to to_arrays() for primary key retrieval
- to_arrays('a', 'b', include_key=True) now returns (keys, a, b) where keys is a list of dicts containing primary key columns - Keys are in the same format as table.keys(), usable for restrictions - Added comprehensive tests for the new functionality 🤖 Generated with [Claude Code](https://claude.com/claude-code) Co-Authored-By: Claude Opus 4.5 <[email protected]>
1 parent 4164348 commit f081477

File tree

3 files changed

+75
-3
lines changed

3 files changed

+75
-3
lines changed

src/datajoint/expression.py

Lines changed: 21 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -678,13 +678,26 @@ def to_arrays(self, *attrs, include_key=False, order_by=None, limit=None, offset
678678
If attrs specified, returns a tuple of numpy arrays (one per attribute).
679679
680680
:param attrs: attribute names to fetch (if empty, fetch all)
681-
:param include_key: if True and attrs specified, include primary key columns
681+
:param include_key: if True and attrs specified, prepend primary keys as list of dicts
682682
:param order_by: attribute(s) to order by, or "KEY"/"KEY DESC"
683683
:param limit: maximum number of rows to return
684684
:param offset: number of rows to skip
685685
:param squeeze: if True, remove extra dimensions from arrays
686686
:param download_path: path for downloading external data
687-
:return: numpy recarray (no attrs) or tuple of arrays (with attrs)
687+
:return: numpy recarray (no attrs) or tuple of arrays (with attrs).
688+
With include_key=True: (keys, *arrays) where keys is list[dict]
689+
690+
Examples::
691+
692+
# Fetch as structured array
693+
data = table.to_arrays()
694+
695+
# Fetch specific columns as separate arrays
696+
a, b = table.to_arrays('a', 'b')
697+
698+
# Fetch with primary keys for later restrictions
699+
keys, a, b = table.to_arrays('a', 'b', include_key=True)
700+
# keys = [{'id': 1}, {'id': 2}, ...] # same format as table.keys()
688701
"""
689702
from functools import partial
690703

@@ -702,6 +715,10 @@ def to_arrays(self, *attrs, include_key=False, order_by=None, limit=None, offset
702715
projected = expr.proj(*fetch_attrs)
703716
dicts = projected.to_dicts(squeeze=squeeze, download_path=download_path)
704717

718+
# Extract keys if requested
719+
if include_key:
720+
keys = [{k: d[k] for k in expr.primary_key} for d in dicts]
721+
705722
# Extract arrays for requested attributes
706723
result_arrays = []
707724
for attr in attrs:
@@ -714,6 +731,8 @@ def to_arrays(self, *attrs, include_key=False, order_by=None, limit=None, offset
714731
arr = np.array(values, dtype=object)
715732
result_arrays.append(arr)
716733

734+
if include_key:
735+
return (keys, *result_arrays)
717736
return result_arrays[0] if len(attrs) == 1 else tuple(result_arrays)
718737
else:
719738
# Fetch all columns as structured array

src/datajoint/version.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
11
# version bump auto managed by Github Actions:
22
# label_prs.yaml(prep), release.yaml(bump), post_release.yaml(edit)
33
# manually set this version will be eventually overwritten by the above actions
4-
__version__ = "2.0.0a12"
4+
__version__ = "2.0.0a13"

tests/integration/test_fetch.py

Lines changed: 53 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -344,3 +344,56 @@ def test_lazy_iteration(lang, languages):
344344
first = next(iter_obj)
345345
assert isinstance(first, dict)
346346
assert "name" in first and "language" in first
347+
348+
349+
def test_to_arrays_include_key(lang, languages):
350+
"""Test to_arrays with include_key=True returns keys as list of dicts"""
351+
# Fetch with include_key=True
352+
keys, names, langs = lang.to_arrays("name", "language", include_key=True, order_by="KEY")
353+
354+
# keys should be a list of dicts with primary key columns
355+
assert isinstance(keys, list)
356+
assert all(isinstance(k, dict) for k in keys)
357+
assert all(set(k.keys()) == {"name", "language"} for k in keys)
358+
359+
# names and langs should be numpy arrays
360+
assert isinstance(names, np.ndarray)
361+
assert isinstance(langs, np.ndarray)
362+
363+
# Length should match
364+
assert len(keys) == len(names) == len(langs) == len(languages)
365+
366+
# Keys should match the data
367+
for key, name, language in zip(keys, names, langs):
368+
assert key["name"] == name
369+
assert key["language"] == language
370+
371+
# Keys should be usable for restrictions
372+
first_key = keys[0]
373+
restricted = lang & first_key
374+
assert len(restricted) == 1
375+
assert restricted.fetch1("name") == first_key["name"]
376+
377+
378+
def test_to_arrays_include_key_single_attr(subject):
379+
"""Test to_arrays include_key with single attribute"""
380+
keys, species = subject.to_arrays("species", include_key=True)
381+
382+
assert isinstance(keys, list)
383+
assert isinstance(species, np.ndarray)
384+
assert len(keys) == len(species)
385+
386+
# Verify keys have only primary key columns
387+
assert all("subject_id" in k for k in keys)
388+
389+
390+
def test_to_arrays_without_include_key(lang):
391+
"""Test that to_arrays without include_key doesn't return keys"""
392+
result = lang.to_arrays("name", "language")
393+
394+
# Should return tuple of arrays, not (keys, ...)
395+
assert isinstance(result, tuple)
396+
assert len(result) == 2
397+
names, langs = result
398+
assert isinstance(names, np.ndarray)
399+
assert isinstance(langs, np.ndarray)

0 commit comments

Comments
 (0)