Skip to content

Commit 763706f

Browse files
committed
updating plotting functionality
1 parent 164c5f1 commit 763706f

File tree

28 files changed

+1100
-374
lines changed

28 files changed

+1100
-374
lines changed

bmtk/analyzer/compartment.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
import os
2+
from six import string_types
3+
import numpy as np
4+
5+
from bmtk.utils import sonata
6+
from bmtk.utils.sonata.config import SonataConfig
7+
from bmtk.utils.reports.compartment import CompartmentReport
8+
from bmtk.utils.reports.compartment import plotting
9+
from bmtk.simulator.utils import simulation_reports
10+
11+
12+
def _get_report(report_path=None, config=None, report_name=None):
13+
if report_path is not None:
14+
return report_path, CompartmentReport.load(report=report_path)
15+
16+
elif config is not None:
17+
selected_reports = []
18+
sim_reports = simulation_reports.from_config(config)
19+
for report in sim_reports:
20+
if report.module in ['membrane_report', 'multimeter_report']:
21+
rname = report.report_name
22+
rfile = report.params['file_name']
23+
# TODO: Full path should be determined by config/simulation_reports module
24+
rpath = rfile if os.path.isabs(rfile) else os.path.join(report.params['tmp_dir'], rfile)
25+
if report_name is not None and report_name == rname:
26+
selected_reports.append((rname, CompartmentReport.load(rpath)))
27+
elif report_name is None:
28+
selected_reports.append((rname, CompartmentReport.load(rpath)))
29+
30+
if len(selected_reports) == 0:
31+
msg = 'Could not find a report '
32+
msg += '' if report_name is None else 'with report_name "{}"'.format(report_name)
33+
msg += ' from configuration file. . Use "report_path" parameter instead.'
34+
raise ValueError(msg)
35+
36+
elif len(selected_reports) > 1:
37+
avail_reports = ', '.join(s[0] for s in selected_reports)
38+
raise ValueError('Configuration file contained multiple "membrane_reports", use "report_name" or'
39+
'"report_path" to pick which one to plot. Option values: {}'.format(avail_reports))
40+
41+
else:
42+
return selected_reports[0]
43+
44+
else:
45+
raise AttributeError('Could not find a compartment report SONATA file. Please user "config_file" or '
46+
'"report_path" options.')
47+
48+
49+
def _find_nodes(population, config=None, nodes_file=None, node_types_file=None):
50+
if nodes_file is not None:
51+
network = sonata.File(data_files=nodes_file, data_type_files=node_types_file)
52+
if population not in network.nodes.population_names:
53+
raise ValueError('node population "{}" not found in {}'.format(population, nodes_file))
54+
return network.nodes[population]
55+
56+
elif config is not None:
57+
for nodes_grp in config.nodes:
58+
network = sonata.File(data_files=nodes_grp['nodes_file'], data_type_files=nodes_grp['node_types_file'])
59+
if population in network.nodes.population_names:
60+
return network.nodes[population]
61+
62+
raise ValueError('Could not find nodes file with node population "{}".'.format(population))
63+
64+
65+
def plot_traces(report_path=None, config_file=None, report_name=None, population=None, group_by=None,
66+
group_excludes=None, nodes_file=None, node_types_file=None,
67+
node_ids=None, sections='origin', average=False, times=None, title=None,
68+
show_legend=None, show=True):
69+
sonata_config = SonataConfig.from_json(config_file) if config_file else None
70+
report_name, cr = _get_report(report_path=report_path, config=sonata_config, report_name=report_name)
71+
72+
if population is None:
73+
pops = cr.populations
74+
if len(pops) > 1:
75+
raise ValueError('Report {} contains more than population of nodes ({}). Use population parameter'.format(
76+
report_name, pops
77+
))
78+
population = pops[0]
79+
80+
if title is None:
81+
title = '{} ({})'.format(report_name, population)
82+
83+
# Create node-groups
84+
if group_by is not None:
85+
node_groups = []
86+
nodes = _find_nodes(population=population, config=sonata_config, nodes_file=nodes_file,
87+
node_types_file=node_types_file)
88+
89+
grouped_df = None
90+
for grp in nodes.groups:
91+
if group_by in grp.all_columns:
92+
grp_df = grp.to_dataframe()
93+
grp_df = grp_df[['node_id', group_by]]
94+
grouped_df = grp_df if grouped_df is None else grouped_df.append(grp_df, ignore_index=True)
95+
96+
if grouped_df is None:
97+
raise ValueError('Could not find any nodes with group_by attribute "{}"'.format(group_by))
98+
99+
# Convert from string to list so we can always use the isin() method for filtering
100+
if isinstance(group_excludes, string_types):
101+
group_excludes = [group_excludes]
102+
elif group_excludes is None:
103+
group_excludes = []
104+
105+
for grp_key, grp in grouped_df.groupby(group_by):
106+
if grp_key in group_excludes:
107+
continue
108+
node_groups.append({'node_ids': np.array(grp['node_id']), 'label': grp_key})
109+
110+
if len(node_groups) == 0:
111+
exclude_str = ' excluding values {}'.format(', '.join(group_excludes)) if len(group_excludes) > 0 else ''
112+
raise ValueError('Could not find any node-groups using group_by="{}"{}.'.format(group_by, exclude_str))
113+
114+
else:
115+
node_groups = None
116+
117+
return plotting.plot_traces(
118+
report=cr,
119+
population=population,
120+
node_ids=node_ids,
121+
sections=sections,
122+
average=average,
123+
node_groups=node_groups,
124+
times=times,
125+
title=title,
126+
show_legend=show_legend,
127+
show=show
128+
)

