diff --git a/bmtk/simulator/bionet/modules/record_cellvars.py b/bmtk/simulator/bionet/modules/record_cellvars.py index 6204408c5..8bef42606 100644 --- a/bmtk/simulator/bionet/modules/record_cellvars.py +++ b/bmtk/simulator/bionet/modules/record_cellvars.py @@ -28,17 +28,6 @@ from bmtk.simulator.bionet.io_tools import io from bmtk.utils.io import cell_vars -try: - # Check to see if h5py is built to run in parallel - if h5py.get_config().mpi: - MembraneRecorder = cell_vars.CellVarRecorderParallel - else: - MembraneRecorder = cell_vars.CellVarRecorder - -except Exception as e: - MembraneRecorder = cell_vars.CellVarRecorder - -MembraneRecorder._io = io pc = h.ParallelContext() MPI_RANK = int(pc.id()) @@ -86,8 +75,12 @@ def __init__(self, tmp_dir, file_name, variable_name, cells, sections='all', buf self._local_gids = [] self._sections = sections - self._var_recorder = MembraneRecorder(self._file_name, self._tmp_dir, self._all_variables, - buffer_data=buffer_data, mpi_rank=MPI_RANK, mpi_size=N_HOSTS) + recorder_cls = cell_vars.get_cell_var_recorder_cls(file_name) + recorder_cls._io = io + self._var_recorder = recorder_cls( + self._file_name, self._tmp_dir, self._all_variables, + buffer_data=buffer_data, mpi_rank=MPI_RANK, mpi_size=N_HOSTS + ) self._gid_list = [] # list of all gids that will have their variables saved self._data_block = {} # table of variable data indexed by [gid][variable] @@ -119,7 +112,7 @@ def initialize(self, sim): # TODO: Make sure the seg has the recorded variable(s) sec_list.append(sec_id) seg_list.append(seg.x) - + self._var_recorder.add_cell(gid, sec_list, seg_list) self._var_recorder.initialize(sim.n_steps, sim.nsteps_block) diff --git a/bmtk/simulator/bionet/modules/record_netcons.py b/bmtk/simulator/bionet/modules/record_netcons.py index 49552c916..28224ad4a 100644 --- a/bmtk/simulator/bionet/modules/record_netcons.py +++ b/bmtk/simulator/bionet/modules/record_netcons.py @@ -10,14 +10,6 @@ from bmtk.simulator.bionet.pointprocesscell import PointProcessCell from bmtk.utils.io import cell_vars -try: - # Check to see if h5py is built to run in parallel - if h5py.get_config().mpi: - MembraneRecorder = cell_vars.CellVarRecorderParallel - else: - MembraneRecorder = cell_vars.CellVarRecorder -except Exception as e: - MembraneRecorder = cell_vars.CellVarRecorder pc = h.ParallelContext() MPI_RANK = int(pc.id()) @@ -46,9 +38,13 @@ def __init__(self, tmp_dir, file_name, variable_name, cells, sections='all', syn self._all_gids = cells self._local_gids = [] self._sections = sections - - self._var_recorder = MembraneRecorder(self._file_name, self._tmp_dir, self._all_variables, - buffer_data=buffer_data, mpi_rank=MPI_RANK, mpi_size=N_HOSTS) + + recorder_cls = cell_vars.get_cell_var_recorder_cls(file_name) + recorder_cls._io = io + self._var_recorder = recorder_cls( + self._file_name, self._tmp_dir, self._all_variables, + buffer_data=buffer_data, mpi_rank=MPI_RANK, mpi_size=N_HOSTS + ) self._virt_lookup = {} self._gid_lookup = {} diff --git a/bmtk/simulator/pointnet/modules/multimeter_reporter.py b/bmtk/simulator/pointnet/modules/multimeter_reporter.py index 12d86ac83..b57ad1d8c 100644 --- a/bmtk/simulator/pointnet/modules/multimeter_reporter.py +++ b/bmtk/simulator/pointnet/modules/multimeter_reporter.py @@ -1,7 +1,7 @@ import os import glob import pandas as pd -from bmtk.utils.io.cell_vars import CellVarRecorder +from bmtk.utils.io.cell_vars import CellVarRecorderH5 as CellVarRecorder from bmtk.simulator.pointnet.io_tools import io import nest diff --git a/bmtk/utils/io/cell_vars.py b/bmtk/utils/io/cell_vars.py index 408176c1d..74f0971cf 100644 --- a/bmtk/utils/io/cell_vars.py +++ b/bmtk/utils/io/cell_vars.py @@ -1,7 +1,12 @@ import os +from datetime import datetime +from collections import defaultdict import h5py import numpy as np +from pynwb import NWBFile, NWBHDF5IO +from nwbext_simulation_output import Compartments, CompartmentSeries + from bmtk.utils import io from bmtk.utils.sonata.utils import add_hdf5_magic, add_hdf5_version @@ -11,17 +16,38 @@ comm = MPI.COMM_WORLD rank = comm.Get_rank() nhosts = comm.Get_size() - except Exception as exc: - pass + comm = None + rank = 1 + + +def get_cell_var_recorder_cls(file_name): + """Return the right class for recording cellvars based on the filename and whether parallel h5py is enabled""" + try: + in_mpi = h5py.get_config().mpi + except Exception as e: + in_mpi = False + + if file_name.endswith('.nwb'): + # NWB + if in_mpi: + return CellVarRecorderNWBParallel + else: + return CellVarRecorderNWB + else: + # HDF5 + if in_mpi: + return CellVarRecorderH5Parallel + else: + return CellVarRecorderH5 -class CellVarRecorder(object): +class CellVarRecorderH5(object): """Used to save cell membrane variables (V, Ca2+, etc) to the described hdf5 format. For parallel simulations this class will write to a seperate tmp file on each rank, then use the merge method to combine the results. This is less efficent, but doesn't require the user to install mpi4py and build h5py in - parallel mode. For better performance use the CellVarRecorderParrallel class instead. + parallel mode. For better performance use one of the CellVarRecorder{H5,NWB}Parallel classes instead. """ _io = io @@ -36,7 +62,7 @@ def __init__(self, var_name): def __init__(self, file_name, tmp_dir, variables, buffer_data=True, mpi_rank=0, mpi_size=1): self._file_name = file_name - self._h5_handle = None + self._file_handle = None self._tmp_dir = tmp_dir self._variables = variables if isinstance(variables, list) else [variables] self._n_vars = len(self._variables) # Used later to keep track if more than one var is saved to the same file. @@ -46,7 +72,8 @@ def __init__(self, file_name, tmp_dir, variables, buffer_data=True, mpi_rank=0, self._tmp_files = [] self._saved_file = file_name - if mpi_size > 1: + if mpi_size > 1 and not isinstance(self, ParallelRecorderMixin): + self._io.log_warning('Was unable to run h5py in parallel (mpi) mode.' + ' Saving of membrane variable(s) may slow down.') tmp_fname = os.path.basename(file_name) # make sure file names don't clash if there are multiple reports @@ -56,8 +83,8 @@ def __init__(self, file_name, tmp_dir, variables, buffer_data=True, mpi_rank=0, self._mapping_gids = [] # list of gids in the order they appear in the data self._gid_map = {} # table for looking up the gid offsets - self._map_attrs = {} # Used for additonal attributes in /mapping - + self._map_attrs = defaultdict(list) # Used for additonal attributes in /mapping + self._mapping_element_ids = [] # sections self._mapping_element_pos = [] # segments self._mapping_index = [0] # index_pointer @@ -123,10 +150,10 @@ def _calc_offset(self): self._gids_beg = 0 self._gids_end = self._n_gids_local - def _create_h5_file(self): - self._h5_handle = h5py.File(self._file_name, 'w') - add_hdf5_version(self._h5_handle) - add_hdf5_magic(self._h5_handle) + def _create_file(self, **io_kwargs): + self._file_handle = h5py.File(self._file_name, 'w', **io_kwargs) + add_hdf5_version(self._file_handle) + add_hdf5_magic(self._file_handle) def add_cell(self, gid, sec_list, seg_list, **map_attrs): assert(len(sec_list) == len(seg_list)) @@ -140,16 +167,26 @@ def add_cell(self, gid, sec_list, seg_list, **map_attrs): self._n_segments_local += n_segs self._n_gids_local += 1 for k, v in map_attrs.items(): - if k not in self._map_attrs: - self._map_attrs[k] = v - else: - self._map_attrs[k].extend(v) + self._map_attrs[k].extend(v) def initialize(self, n_steps, buffer_size=0): self._calc_offset() - self._create_h5_file() + self._create_file() + self._init_mapping() + self._total_steps = n_steps + self._buffer_block_size = buffer_size + self._init_buffers() + + if not self._buffer_data: + # If data is not being buffered and instead written to the main block, + # we have to add a rank offset to the gid offset + for gid, gid_offset in self._gid_map.items(): + self._gid_map[gid] = (gid_offset[0] + self._seg_offset_beg, gid_offset[1] + self._seg_offset_beg) + + self._is_initialized = True - var_grp = self._h5_handle.create_group('/mapping') + def _init_mapping(self): + var_grp = self._file_handle.create_group('/mapping') var_grp.create_dataset('gids', shape=(self._n_gids_all,), dtype=np.uint) var_grp.create_dataset('element_id', shape=(self._n_segments_all,), dtype=np.uint) var_grp.create_dataset('element_pos', shape=(self._n_segments_all,), dtype=np.float) @@ -164,32 +201,24 @@ def initialize(self, n_steps, buffer_size=0): var_grp['index_pointer'][self._gids_beg:(self._gids_end+1)] = self._mapping_index for k, v in self._map_attrs.items(): var_grp[k][self._seg_offset_beg:self._seg_offset_end] = v - - self._total_steps = n_steps - self._buffer_block_size = buffer_size - if not self._buffer_data: - # If data is not being buffered and instead written to the main block, we have to add a rank offset - # to the gid offset - for gid, gid_offset in self._gid_map.items(): - self._gid_map[gid] = (gid_offset[0] + self._seg_offset_beg, gid_offset[1] + self._seg_offset_beg) - + + def _init_buffers(self): for var_name, data_tables in self._data_blocks.items(): # If users are trying to save multiple variables in the same file put data table in its own /{var} group # (not sonata compliant). Otherwise the data table is located at the root - data_grp = self._h5_handle if self._n_vars == 1 else self._h5_handle.create_group('/{}'.format(var_name)) + data_grp = self._file_handle if self._n_vars == 1 else self._file_handle.create_group('/{}'.format(var_name)) if self._buffer_data: # Set up in-memory block to buffer recorded variables before writing to the dataset - data_tables.buffer_block = np.zeros((buffer_size, self._n_segments_local), dtype=np.float) - data_tables.data_block = data_grp.create_dataset('data', shape=(n_steps, self._n_segments_all), + data_tables.buffer_block = np.zeros((self._buffer_block_size, self._n_segments_local), dtype=np.float) + data_tables.data_block = data_grp.create_dataset('data', shape=(self._total_steps, self._n_segments_all), dtype=np.float, chunks=True) data_tables.data_block.attrs['variable_name'] = var_name else: # Since we are not buffering data, we just write directly to the on-disk dataset - data_tables.buffer_block = data_grp.create_dataset('data', shape=(n_steps, self._n_segments_all), + data_tables.buffer_block = data_grp.create_dataset('data', shape=(self._total_steps, self._n_segments_all), dtype=np.float, chunks=True) data_tables.buffer_block.attrs['variable_name'] = var_name - self._is_initialized = True def record_cell(self, gid, var_name, seg_vals, tstep): """Record cell parameters. @@ -234,7 +263,7 @@ def flush(self): data_table.data_block[blk_beg:blk_end, :] = data_table.buffer_block[:block_size, :] def close(self): - self._h5_handle.close() + self._file_handle.close() def merge(self): if self._mpi_size > 1 and self._mpi_rank == 0: @@ -290,7 +319,6 @@ def merge(self): gids_ds[beg:end] = tmp_mapping_grp['gids'] index_pointer_ds[beg:(end+1)] = update_index - # combine the /var/data datasets for var_name in self._variables: data_name = '/data' if self._n_vars == 1 else '/{}/data'.format(var_name) @@ -305,33 +333,85 @@ def merge(self): os.remove(tmp_file) -class CellVarRecorderParallel(CellVarRecorder): - """ - Unlike the parent, this take advantage of parallel h5py to writting to the results file across different ranks. +class CellVarRecorderNWB(CellVarRecorderH5): + def __init__(self, file_name, tmp_dir, variables, buffer_data=True, mpi_rank=0, mpi_size=1): + super(CellVarRecorderNWB, self).__init__( + file_name, tmp_dir, variables, buffer_data=buffer_data, + mpi_rank=mpi_rank, mpi_size=mpi_size + ) + self._compartments = Compartments('compartments') + self._compartmentseries = {} + + def _create_file(self, **io_kwargs): + self._nwbio = NWBHDF5IO(self._file_name, 'w', **io_kwargs) + self._file_handle = NWBFile('description', 'id', datetime.now().astimezone()) # TODO: pass in descr, id + + def add_cell(self, gid, sec_list, seg_list, **map_attrs): + if map_attrs: + raise NotImplementedError('Cannot use map_attrs with NWB') # TODO: support this + self._compartments.add_row(number=sec_list, position=seg_list, id=gid) + super(CellVarRecorderNWB, self).add_cell(gid, sec_list, seg_list, **map_attrs) + + def _init_mapping(self): + # Cell/section id and pos are in the Compartments table + # 1/dt is the rate of the recorded datasets. + # tstart was used as session_start_time when creating the NWBFile + # nwb doesn't store tstop + pass + + def _init_buffers(self): + self._file_handle.add_acquisition(self._compartments) + for var_name, data_tables in self._data_blocks.items(): + cs = CompartmentSeries( + var_name, data=np.zeros((self._total_steps, self._n_segments_all)), + compartments=self._compartments, unit='mV', rate=1000.0/self.dt + ) + self._compartmentseries[var_name] = cs + self._file_handle.add_acquisition(cs) + data_tables.buffer_block = np.zeros((self._buffer_block_size, self._n_segments_local), dtype=np.float) + data_tables.data_block = self._compartmentseries[var_name].data + + self._nwbio.write(self._file_handle) + + # Re-read data sets so that pynwb forgets it has them in memory + # (forces immediate write upon modification) + self._nwbio.close() + self._nwbio = NWBHDF5IO(self._file_name, 'a', comm=comm) + self._file_handle = self._nwbio.read() + for var_name, data_tables in self._data_blocks.items(): + self._data_blocks[var_name].data_block = self._file_handle.acquisition[var_name].data + + def close(self): + self._nwbio.close() + + def merge(self): + raise NotImplementedError("Can't merge NWB files across ranks") - """ - def __init__(self, file_name, tmp_dir, variables, buffer_data=True): - super(CellVarRecorder, self).__init__(file_name, tmp_dir, variables, buffer_data=buffer_data, mpi_rank=0, - mpi_size=1) +class ParallelRecorderMixin(): + """ + When inherited along with one of the CellVarRecorder classes, this takes + advantage of parallel h5py to collectively write the results file from multiple ranks. + """ def _calc_offset(self): # iterate through the ranks let rank r determine the offset from rank r-1 for r in range(comm.Get_size()): if rank == r: - if rank < (nhosts - 1): - # pass the num of segments and num of gids to the next rank - offsets = np.array([self._n_segments_local, self._n_gids_local], dtype=np.uint) - comm.Send([offsets, MPI.UNSIGNED_INT], dest=(rank+1)) - if rank > 0: # get num of segments and gids from prev. rank and calculate offsets - offset = np.empty(2, dtype=np.uint) + offsets = np.empty(2, dtype=np.uint) comm.Recv([offsets, MPI.UNSIGNED_INT], source=(r-1)) self._seg_offset_beg = offsets[0] - self._seg_offset_end = self._seg_offset_beg + self._n_segments_local + self._gids_beg = offsets[1] + + self._seg_offset_end = int(self._seg_offset_beg) \ + + int(self._n_segments_local) + self._gids_end = int(self._gids_beg) + int(self._n_gids_local) - self._gids_beg = offset[1] - self._gids_end = self._gids_beg + self._n_gids_local + if rank < (nhosts - 1): + # pass the next rank its offset + offsets = np.array([self._seg_offset_end, self._gids_end], dtype=np.uint) + comm.Send([offsets, MPI.UNSIGNED_INT], dest=(rank+1)) comm.Barrier() @@ -345,10 +425,18 @@ def _calc_offset(self): self._n_segments_all = total_counts[0] self._n_gids_all = total_counts[1] - def _create_h5_file(self): - self._h5_handle = h5py.File(self._file_name, 'w', driver='mpio', comm=MPI.COMM_WORLD) - add_hdf5_version(self._h5_handle) - add_hdf5_magic(self._h5_handle) - def merge(self): pass + + +class CellVarRecorderH5Parallel(ParallelRecorderMixin, CellVarRecorderH5): + def _create_file(self, **io_kwargs): + io_kwargs['driver'] = 'mpio' + io_kwargs['comm'] = comm + super(CellVarRecorderH5Parallel, self)._create_file(**io_kwargs) + + +class CellVarRecorderNWBParallel(ParallelRecorderMixin, CellVarRecorderNWB): + def _create_file(self, **io_kwargs): + io_kwargs['comm'] = comm + super(CellVarRecorderNWBParallel, self)._create_file(**io_kwargs)