7777
7878from neo import logging_handler
7979
80+ from .utils import get_memmap_chunk_from_opened_file
81+
8082
8183possible_raw_modes = [
8284 "one-file" ,
@@ -182,6 +184,15 @@ def __init__(self, use_cache: bool = False, cache_path: str = "same_as_resource"
182184 self .header = None
183185 self .is_header_parsed = False
184186
187+ self ._has_buffer_description_api = False
188+
189+ def has_buffer_description_api (self ) -> bool :
190+ """
191+ Return if the reader handle the buffer API.
192+ If True then the reader support internally `get_analogsignal_buffer_description()`
193+ """
194+ return self ._has_buffer_description_api
195+
185196 def parse_header (self ):
186197 """
187198 Parses the header of the file(s) to allow for faster computations
@@ -191,6 +202,7 @@ def parse_header(self):
191202 # this must create
192203 # self.header['nb_block']
193204 # self.header['nb_segment']
205+ # self.header['signal_buffers']
194206 # self.header['signal_streams']
195207 # self.header['signal_channels']
196208 # self.header['spike_channels']
@@ -663,6 +675,7 @@ def get_signal_size(self, block_index: int, seg_index: int, stream_index: int |
663675
664676 """
665677 stream_index = self ._get_stream_index_from_arg (stream_index )
678+
666679 return self ._get_signal_size (block_index , seg_index , stream_index )
667680
668681 def get_signal_t_start (self , block_index : int , seg_index : int , stream_index : int | None = None ):
@@ -1311,7 +1324,6 @@ def _get_analogsignal_chunk(
13111324 -------
13121325 array of samples, with each requested channel in a column
13131326 """
1314-
13151327 raise (NotImplementedError )
13161328
13171329 ###
@@ -1350,6 +1362,152 @@ def _rescale_event_timestamp(self, event_timestamps: np.ndarray, dtype: np.dtype
13501362 def _rescale_epoch_duration (self , raw_duration : np .ndarray , dtype : np .dtype ):
13511363 raise (NotImplementedError )
13521364
1365+ ###
1366+ # buffer api zone
1367+ # must be implemented if has_buffer_description_api=True
1368+ def get_analogsignal_buffer_description (self , block_index : int = 0 , seg_index : int = 0 , buffer_id : str = None ):
1369+ if not self .has_buffer_description_api :
1370+ raise ValueError ("This reader do not support buffer_description API" )
1371+ descr = self ._get_analogsignal_buffer_description (block_index , seg_index , buffer_id )
1372+ return descr
1373+
1374+ def _get_analogsignal_buffer_description (self , block_index , seg_index , buffer_id ):
1375+ raise (NotImplementedError )
1376+
1377+
1378+ class BaseRawWithBufferApiIO (BaseRawIO ):
1379+ """
1380+ Generic class for reader that support "buffer api".
1381+
1382+ In short reader that are internally based on:
1383+
1384+ * np.memmap
1385+ * hdf5
1386+
1387+ In theses cases _get_signal_size and _get_analogsignal_chunk are totaly generic and do not need to be implemented in the class.
1388+
1389+ For this class sub classes must implements theses two dict:
1390+ * self._buffer_descriptions[block_index][seg_index] = buffer_description
1391+ * self._stream_buffer_slice[buffer_id] = None or slicer o indices
1392+
1393+ """
1394+
1395+ def __init__ (self , * arg , ** kwargs ):
1396+ super ().__init__ (* arg , ** kwargs )
1397+ self ._has_buffer_description_api = True
1398+
1399+ def _get_signal_size (self , block_index , seg_index , stream_index ):
1400+ buffer_id = self .header ["signal_streams" ][stream_index ]["buffer_id" ]
1401+ buffer_desc = self .get_analogsignal_buffer_description (block_index , seg_index , buffer_id )
1402+ # some hdf5 revert teh buffer
1403+ time_axis = buffer_desc .get ("time_axis" , 0 )
1404+ return buffer_desc ["shape" ][time_axis ]
1405+
1406+ def _get_analogsignal_chunk (
1407+ self ,
1408+ block_index : int ,
1409+ seg_index : int ,
1410+ i_start : int | None ,
1411+ i_stop : int | None ,
1412+ stream_index : int ,
1413+ channel_indexes : list [int ] | None ,
1414+ ):
1415+
1416+ stream_id = self .header ["signal_streams" ][stream_index ]["id" ]
1417+ buffer_id = self .header ["signal_streams" ][stream_index ]["buffer_id" ]
1418+
1419+ buffer_slice = self ._stream_buffer_slice [stream_id ]
1420+
1421+ buffer_desc = self .get_analogsignal_buffer_description (block_index , seg_index , buffer_id )
1422+
1423+ i_start = i_start or 0
1424+ i_stop = i_stop or buffer_desc ["shape" ][0 ]
1425+
1426+ if buffer_desc ["type" ] == "raw" :
1427+
1428+ # open files on demand and keep reference to opened file
1429+ if not hasattr (self , "_memmap_analogsignal_buffers" ):
1430+ self ._memmap_analogsignal_buffers = {}
1431+ if block_index not in self ._memmap_analogsignal_buffers :
1432+ self ._memmap_analogsignal_buffers [block_index ] = {}
1433+ if seg_index not in self ._memmap_analogsignal_buffers [block_index ]:
1434+ self ._memmap_analogsignal_buffers [block_index ][seg_index ] = {}
1435+ if buffer_id not in self ._memmap_analogsignal_buffers [block_index ][seg_index ]:
1436+ fid = open (buffer_desc ["file_path" ], mode = "rb" )
1437+ self ._memmap_analogsignal_buffers [block_index ][seg_index ][buffer_id ] = fid
1438+ else :
1439+ fid = self ._memmap_analogsignal_buffers [block_index ][seg_index ][buffer_id ]
1440+
1441+ num_channels = buffer_desc ["shape" ][1 ]
1442+
1443+ raw_sigs = get_memmap_chunk_from_opened_file (
1444+ fid ,
1445+ num_channels ,
1446+ i_start ,
1447+ i_stop ,
1448+ np .dtype (buffer_desc ["dtype" ]),
1449+ file_offset = buffer_desc ["file_offset" ],
1450+ )
1451+
1452+ elif buffer_desc ["type" ] == "hdf5" :
1453+
1454+ # open files on demand and keep reference to opened file
1455+ if not hasattr (self , "_hdf5_analogsignal_buffers" ):
1456+ self ._hdf5_analogsignal_buffers = {}
1457+ if block_index not in self ._hdf5_analogsignal_buffers :
1458+ self ._hdf5_analogsignal_buffers [block_index ] = {}
1459+ if seg_index not in self ._hdf5_analogsignal_buffers [block_index ]:
1460+ self ._hdf5_analogsignal_buffers [block_index ][seg_index ] = {}
1461+ if buffer_id not in self ._hdf5_analogsignal_buffers [block_index ][seg_index ]:
1462+ import h5py
1463+
1464+ h5file = h5py .File (buffer_desc ["file_path" ], mode = "r" )
1465+ self ._hdf5_analogsignal_buffers [block_index ][seg_index ][buffer_id ] = h5file
1466+ else :
1467+ h5file = self ._hdf5_analogsignal_buffers [block_index ][seg_index ][buffer_id ]
1468+
1469+ hdf5_path = buffer_desc ["hdf5_path" ]
1470+ full_raw_sigs = h5file [hdf5_path ]
1471+
1472+ time_axis = buffer_desc .get ("time_axis" , 0 )
1473+ if time_axis == 0 :
1474+ raw_sigs = full_raw_sigs [i_start :i_stop , :]
1475+ elif time_axis == 1 :
1476+ raw_sigs = full_raw_sigs [:, i_start :i_stop ].T
1477+ else :
1478+ raise RuntimeError ("Should never happen" )
1479+
1480+ if buffer_slice is not None :
1481+ raw_sigs = raw_sigs [:, buffer_slice ]
1482+
1483+ else :
1484+ raise NotImplementedError ()
1485+
1486+ # this is a pre slicing when the stream do not contain all channels (for instance spikeglx when load_sync_channel=False)
1487+ if buffer_slice is not None :
1488+ raw_sigs = raw_sigs [:, buffer_slice ]
1489+
1490+ # channel slice requested
1491+ if channel_indexes is not None :
1492+ raw_sigs = raw_sigs [:, channel_indexes ]
1493+
1494+ return raw_sigs
1495+
1496+ def __del__ (self ):
1497+ if hasattr (self , "_memmap_analogsignal_buffers" ):
1498+ for block_index in self ._memmap_analogsignal_buffers .keys ():
1499+ for seg_index in self ._memmap_analogsignal_buffers [block_index ].keys ():
1500+ for buffer_id , fid in self ._memmap_analogsignal_buffers [block_index ][seg_index ].items ():
1501+ fid .close ()
1502+ del self ._memmap_analogsignal_buffers
1503+
1504+ if hasattr (self , "_hdf5_analogsignal_buffers" ):
1505+ for block_index in self ._hdf5_analogsignal_buffers .keys ():
1506+ for seg_index in self ._hdf5_analogsignal_buffers [block_index ].keys ():
1507+ for buffer_id , h5_file in self ._hdf5_analogsignal_buffers [block_index ][seg_index ].items ():
1508+ h5_file .close ()
1509+ del self ._hdf5_analogsignal_buffers
1510+
13531511
13541512def pprint_vector (vector , lim : int = 8 ):
13551513 vector = np .asarray (vector )
0 commit comments