bmtk/analyzer/spike_trains.py

Lines changed: 163 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,176 @@
1+
import os
12
import numpy as np
23
import pandas as pd
3-
import h5py
4+
from functools import partial
5+
from six import string_types
46

5-
6-
import matplotlib.pyplot as plt
7-
8-
9-
from bmtk.utils.sonata.config import SonataConfig as ConfigDict
7+
from bmtk.utils import sonata
8+
from bmtk.utils.sonata.config import SonataConfig
109
from bmtk.utils.reports import SpikeTrains
1110
from bmtk.utils.reports.spike_trains import plotting
11+
from bmtk.simulator.utils import simulation_reports
1212

1313

14-
def load_spikes_file(config_file=None, spikes_file=None):
15-
if spikes_file is not None:
16-
return SpikeTrains.load(spikes_file)
14+
def _find_spikes(spikes_file=None, config_file=None, population=None):
15+
candidate_spikes = []
16+
17+
# Get spikes file(s)
18+
if spikes_file:
19+
# User has explicity set the location of the spike files
20+
candidate_spikes.append(spikes_file)
1721

1822
elif config_file is not None:
19-
config = ConfigDict.from_json(config_file)
20-
return SpikeTrains.load(config['output']['spikes_file'])
23+
# Otherwise search the config.json for all possible output spikes_files. We can use the simulation_reports
24+
# module to find any spikes output file specified in config's "output" or "reports" section.
25+
config = SonataConfig.from_json(config_file)
26+
sim_reports = simulation_reports.from_config(config)
27+
for report in sim_reports:
28+
if report.module == 'spikes_report':
29+
# BMTK can end up output the same spikes file in SONATA, CSV, and NWB format. Try fetching the SONATA
30+
# version first, then CSV, and finally NWB if it exists.
31+
spikes_sonata = report.params.get('spikes_file', None)
32+
spikes_csv = report.params.get('spikes_file_csv', None)
33+
spikes_nwb = report.params.get('spikes_file_nwb', None)
34+
35+
if spikes_sonata is not None:
36+
candidate_spikes.append(spikes_sonata)
37+
elif spikes_csv is not None:
38+
candidate_spikes.append(spikes_csv)
39+
elif spikes_csv is not None:
40+
candidate_spikes.append(spikes_nwb)
41+
42+
# TODO: Should we also look in the "inputs" for displaying input spike statistics?
43+
44+
if not candidate_spikes:
45+
raise ValueError('Could not find an output spikes-file. Use "spikes_file" parameter option.')
46+
47+
# Find file that contains spikes for the specified "population" of nodes. If "population" parameter is not
48+
# specified try to guess that spikes that the user wants to visualize.
49+
if population is not None:
50+
spikes_obj = None
51+
for spikes_f in candidate_spikes:
52+
st = SpikeTrains.load(spikes_f)
53+
if population in st.populations:
54+
if spikes_obj is None:
55+
spikes_obj = st
56+
else:
57+
spikes_obj.merge(st)
58+
59+
if spikes_obj is None:
60+
raise ValueError('Could not fine spikes file with node population "{}".'.format(population))
61+
else:
62+
return population, spikes_obj
2163

