Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 15 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,19 @@
# Changelog
## [Latest](https://github.com/int-brain-lab/ONE/commits/main) [3.2.1]
## [Latest](https://github.com/int-brain-lab/ONE/commits/main) [3.3.0]
This version makes some performance improvements and supports local mode eid2pid/pid2eid.

### Modified

- cache ALFPath properties
- improvements to OneAlyx.get_details performance
- update cache tables with calls to OneAlyx.eid2pid, pid2eid, eid2path, path2eid, get_details
- OneAlyx.eid2pid and pid2eid support local mode queries so long as the insertions table exists

### Added

- AlyxClient.rest_cache_dir property allows users to change location of REST response cache

## [3.2.1]

### Modified

Expand Down
2 changes: 1 addition & 1 deletion one/__init__.py
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
"""The Open Neurophysiology Environment (ONE) API."""
__version__ = '3.2.1'
__version__ = '3.3.0'
7 changes: 4 additions & 3 deletions one/alf/path.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
"""
import os
import pathlib
from functools import cached_property
from collections import OrderedDict
from datetime import datetime
from typing import Union, Optional, Iterable
Expand Down Expand Up @@ -844,17 +845,17 @@ def parse_alf_name(self, as_dict=True):
"""
return filename_parts(self.name, assert_valid=False, as_dict=as_dict)

@property
@cached_property
def dataset_name_parts(self):
"""tuple of str: the dataset name parts, with empty strings for missing parts."""
return tuple(p or '' for p in self.parse_alf_name(as_dict=False))

@property
@cached_property
def session_parts(self):
"""tuple of str: the session path parts, with empty strings for missing parts."""
return tuple(p or '' for p in session_path_parts(self, assert_valid=False))

@property
@cached_property
def alf_parts(self):
"""tuple of str: the full ALF path parts, with empty strings for missing parts."""
return tuple(p or '' for p in self.parse_alf_path(as_dict=False))
Expand Down
105 changes: 78 additions & 27 deletions one/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from weakref import WeakMethod
from datetime import datetime, timedelta
from functools import lru_cache, partial
from itertools import islice
from inspect import unwrap
from pathlib import Path, PurePosixPath
from typing import Any, Union, Optional, List
Expand All @@ -28,7 +29,7 @@
import one.alf.io as alfio
import one.alf.path as alfiles
import one.alf.exceptions as alferr
from one.alf.path import ALFPath
from one.alf.path import ALFPath, ensure_alf_path
from .alf.cache import (
make_parquet_db, load_tables, remove_table_files, merge_tables,
default_cache, cast_index_object)
Expand Down Expand Up @@ -2020,9 +2021,26 @@ def pid2eid(self, pid: str, query_type=None) -> (UUID, str):

