|
| 1 | +import os |
1 | 2 | import numpy as np |
2 | 3 | import pandas as pd |
3 | | -import h5py |
| 4 | +from functools import partial |
| 5 | +from six import string_types |
4 | 6 |
|
5 | | - |
6 | | -import matplotlib.pyplot as plt |
7 | | - |
8 | | - |
9 | | -from bmtk.utils.sonata.config import SonataConfig as ConfigDict |
| 7 | +from bmtk.utils import sonata |
| 8 | +from bmtk.utils.sonata.config import SonataConfig |
10 | 9 | from bmtk.utils.reports import SpikeTrains |
11 | 10 | from bmtk.utils.reports.spike_trains import plotting |
| 11 | +from bmtk.simulator.utils import simulation_reports |
12 | 12 |
|
13 | 13 |
|
14 | | -def load_spikes_file(config_file=None, spikes_file=None): |
15 | | - if spikes_file is not None: |
16 | | - return SpikeTrains.load(spikes_file) |
| 14 | +def _find_spikes(spikes_file=None, config_file=None, population=None): |
| 15 | + candidate_spikes = [] |
| 16 | + |
| 17 | + # Get spikes file(s) |
| 18 | + if spikes_file: |
| 19 | + # User has explicity set the location of the spike files |
| 20 | + candidate_spikes.append(spikes_file) |
17 | 21 |
|
18 | 22 | elif config_file is not None: |
19 | | - config = ConfigDict.from_json(config_file) |
20 | | - return SpikeTrains.load(config['output']['spikes_file']) |
| 23 | + # Otherwise search the config.json for all possible output spikes_files. We can use the simulation_reports |
| 24 | + # module to find any spikes output file specified in config's "output" or "reports" section. |
| 25 | + config = SonataConfig.from_json(config_file) |
| 26 | + sim_reports = simulation_reports.from_config(config) |
| 27 | + for report in sim_reports: |
| 28 | + if report.module == 'spikes_report': |
| 29 | + # BMTK can end up output the same spikes file in SONATA, CSV, and NWB format. Try fetching the SONATA |
| 30 | + # version first, then CSV, and finally NWB if it exists. |
| 31 | + spikes_sonata = report.params.get('spikes_file', None) |
| 32 | + spikes_csv = report.params.get('spikes_file_csv', None) |
| 33 | + spikes_nwb = report.params.get('spikes_file_nwb', None) |
| 34 | + |
| 35 | + if spikes_sonata is not None: |
| 36 | + candidate_spikes.append(spikes_sonata) |
| 37 | + elif spikes_csv is not None: |
| 38 | + candidate_spikes.append(spikes_csv) |
| 39 | + elif spikes_csv is not None: |
| 40 | + candidate_spikes.append(spikes_nwb) |
| 41 | + |
| 42 | + # TODO: Should we also look in the "inputs" for displaying input spike statistics? |
| 43 | + |
| 44 | + if not candidate_spikes: |
| 45 | + raise ValueError('Could not find an output spikes-file. Use "spikes_file" parameter option.') |
| 46 | + |
| 47 | + # Find file that contains spikes for the specified "population" of nodes. If "population" parameter is not |
| 48 | + # specified try to guess that spikes that the user wants to visualize. |
| 49 | + if population is not None: |
| 50 | + spikes_obj = None |
| 51 | + for spikes_f in candidate_spikes: |
| 52 | + st = SpikeTrains.load(spikes_f) |
| 53 | + if population in st.populations: |
| 54 | + if spikes_obj is None: |
| 55 | + spikes_obj = st |
| 56 | + else: |
| 57 | + spikes_obj.merge(st) |
| 58 | + |
| 59 | + if spikes_obj is None: |
| 60 | + raise ValueError('Could not fine spikes file with node population "{}".'.format(population)) |
| 61 | + else: |
| 62 | + return population, spikes_obj |
21 | 63 |
|
| 64 | + else: |
| 65 | + if len(candidate_spikes) > 1: |
| 66 | + raise ValueError('Found more than one spike-trains file') |
22 | 67 |
|
23 | | -def to_dataframe(config_file, spikes_file=None): |
24 | | - spike_trains = load_spikes_file(config_file=config_file, spikes_file=spikes_file) |
25 | | - return spike_trains.to_dataframe() |
| 68 | + spikes_f = candidate_spikes[0] |
| 69 | + if not os.path.exists(spikes_f): |
| 70 | + raise ValueError('Did not find spike-trains file {}. Make sure the simulation has completed.'.format( |
| 71 | + spikes_f)) |
| 72 | + |
| 73 | + spikes_obj = SpikeTrains.load(spikes_f) |
| 74 | + if len(spikes_obj.populations) > 1: |
| 75 | + raise ValueError('Spikes file {} contains more than one node population.'.format(spikes_f)) |
| 76 | + else: |
| 77 | + return spikes_obj.populations[0], spikes_obj |
| 78 | + |
| 79 | + |
| 80 | +def _find_nodes(population, config=None, nodes_file=None, node_types_file=None): |
| 81 | + if nodes_file is not None: |
| 82 | + network = sonata.File(data_files=nodes_file, data_type_files=node_types_file) |
| 83 | + if population not in network.nodes.population_names: |
| 84 | + raise ValueError('node population "{}" not found in {}'.format(population, nodes_file)) |
| 85 | + return network.nodes[population] |
| 86 | + |
| 87 | + elif config is not None: |
| 88 | + for nodes_grp in config.nodes: |
| 89 | + network = sonata.File(data_files=nodes_grp['nodes_file'], data_type_files=nodes_grp['node_types_file']) |
| 90 | + if population in network.nodes.population_names: |
| 91 | + return network.nodes[population] |
| 92 | + |
| 93 | + raise ValueError('Could not find nodes file with node population "{}".'.format(population)) |
| 94 | + |
| 95 | + |
| 96 | +def _plot_helper(plot_fnc, config_file=None, population=None, times=None, title=None, show=True, |
| 97 | + group_by=None, group_excludes=None, |
| 98 | + spikes_file=None, nodes_file=None, node_types_file=None): |
| 99 | + sonata_config = SonataConfig.from_json(config_file) if config_file else None |
| 100 | + pop, spike_trains = _find_spikes(config_file=config_file, spikes_file=spikes_file, population=population) |
| 101 | + |
| 102 | + # Create the title |
| 103 | + title = title if title is not None else '{} Nodes'.format(pop) |
| 104 | + |
| 105 | + # Get start and stop times from config if needed |
| 106 | + if sonata_config and times is None: |
| 107 | + times = (sonata_config.tstart, sonata_config.tstop) |
| 108 | + |
| 109 | + # Create node-groups |
| 110 | + if group_by is not None: |
| 111 | + node_groups = [] |
| 112 | + nodes = _find_nodes(population=pop, config=sonata_config, nodes_file=nodes_file, |
| 113 | + node_types_file=node_types_file) |
| 114 | + grouped_df = None |
| 115 | + for grp in nodes.groups: |
| 116 | + if group_by in grp.all_columns: |
| 117 | + grp_df = grp.to_dataframe() |
| 118 | + grp_df = grp_df[['node_id', group_by]] |
| 119 | + grouped_df = grp_df if grouped_df is None else grouped_df.append(grp_df, ignore_index=True) |
| 120 | + |
| 121 | + if grouped_df is None: |
| 122 | + raise ValueError('Could not find any nodes with group_by attribute "{}"'.format(group_by)) |
| 123 | + |
| 124 | + # Convert from string to list so we can always use the isin() method for filtering |
| 125 | + if isinstance(group_excludes, string_types): |
| 126 | + group_excludes = [group_excludes] |
| 127 | + elif group_excludes is None: |
| 128 | + group_excludes = [] |
| 129 | + |
| 130 | + for grp_key, grp in grouped_df.groupby(group_by): |
| 131 | + if grp_key in group_excludes: |
| 132 | + continue |
| 133 | + node_groups.append({'node_ids': np.array(grp['node_id']), 'label': grp_key}) |
| 134 | + |
| 135 | + else: |
| 136 | + node_groups = None |
| 137 | + |
| 138 | + plot_fnc(spike_trains=spike_trains, node_groups=node_groups, population=pop, times=times, title=title, show=show) |
| 139 | + |
| 140 | + |
| 141 | +def plot_raster(config_file=None, population=None, with_histogram=True, times=None, title=None, show=True, |
| 142 | + group_by=None, group_excludes=None, |
| 143 | + spikes_file=None, nodes_file=None, node_types_file=None): |
| 144 | + |
| 145 | + plot_fnc = partial(plotting.plot_raster, with_histogram=with_histogram) |
| 146 | + return _plot_helper(plot_fnc, |
| 147 | + config_file=config_file, population=population, times=times, title=title, show=show, |
| 148 | + group_by=group_by, group_excludes=group_excludes, |
| 149 | + spikes_file=spikes_file, nodes_file=nodes_file, node_types_file=node_types_file |
| 150 | + ) |
26 | 151 |
|
27 | 152 |
|
28 | | -def plot_raster(config_file, spikes_file=None): |
29 | | - spike_trains = load_spikes_file(config_file=config_file, spikes_file=spikes_file) |
30 | | - plotting.plot_raster(spike_trains) |
31 | | - plt.show() |
| 153 | +def plot_rates(config_file=None, population=None, smoothing=False, smoothing_params=None, times=None, title=None, |
| 154 | + show=True, group_by=None, group_excludes=None, spikes_file=None, nodes_file=None, node_types_file=None): |
32 | 155 |
|
| 156 | + plot_fnc = partial(plotting.plot_rates, smoothing=smoothing, smoothing_params=smoothing_params) |
| 157 | + return _plot_helper(plot_fnc, |
| 158 | + config_file=config_file, population=population, times=times, title=title, show=show, |
| 159 | + group_by=group_by, group_excludes=group_excludes, |
| 160 | + spikes_file=spikes_file, nodes_file=nodes_file, node_types_file=node_types_file |
| 161 | + ) |
33 | 162 |
|
34 | | -def plot_rates(config_file): |
35 | | - spike_trains = load_spikes_file(config_file) |
36 | | - plotting.plot_rates(spike_trains) |
| 163 | + |
| 164 | +def plot_rates_boxplot(config_file=None, population=None, times=None, title=None, show=True, |
| 165 | + group_by=None, group_excludes=None, |
| 166 | + spikes_file=None, nodes_file=None, node_types_file=None): |
| 167 | + |
| 168 | + plot_fnc = partial(plotting.plot_rates_boxplot) |
| 169 | + return _plot_helper(plot_fnc, |
| 170 | + config_file=config_file, population=population, times=times, title=title, show=show, |
| 171 | + group_by=group_by, group_excludes=group_excludes, |
| 172 | + spikes_file=spikes_file, nodes_file=nodes_file, node_types_file=node_types_file |
| 173 | + ) |
37 | 174 |
|
38 | 175 |
|
39 | 176 | def spike_statistics(spikes_file, simulation=None, simulation_time=None, groupby=None, network=None, **filterparams): |
@@ -68,3 +205,8 @@ def calc_stats(r): |
68 | 205 | return vals_df |
69 | 206 | else: |
70 | 207 | return spike_counts_df |
| 208 | + |
| 209 | + |
| 210 | +def to_dataframe(config_file, spikes_file=None, population=None): |
| 211 | + _, spike_trains = _find_spikes(config_file=config_file, spikes_file=spikes_file, population=population) |
| 212 | + return spike_trains.to_dataframe() |
0 commit comments