Skip to content

Commit cfcb258

Browse files
authored
Merge pull request #326 from int-brain-lab/glm_improvements
Bugfix to Sequential Selection of (G)LM features
2 parents 20feca0 + a9cbfa2 commit cfcb258

File tree

5 files changed

+39
-12
lines changed

5 files changed

+39
-12
lines changed

brainbox/io/one.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -70,11 +70,11 @@ def load_channel_locations(eid, one=None, probe=None, aligned=False):
7070
counts = [0]
7171
else:
7272
tracing = [(insertions.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}).
73-
get('tracing_exists', False))]
73+
get('tracing_exists', False))]
7474
resolved = [(insertions.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}).
75-
get('alignment_resolved', False))]
75+
get('alignment_resolved', False))]
7676
counts = [(insertions.get('json', {'temp': 0}).get('extended_qc', {'temp': 0}).
77-
get('alignment_count', 0))]
77+
get('alignment_count', 0))]
7878
probe_id = [insertions['id']]
7979
# No specific probe specified, load any that is available
8080
# Need to catch for the case where we have two of the same probe insertions
@@ -420,7 +420,7 @@ def load_wheel_reaction_times(eid, one=None):
420420

421421

422422
def load_trials_df(eid, one=None, maxlen=None, t_before=0., t_after=0., ret_wheel=False,
423-
ret_abswheel=False, wheel_binsize=0.02, addtl_types=()):
423+
ret_abswheel=False, ext_DLC=False, wheel_binsize=0.02, addtl_types=[]):
424424
"""
425425
TODO Test this with new ONE
426426
Generate a pandas dataframe of per-trial timing information about a given session.
@@ -451,6 +451,8 @@ def load_trials_df(eid, one=None, maxlen=None, t_before=0., t_after=0., ret_whee
451451
Whether to return the time-resampled wheel velocity trace, by default False
452452
ret_abswheel : bool, optional
453453
Whether to return the time-resampled absolute wheel velocity trace, by default False
454+
ext_DLC : bool, optional
455+
Whether to extract DLC data, by default False
454456
wheel_binsize : float, optional
455457
Time bins to resample wheel velocity to, by default 0.02
456458
addtl_types : list, optional

brainbox/modeling/linear.py

Lines changed: 23 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,16 +59,38 @@ def __init__(self, design_matrix, spk_times, spk_clu,
5959
def _fit(self, dm, binned, cells=None):
6060
"""
6161
Fitting primitive that brainbox.NeuralModel.fit method will call
62+
63+
Parameters
64+
----------
65+
dm : np.ndarray
66+
Design matrix to use for fitting
67+
binned : np.ndarray
68+
Array of binned spike times. Must share first dimension with dm
69+
cells : iterable with .shape attribute, optional
70+
List of cells which are being fit. Use to generate index for output
71+
coefficients and intercepts, must share shape with second dimension
72+
of binned. When None will default to a list of all cells in the model object,
73+
by default None
74+
75+
Returns
76+
-------
77+
coefs, pd.Series
78+
Series containing fit coefficients for cells
79+
intercepts, pd.Series
80+
Series containing intercepts for fits.
6281
"""
6382
if cells is None:
6483
cells = self.clu_ids.flatten()
84+
if cells.shape[0] != binned.shape[1]:
85+
raise ValueError('Length of cells does not match shape of binned')
86+
6587
coefs = pd.Series(index=cells, name='coefficients', dtype=object)
6688
intercepts = pd.Series(index=cells, name='intercepts')
6789

6890
lm = self.estimator.fit(dm, binned)
6991
weight, intercept = lm.coef_, lm.intercept_
7092
for cell in cells:
71-
cell_idx = np.argwhere(self.clu_ids == cell)[0, 0]
93+
cell_idx = np.argwhere(cells == cell)[0, 0]
7294
coefs.at[cell] = weight[cell_idx, :]
7395
intercepts.at[cell] = intercept[cell_idx]
7496
return coefs, intercepts
@@ -84,7 +106,6 @@ def score(self):
84106
"""
85107
if not hasattr(self, 'coefs'):
86108
raise AttributeError('Model has not been fit yet.')
87-
88109
testmask = np.isin(self.design.trlabels, self.testinds).flatten()
89110
dm, binned = self.design[testmask, :], self.binnedspikes[testmask]
90111

brainbox/modeling/neural_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,7 @@ def __init__(self, design_matrix, spk_times, spk_clu,
103103
self.design = design_matrix
104104
self.spikes = spks
105105
self.clu = clu
106-
self.clu_ids = np.argwhere(np.sum(trialspiking, axis=0) > mintrials)
106+
self.clu_ids = np.argwhere(np.sum(trialspiking, axis=0) > mintrials).flatten()
107107
self.traininds = traininds
108108
self.testinds = testinds
109109
self.stepwise = stepwise

brainbox/modeling/poisson.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -38,16 +38,19 @@ def _fit(self, dm, binned, cells=None, noncovwarn=False):
3838
alpha : float
3939
Regularization strength, applied as multiplicative constant on ridge regularization.
4040
cells : list
41-
List of cells which should be fit. If None is passed, will default to fitting all cells
42-
in clu_ids
41+
List of cells labels for columns in binned. Will default to all cells in model if None
42+
is passed. Must be of the same length as columns in binned. By default None.
4343
"""
4444
if cells is None:
4545
cells = self.clu_ids.flatten()
46+
if cells.shape[0] != binned.shape[1]:
47+
raise ValueError('Length of cells does not match shape of binned')
48+
4649
coefs = pd.Series(index=cells, name='coefficients', dtype=object)
4750
intercepts = pd.Series(index=cells, name='intercepts')
4851
nonconverged = []
4952
for cell in tqdm(cells, 'Fitting units:', leave=False):
50-
cell_idx = np.argwhere(self.clu_ids == cell)[0, 0]
53+
cell_idx = np.argwhere(cells == cell)[0, 0]
5154
cellbinned = binned[:, cell_idx]
5255
with catch_warnings(record=True) as w:
5356
fitobj = PoissonRegressor(alpha=self.alpha,

brainbox/modeling/utils.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -101,7 +101,7 @@ def fit(self, progress=False):
101101
new_feature_idx, nf_score = self._get_best_new_feature(current_mask, cells)
102102
for cell in cells:
103103
maskdf.at[cell, self.features[new_feature_idx.loc[cell]]] = True
104-
seqdf.loc[cell, i] = self.features[new_feature_idx]
104+
seqdf.loc[cell, i] = self.features[new_feature_idx.loc[cell]]
105105
scoredf.loc[cell, i] = nf_score.loc[cell]
106106
self.support_ = maskdf
107107
self.sequences_ = seqdf
@@ -110,7 +110,8 @@ def fit(self, progress=False):
110110
def _get_best_new_feature(self, mask, cells):
111111
mask = np.array(mask)
112112
candidate_features = np.flatnonzero(~mask)
113-
my = self.model.binnedspikes[self.train]
113+
cell_idxs = np.argwhere(np.isin(self.model.clu_ids, cells)).flatten()
114+
my = self.model.binnedspikes[np.ix_(self.train, cell_idxs)]
114115
scores = pd.DataFrame(index=cells, columns=candidate_features, dtype=float)
115116
for feature_idx in candidate_features:
116117
candidate_mask = mask.copy()

0 commit comments

Comments
 (0)