"""
query_type = query_type or self.mode
if query_type == 'local' and 'insertions' not in self._cache.keys():
raise NotImplementedError('Converting probe IDs required remote connection')
rec = self.alyx.rest('insertions', 'read', id=str(pid))
if query_type == 'local': # and 'insertions' not in self._cache.keys():
if 'insertions' not in self._cache.keys():
raise NotImplementedError('Converting probe IDs requires remote connection')
else:
# If local, use the cache table
pid = UUID(pid) if isinstance(pid, str) else pid
try:
rec = self._cache['insertions'].loc[pd.IndexSlice[:, pid], 'name']
(eid, _), name = next(rec.items())
return eid, name
except KeyError:
return None, None
try:
rec = self.alyx.rest('insertions', 'read', id=pid)
except requests.exceptions.HTTPError as ex:
if ex.response.status_code == 404:
_logger.error(f'Probe {pid} not found in Alyx')
return None, None
raise ex
self._update_insertions_table([rec])
return UUID(rec['session']), rec['name']

def eid2pid(self, eid, query_type=None, details=False, **kwargs) -> (UUID, str, list):
Expand Down Expand Up @@ -2063,15 +2081,33 @@ def eid2pid(self, eid, query_type=None, details=False, **kwargs) -> (UUID, str,
"""
query_type = query_type or self.mode
if query_type == 'local' and 'insertions' not in self._cache.keys():
raise NotImplementedError('Converting probe IDs required remote connection')
raise NotImplementedError('Converting to probe ID requires remote connection')
eid = self.to_eid(eid) # Ensure we have a UUID str
if not eid:
return (None,) * (3 if details else 2)
recs = self.alyx.rest('insertions', 'list', session=eid, **kwargs)
pids = [UUID(x['id']) for x in recs]
labels = [x['name'] for x in recs]
if query_type == 'local':
try: # If local, use the cache table
rec = self._cache['insertions'].loc[(eid,), :]
pids, names = map(list, zip(*rec.sort_values('name')['name'].items()))
if details:
rec['session'] = str(eid)
session_info = self._cache['sessions'].loc[eid].to_dict()
session_info['date'] = session_info['date'].isoformat()
session_info['projects'] = session_info['projects'].split(',')
rec['session_info'] = session_info
# Convert to list of dicts after casting UUIDs to strings
recs = cast_index_object(rec, str).reset_index().to_dict('records')
return pids, names, recs
return pids, names
except KeyError:
return (None,) * (3 if details else 2)

if recs := self.alyx.rest('insertions', 'list', session=eid, **kwargs):
self._update_insertions_table(recs)
pids = [UUID(x['id']) for x in recs] or None
labels = [x['name'] for x in recs] or None
if details:
return pids, labels, recs
return pids, labels, recs or None
else:
return pids, labels

Expand Down Expand Up @@ -2757,10 +2793,11 @@ def eid2path(self, eid, query_type=None) -> Listable(ALFPath):
return [unwrapped(self, e, query_type='remote') for e in eid]

# if it wasn't successful, query Alyx
ses = self.alyx.rest('sessions', 'list', django=f'pk,{str(eid)}')
ses = self.alyx.rest('sessions', 'list', id=eid)
if len(ses) == 0:
return None
else:
self._update_sessions_table(ses)
return ALFPath(self.cache_dir).joinpath(
ses[0]['lab'], 'Subjects', ses[0]['subject'], ses[0]['start_time'][:10],
str(ses[0]['number']).zfill(3))
Expand Down Expand Up @@ -2788,7 +2825,7 @@ def path2eid(self, path_obj: Union[str, Path], query_type=None) -> Listable(str)
eid_list.append(self.path2eid(p))
return eid_list
# else ensure the path ends with mouse, date, number
path_obj = ALFPath(path_obj)
path_obj = ensure_alf_path(path_obj)

# try the cached info to possibly avoid hitting database
mode = query_type or self.mode
Expand Down Expand Up @@ -2969,26 +3006,40 @@ def get_details(self, eid: str, full: bool = False, query_type=None):
[Errno 404] Remote session not found on Alyx.