64+
else:
65+
if len(candidate_spikes) > 1:
66+
raise ValueError('Found more than one spike-trains file')
2267

23-
def to_dataframe(config_file, spikes_file=None):
24-
spike_trains = load_spikes_file(config_file=config_file, spikes_file=spikes_file)
25-
return spike_trains.to_dataframe()
68+
spikes_f = candidate_spikes[0]
69+
if not os.path.exists(spikes_f):
70+
raise ValueError('Did not find spike-trains file {}. Make sure the simulation has completed.'.format(
71+
spikes_f))
72+
73+
spikes_obj = SpikeTrains.load(spikes_f)
74+
if len(spikes_obj.populations) > 1:
75+
raise ValueError('Spikes file {} contains more than one node population.'.format(spikes_f))
76+
else:
77+
return spikes_obj.populations[0], spikes_obj
78+
79+
80+
def _find_nodes(population, config=None, nodes_file=None, node_types_file=None):
81+
if nodes_file is not None:
82+
network = sonata.File(data_files=nodes_file, data_type_files=node_types_file)
83+
if population not in network.nodes.population_names:
84+
raise ValueError('node population "{}" not found in {}'.format(population, nodes_file))
85+
return network.nodes[population]
86+
87+
elif config is not None:
88+
for nodes_grp in config.nodes:
89+
network = sonata.File(data_files=nodes_grp['nodes_file'], data_type_files=nodes_grp['node_types_file'])
90+
if population in network.nodes.population_names:
91+
return network.nodes[population]
92+
93+
raise ValueError('Could not find nodes file with node population "{}".'.format(population))
94+
95+
96+
def _plot_helper(plot_fnc, config_file=None, population=None, times=None, title=None, show=True,
97+
group_by=None, group_excludes=None,
98+
spikes_file=None, nodes_file=None, node_types_file=None):
99+
sonata_config = SonataConfig.from_json(config_file) if config_file else None
100+
pop, spike_trains = _find_spikes(config_file=config_file, spikes_file=spikes_file, population=population)
101+
102+
# Create the title
103+
title = title if title is not None else '{} Nodes'.format(pop)
104+
105+
# Get start and stop times from config if needed
106+
if sonata_config and times is None:
107+
times = (sonata_config.tstart, sonata_config.tstop)
108+
109+
# Create node-groups
110+
if group_by is not None:
111+
node_groups = []
112+
nodes = _find_nodes(population=pop, config=sonata_config, nodes_file=nodes_file,
113+
node_types_file=node_types_file)
114+
grouped_df = None
115+
for grp in nodes.groups:
116+
if group_by in grp.all_columns:
117+
grp_df = grp.to_dataframe()
118+
grp_df = grp_df[['node_id', group_by]]
119+
grouped_df = grp_df if grouped_df is None else grouped_df.append(grp_df, ignore_index=True)
120+
121+
if grouped_df is None:
122+
raise ValueError('Could not find any nodes with group_by attribute "{}"'.format(group_by))
123+
124+
# Convert from string to list so we can always use the isin() method for filtering
125+
if isinstance(group_excludes, string_types):
126+
group_excludes = [group_excludes]
127+
elif group_excludes is None:
128+
group_excludes = []
129+
130+
for grp_key, grp in grouped_df.groupby(group_by):
131+
if grp_key in group_excludes:
132+
continue
133+
node_groups.append({'node_ids': np.array(grp['node_id']), 'label': grp_key})
134+
135+
else:
136+
node_groups = None
137+
138+
plot_fnc(spike_trains=spike_trains, node_groups=node_groups, population=pop, times=times, title=title, show=show)
139+
140+
141+
def plot_raster(config_file=None, population=None, with_histogram=True, times=None, title=None, show=True,
142+
group_by=None, group_excludes=None,
143+
spikes_file=None, nodes_file=None, node_types_file=None):
144+
145+
plot_fnc = partial(plotting.plot_raster, with_histogram=with_histogram)
146+
return _plot_helper(plot_fnc,
147+
config_file=config_file, population=population, times=times, title=title, show=show,
148+
group_by=group_by, group_excludes=group_excludes,
149+
spikes_file=spikes_file, nodes_file=nodes_file, node_types_file=node_types_file
150+
)
26151

27152