"""
def process(d, root=self.cache_dir):
"""Returns dict in similar format to One.search output."""
det_fields = ['subject', 'start_time', 'number', 'lab', 'projects',
'url', 'task_protocol', 'local_path']
out = {k: v for k, v in d.items() if k in det_fields}
out['projects'] = ','.join(out['projects'])
out['date'] = datetime.fromisoformat(out['start_time']).date()
out['local_path'] = session_record2path(out, root)
return out

if (query_type or self.mode) == 'local':
return super().get_details(eid, full=full)
# If eid is a list of eIDs recurse through list and return the results
if isinstance(eid, (list, util.LazyId)):
details_list = []
for p in eid:
details_list.append(self.get_details(p, full=full))
return details_list
# load all details
dets = self.alyx.rest('sessions', 'read', eid)
eids = ensure_list(eid)
details = dict.fromkeys(map(str, eids), None) # create map to skip duplicates
if full:
return dets
# If it's not full return the normal output like from a one.search
det_fields = ['subject', 'start_time', 'number', 'lab', 'projects',
'url', 'task_protocol', 'local_path']
out = {k: v for k, v in dets.items() if k in det_fields}
out['projects'] = ','.join(out['projects'])
out.update({'local_path': self.eid2path(eid),
'date': datetime.fromisoformat(out['start_time']).date()})
return out
for e in details:
# check for duplicates
details[e] = self.alyx.rest('sessions', 'read', id=e)
session, datasets = ses2records(details[e])
merge_tables(
self._cache, sessions=session, datasets=datasets.copy(),
origin=self.alyx.base_url)
details = [details[str(e)].copy() for e in eids]
else:
# batch to ensure the list is not too long for the GET request
iterator = iter(details.keys())
while batch := tuple(islice(iterator, 50)):
ret = self.alyx.rest('sessions', 'list', django=f'pk__in,{batch}')
details.update({d['id']: d for d in ret})
self._update_sessions_table(details.values())
details = [process(details[str(e)]) for e in eids]
# Return either a single dict or a list of dicts depending on the input type
return (details if isinstance(eid, (list, util.LazyId)) else details[0])


def _setup(**kwargs):
Expand Down

Large diffs are not rendered by default.

33 changes: 33 additions & 0 deletions one/tests/test_converters.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import unittest
from unittest import mock
from pathlib import Path, PurePosixPath, PureWindowsPath
from requests.exceptions import HTTPError
from uuid import UUID, uuid4
import datetime

Expand Down Expand Up @@ -338,11 +339,29 @@ def test_path2eid(self):

def test_pid2eid(self):
"""Test for OneAlyx.pid2eid method."""
if 'insertions' in self.one._cache:
del self.one._cache['insertions']
self.assertRaises(NotImplementedError, self.one.pid2eid, self.pid, query_type='local')
self.assertEqual((self.eid, 'probe00'), self.one.pid2eid(self.pid))
# Check cache table updated
self.assertIn('insertions', self.one._cache)
self.assertIn(self.eid, self.one._cache['insertions'].index)
# Local mode should now work
self.assertEqual((self.eid, 'probe00'), self.one.pid2eid(self.pid, query_type='local'))
# Test behaviour when pid not found
pid = UUID('00000000-0000-0000-0000-000000000000')
self.assertEqual((None, None), self.one.pid2eid(pid, query_type='local'))
self.assertEqual((None, None), self.one.pid2eid(pid, query_type='remote'))
# Non-404 status code should raise
err = HTTPError()
err.response = self.one._cache.__class__({'status_code': 500})
with mock.patch.object(self.one.alyx, 'get', side_effect=err):
self.assertRaises(HTTPError, self.one.pid2eid, pid, query_type='remote')

def test_eid2pid(self):
"""Test for OneAlyx.eid2pid method."""
if 'insertions' in self.one._cache:
del self.one._cache['insertions']
self.assertRaises(NotImplementedError, self.one.eid2pid, self.eid, query_type='local')
# Check invalid eid
self.assertEqual((None, None), self.one.eid2pid(None))
Expand All @@ -357,6 +376,20 @@ def test_eid2pid(self):
expected_keys = {'id', 'name', 'model', 'serial'}
for d in det:
self.assertTrue(set(d.keys()) >= expected_keys)
# Check cache table updated
cache = self.one._cache
self.assertIn('insertions', cache)
self.assertTrue(cache['insertions'].index.get_level_values(1).isin(expected[0]).all())
# Check local mode should now work
self.assertEqual(expected, self.one.eid2pid(self.eid, query_type='local'))
*_, det = self.one.eid2pid(self.eid, details=True, query_type='local')
for d in det:
self.assertTrue(set(d.keys()) >= expected_keys)
# Check behaviour when eid not found in local mode
eid = UUID('00000000-0000-0000-0000-000000000000')
self.assertEqual((None, None), self.one.eid2pid(eid, query_type='local'))
out = self.one.eid2pid(eid, query_type='local', details=True)
self.assertEqual((None, None, None), out)

def test_ses2records(self):
"""Test one.converters.ses2records function."""
Expand Down
20 changes: 15 additions & 5 deletions one/tests/test_one.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@
validate_date_range, index_last_before, filter_datasets, _collection_spec,
filter_revision_last_before, parse_id, autocomplete, LazyId
)
from one.webclient import AlyxClient
import one.params
import one.alf.exceptions as alferr
from one.converters import datasets2records
Expand Down Expand Up @@ -1238,13 +1239,15 @@ def test_dataset2type(self):
self.one.dataset2type(bad_id)

def test_pid2eid(self):
"""Test OneAlyx.pid2eid."""
"""Test OneAlyx.pid2eid.