28-
def plot_raster(config_file, spikes_file=None):
29-
spike_trains = load_spikes_file(config_file=config_file, spikes_file=spikes_file)
30-
plotting.plot_raster(spike_trains)
31-
plt.show()
153+
def plot_rates(config_file=None, population=None, smoothing=False, smoothing_params=None, times=None, title=None,
154+
show=True, group_by=None, group_excludes=None, spikes_file=None, nodes_file=None, node_types_file=None):
32155

156+
plot_fnc = partial(plotting.plot_rates, smoothing=smoothing, smoothing_params=smoothing_params)
157+
return _plot_helper(plot_fnc,
158+
config_file=config_file, population=population, times=times, title=title, show=show,
159+
group_by=group_by, group_excludes=group_excludes,
160+
spikes_file=spikes_file, nodes_file=nodes_file, node_types_file=node_types_file
161+
)
33162

34-
def plot_rates(config_file):
35-
spike_trains = load_spikes_file(config_file)
36-
plotting.plot_rates(spike_trains)
163+
164+
def plot_rates_boxplot(config_file=None, population=None, times=None, title=None, show=True,
165+
group_by=None, group_excludes=None,
166+
spikes_file=None, nodes_file=None, node_types_file=None):
167+
168+
plot_fnc = partial(plotting.plot_rates_boxplot)
169+
return _plot_helper(plot_fnc,
170+
config_file=config_file, population=population, times=times, title=title, show=show,
171+
group_by=group_by, group_excludes=group_excludes,
172+
spikes_file=spikes_file, nodes_file=nodes_file, node_types_file=node_types_file
173+
)
37174

38175

39176
def spike_statistics(spikes_file, simulation=None, simulation_time=None, groupby=None, network=None, **filterparams):
@@ -68,3 +205,8 @@ def calc_stats(r):
68205
return vals_df
69206
else:
70207
return spike_counts_df
208+
209+
210+
def to_dataframe(config_file, spikes_file=None, population=None):
211+
_, spike_trains = _find_spikes(config_file=config_file, spikes_file=spikes_file, population=population)
212+
return spike_trains.to_dataframe()

bmtk/analyzer/visualization/spikes.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,13 +26,14 @@
2626
from six import string_types
2727
import pandas as pd
2828
import numpy as np
29+
import warnings
2930
import matplotlib.pyplot as plt
3031
import matplotlib.cm as cmx
3132
import matplotlib.colors as colors
3233
import matplotlib.gridspec as gridspec
3334

3435
import bmtk.simulator.utils.config as config
35-
from bmtk.utils.reports.spike_trains.plotting import plot_raster, plot_rates, plot_raster_cmp
36+
from bmtk.utils.reports.spike_trains.plotting import plot_raster, plot_rates # , plot_raster_cmp
3637

3738

3839
from mpl_toolkits.axes_grid1 import make_axes_locatable
@@ -119,6 +120,7 @@ def parse_line(line):
119120

120121

121122
def plot_spikes_config(configure, group_key=None, exclude=[], save_as=None, show_plot=True):
123+
warnings.warn('Deprecated: Please use bmtk.analyzer.spike_trains.plot_raster instead.', DeprecationWarning)
122124
if isinstance(configure, string_types):
123125
conf = config.from_json(configure)
124126
elif isinstance(configure, dict):
@@ -135,6 +137,7 @@ def plot_spikes_config(configure, group_key=None, exclude=[], save_as=None, show
135137

136138
def plot_spikes(cells_file, cell_models_file, spikes_file, population=None, group_key=None, exclude=[], save_as=None,
137139
show=True, title=None, legend=True, font_size=None):
140+
warnings.warn('Deprecated: Please use bmtk.analyzer.spike_trains.plot_raster instead.', DeprecationWarning)
138141
# check if can be shown and/or saved
139142
#if save_as is not None:
140143
# if os.path.exists(save_as):

bmtk/simulator/bionet/modules/record_cellvars.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -81,6 +81,7 @@ def __init__(self, tmp_dir, file_name, variable_name, cells=None, gids=None, sec
8181

8282
self._tmp_dir = tmp_dir
8383

84+
# TODO: Full path should be determined by config/simulation_reports module
8485
self._file_name = file_name if os.path.isabs(file_name) else os.path.join(tmp_dir, file_name)
8586
self._all_gids = cells
8687
self._local_gids = []

bmtk/tests/utils/reports/compartment/__init__.py

Whitespace-only changes.

0 commit comments

Comments
 (0)