For a more complete test see `test_converters.TestOnlineConverters.test_pid2eid`.
This test uses the REST fixtures and therefore can be run offline.
"""
pid = 'b529f2d8-cdae-4d59-aba2-cbd1b5572e36'
eid, collection = self.one.pid2eid(pid, query_type='remote')
self.assertEqual(UUID('fc737f3c-2a57-4165-9763-905413e7e341'), eid)
self.assertEqual('probe00', collection)
with self.assertRaises(NotImplementedError):
self.one.pid2eid(pid, query_type='local')

@unittest.mock.patch('sys.stdout', new_callable=io.StringIO)
def test_describe_revision(self, mock_stdout):
Expand Down Expand Up @@ -1845,10 +1848,17 @@ def test_get_details(self):
self.assertEqual(1, det['number'])
self.assertNotIn('data_dataset_session_related', det)

# Test list
det = self.one.get_details([self.eid, self.eid], full=True)
# Test with a list
# For duplicate eids, we should avoid multiple queries
with mock.patch.object(AlyxClient, 'rest', wraps=self.one.alyx.rest) as m:
det = self.one.get_details([self.eid, self.eid], full=True)
m.assert_called_once_with('sessions', 'read', id=str(self.eid))
self.assertIsInstance(det, list)
self.assertEqual(2, len(det))
self.assertIn('data_dataset_session_related', det[0])
# Check that the details dicts are copies (modifying one should not affect the other)
self.assertEqual(det[0], det[1]) # details should be the same
self.assertIsNot(det[0], det[1]) # should be different objects

def test_cache_buildup(self):
"""Test build up of cache table via remote queries.
Expand Down
7 changes: 4 additions & 3 deletions one/webclient.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ def wrapper_decorator(alyx_client, *args, expires=None, clobber=False, **kwargs)
response will not be used on subsequent calls. If None, the default expiry is applied.
clobber : bool
If True any existing cached response is overwritten.
**kwargs
kwargs
Keyword arguments for applying to wrapped function.

Returns
Expand All @@ -123,7 +123,7 @@ def wrapper_decorator(alyx_client, *args, expires=None, clobber=False, **kwargs)
if args[0].__name__ != mode and mode != '*':
return method(alyx_client, *args, **kwargs)
# Check cache
rest_cache = alyx_client.cache_dir.joinpath('.rest')
rest_cache = alyx_client.rest_cache_dir
sha1 = hashlib.sha1()
sha1.update(bytes(args[1], 'utf-8'))
name = sha1.hexdigest()
Expand Down Expand Up @@ -577,6 +577,7 @@ def __init__(self, base_url=None, username=None, password=None,
# The default expiry is overridden by the `expires` kwarg. If False, the caching is
# turned off.
self.default_expiry = timedelta(minutes=5)
self.rest_cache_dir = self.cache_dir.joinpath('.rest')
self.cache_mode = cache_rest
self._obj_id = id(self)

Expand Down Expand Up @@ -1368,5 +1369,5 @@ def json_field_delete(

def clear_rest_cache(self):
"""Clear all REST response cache files for the base url."""
for file in self.cache_dir.joinpath('.rest').glob('*'):
for file in self.rest_cache_dir.glob('*'):
file.unlink()