Skip to content

Commit d61ee27

Browse files
authored
Merge pull request #425 from int-brain-lab/glm_improvements
Changes to GLM code structure, added tests for design matrices
2 parents 24d33e8 + 087e6f3 commit d61ee27

File tree

7 files changed

+213
-138
lines changed

7 files changed

+213
-138
lines changed

brainbox/modeling/design_matrix.py

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class DesignMatrix:
1111
and allow the generation of a design matrix with specified regressors
1212
"""
1313

14-
def __init__(self, trialsdf, vartypes=None, binwidth=0.02):
14+
def __init__(self, trialsdf, vartypes, binwidth=0.02):
1515
"""
1616
Class for generating design matrices to model neural data. Provides handy routines for
1717
describing neural spiking activity using basis functions and other primitives.
@@ -31,7 +31,7 @@ def __init__(self, trialsdf, vartypes=None, binwidth=0.02):
3131
3232
Obligatory columns for the dataframe are "trial_start" and "trial_end", which tell the
3333
constructor which time points to associate with that trial.
34-
vartypes : dict, optional
34+
vartypes : dict
3535
Dictionary of types for each of the columns in trialsdf. Columns must be of the types:
3636
-- timing: timing events, in which the column values are times since the start of the
3737
session of an event within that trial, e.g. stimulus onset.
@@ -41,46 +41,44 @@ def __init__(self, trialsdf, vartypes=None, binwidth=0.02):
4141
changes within the trial. e.g. pupil diameter.
4242
Dictionary keys should be columns in trialsdf, values should be strings that are equal
4343
to one of the above.
44-
45-
If vartypes is not passed, the constructor will assume you know what you are doing. Be
46-
warned that this can result in the class failing in spectacular and vindictive ways.
47-
by default None
4844
binwidth : float, optional
4945
Length of time bins which will be used for design matrix, by default 0.02
5046
"""
5147
# Data checks #
52-
if vartypes is not None:
53-
validtypes = ('timing', 'continuous', 'value')
54-
if not all([name in vartypes for name in trialsdf.columns]):
55-
raise KeyError("Some columns were not described in vartypes")
56-
if not all([value in validtypes for value in vartypes.values()]):
57-
raise ValueError("Invalid values were passed in vartypes")
48+
validtypes = ('timing', 'continuous', 'value')
49+
if not all([name in vartypes for name in trialsdf.columns]):
50+
raise KeyError("Some columns were not described in vartypes")
51+
if not all([value in validtypes for value in vartypes.values()]):
52+
raise ValueError("Invalid values were passed in vartypes")
5853

5954
# Filter out cells which don't meet the criteria for minimum spiking, while doing trial
6055
# assignment
61-
self.vartypes = vartypes
62-
if vartypes is not None:
63-
self.vartypes['duration'] = 'value'
56+
vartypes['duration'] = 'value'
6457
base_df = trialsdf.copy()
6558
trialsdf = trialsdf.copy() # Make sure we don't modify the original dataframe
6659
trbounds = trialsdf[['trial_start', 'trial_end']] # Get the start/end of trials
6760
# Empty trial duration value to use later
6861
trialsdf['duration'] = np.nan
62+
# Figure out which columns are timing variables if vartypes was passed
6963
timingvars = [col for col in trialsdf.columns if vartypes[col] == 'timing']
64+
7065
for i, (start, end) in trbounds.iterrows():
7166
if any(np.isnan((start, end))):
7267
warn(f"NaN values found in trial start or end at trial number {i}. "
7368
"Discarding trial.")
7469
trialsdf.drop(i, inplace=True)
7570
continue
7671
for col in timingvars:
72+
# Round values for the timing variables to the 5th decimal place and subtract
73+
# trial start time.
7774
trialsdf.at[i, col] = np.round(trialsdf.at[i, col] - start, decimals=5)
7875
trialsdf.at[i, 'duration'] = end - start
7976

8077
# Set model parameters to begin with
8178
self.binwidth = binwidth
8279
self.covar = {}
8380
self.trialsdf = trialsdf
81+
self.vartypes = vartypes
8482
self.base_df = base_df
8583
self.compiled = False
8684
return
@@ -155,7 +153,7 @@ def add_covariate_timing(self, covlabel, eventname, bases,
155153
else:
156154
raise TypeError('deltaval must be None, pandas series, or string reference'
157155
f' to trialsdf column. {type(deltaval)} was passed instead.')
158-
if eventname in self.vartypes and self.vartypes[eventname] != 'timing':
156+
if self.vartypes[eventname] != 'timing':
159157
raise TypeError(f'Column {eventname} in trialsdf is not registered as a timing')
160158

161159
vecsizes = self.trialsdf['duration'].apply(self.binf)
@@ -174,6 +172,33 @@ def add_covariate_timing(self, covlabel, eventname, bases,
174172

175173
def add_covariate_boxcar(self, covlabel, boxstart, boxend,
176174
cond=None, height=None, desc=''):
175+
"""
176+
Convenience wrapped on add_covariate to add a boxcar covariate on the given start and end
177+
variables, such that the covariate is a step function with non-zero value between those
178+
values.
179+
180+
Note: This has not been tested yet and is not guaranteed to work, or work correctly.
181+
182+
Parameters
183+
----------
184+
covlabel : str
185+
Name of the covariate for accessing later. Can be accessed via dot syntax of the
186+
instance usually.
187+
boxstart : str
188+
Column name in trialsdf which will be used to define the start of the boxcar
189+
boxend : str
190+
Column name in trialsdf which defines the end of boxcar variable
191+
cond : None, list, or func, optional
192+
Condition in which to apply this covariate. Can either be a list of trial indices, or
193+
a function which takes in a row of the trialsdf and returns a boolen on inclusion,
194+
by default None
195+
height : None, str, or pandas series, optional
196+
Values for the height of the boxcar during the period defined per trial. Can be a
197+
reference to a column in trialsdf or a separate series, by default None
198+
desc : str, optional
199+
Additional information about the covariate to store as a string, by default ''
200+
201+
"""
177202
if covlabel in self.covar:
178203
raise AttributeError(f'Covariate {covlabel} already exists in model.')
179204
self._compile_check()
@@ -210,6 +235,27 @@ def add_covariate_boxcar(self, covlabel, boxstart, boxend,
210235

211236
def add_covariate_raw(self, covlabel, raw,
212237
cond=None, desc=''):
238+
"""
239+
Convenience wrapper to add a 'raw' covariate, that is to say a covariate which is a
240+
continuous value that changes with time during the course of a trial.
241+
242+
Note: This has not been tested and is not guaranteed to work or to work correctly.
243+
244+
Parameters
245+
----------
246+
covlabel : str
247+
String used to reference covariate, can usually be accessed by instance's dot syntax
248+
raw : str, func, or pandas series
249+
The covariate to add to the design matrix. Can be a str reference to a column in
250+
trialsdf, a function which takes in rows of trialsdf and produces a vector for each
251+
row of the appropriate size given binwidth and trial duration, or a pandas series
252+
of vectors of said appropriate type.
253+
cond : None, list, or func, optional
254+
Trials in which to apply the given covariate. Can be a list of trial numbers,
255+
or a function which accepts rows of the trialsdf and returns a boolean, by default None
256+
desc : str, optional
257+
Additional information about the covariate for access later, by default ''
258+
"""
213259
stimlens = self.trialsdf.duration.apply(self.binf)
214260
if isinstance(raw, str):
215261
if raw not in self.trialsdf.columns:
@@ -354,7 +400,6 @@ def compile_design_matrix(self, dense=True):
354400
assert self.binnedspikes.shape[0] == dm.shape[0], "Oh shit. Indexing error."
355401
self.dm = dm
356402
self.trlabels = trlabels
357-
# self.dm = np.roll(dm, -1, axis=0) # Fix weird +1 offset bug in design matrix
358403
self.compiled = True
359404
return
360405

@@ -384,7 +429,7 @@ def denseconv(X, bases):
384429
A = np.zeros((T + TB - 1, int(np.sum(indices[kCov, :]))))
385430
for i, j in enumerate(np.argwhere(indices[kCov, :]).flat):
386431
A[:, i] = np.convolve(X[:, kCov], bases[:, j])
387-
BX[:, k: sI[kCov]] = A[: T, :]
432+
BX[:, k: sI[kCov]] = A[:T, :]
388433
k = sI[kCov]
389434
return BX
390435

@@ -400,5 +445,5 @@ def convbasis(stim, bases, offset=0):
400445
if offset < 0:
401446
X = X[-offset:, :]
402447
elif offset > 0:
403-
X = X[: -(1 + offset), :]
448+
X = X[:-offset, :]
404449
return X

brainbox/modeling/linear.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
class LinearGLM(NeuralModel):
1818
def __init__(self, design_matrix, spk_times, spk_clu,
1919
binwidth=0.02, metric='rsq', estimator=None,
20-
train=0.8, blocktrain=False, mintrials=100):
20+
mintrials=100):
2121
"""
2222
Fit a linear model using a DesignMatrix object and spike data. Can use ridge regression
2323
or pure linear regression
@@ -48,13 +48,15 @@ def __init__(self, design_matrix, spk_times, spk_clu,
4848
fitting, by default 100
4949
"""
5050
super().__init__(design_matrix, spk_times, spk_clu,
51-
binwidth, train, blocktrain, mintrials)
51+
binwidth, mintrials)
5252
if estimator is None:
5353
estimator = LinearRegression()
5454
if not isinstance(estimator, BaseEstimator):
5555
raise ValueError('Estimator must be a scikit-learn estimator, e.g. LinearRegression')
5656
self.metric = metric
5757
self.estimator = estimator
58+
self.link = lambda x: x
59+
self.invlink = self.link
5860

5961
def _fit(self, dm, binned, cells=None):
6062
"""
@@ -94,26 +96,3 @@ def _fit(self, dm, binned, cells=None):
9496
coefs.at[cell] = weight[cell_idx, :]
9597
intercepts.at[cell] = intercept[cell_idx]
9698
return coefs, intercepts
97-
98-
def score(self):
99-
"""
100-
Score model using chosen metric
101-
102-
Returns
103-
-------
104-
pandas.Series
105-
Score using chosen metric (defined at instantiation) for each unit fit by the model.
106-
"""
107-
if not hasattr(self, 'coefs'):
108-
raise AttributeError('Model has not been fit yet.')
109-
testmask = np.isin(self.design.trlabels, self.testinds).flatten()
110-
dm, binned = self.design[testmask, :], self.binnedspikes[testmask]
111-
112-
scores = pd.Series(index=self.coefs.index, name='scores')
113-
for cell in self.coefs.index:
114-
cell_idx = np.argwhere(self.clu_ids == cell)[0, 0]
115-
wt = self.coefs.loc[cell].reshape(-1, 1)
116-
bias = self.intercepts.loc[cell]
117-
y = binned[:, cell_idx]
118-
scores.at[cell] = self._scorer(wt, bias, dm, y)
119-
return scores

brainbox/modeling/neural_model.py

Lines changed: 49 additions & 31 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@ class NeuralModel:
2323
"""
2424

2525
def __init__(self, design_matrix, spk_times, spk_clu,
26-
binwidth=0.02, train=0.8, blocktrain=False, mintrials=100, stepwise=False):
26+
binwidth=0.02, mintrials=100, stepwise=False):
2727
"""
2828
Construct GLM object using information about all trials, and the relevant spike times.
2929
Only ingests data, and further object methods must be called to describe kernels, gain
@@ -38,10 +38,8 @@ def __init__(self, design_matrix, spk_times, spk_clu,
3838
spk_clu: numpy.array of integers
3939
1-D array of same shape as spk_times, with integer cluster IDs identifying which
4040
cluster a spike time belonged to.
41-
train: float
42-
Float in (0, 1] indicating proportion of data to use for training GLM vs testing
43-
(using the NeuralGLM.score method). Trials to keep will be randomly sampled, by default
44-
0.8
41+
binwidth : float
42+
Size of bins to put spikes in to, in seconds.
4543
mintrials: int
4644
Minimum number of trials in which neurons fired a spike in order to be fit. Defaults
4745
to 100 trials.
@@ -54,10 +52,6 @@ def __init__(self, design_matrix, spk_times, spk_clu,
5452
# Data checks #
5553
if not len(spk_times) == len(spk_clu):
5654
raise IndexError("Spike times and cluster IDs are not same length")
57-
if not isinstance(train, float) and not train == 1:
58-
raise TypeError('train must be a float between 0 and 1')
59-
if not ((train > 0) & (train <= 1)):
60-
raise ValueError('train must be between 0 and 1')
6155
if not design_matrix.compiled:
6256
raise AttributeError('Design matrix object must be compiled before passing to fit')
6357

@@ -83,29 +77,11 @@ def __init__(self, design_matrix, spk_times, spk_clu,
8377
spks[i] = spk_times[st_startind:st_endind] - start
8478
clu[i] = spk_clu[st_startind:st_endind]
8579

86-
# Break the data into test and train sections for cross-validation
87-
if train == 1:
88-
print('Training fraction set to 1. Training on all data.')
89-
traininds = base_df.index
90-
testinds = base_df.index
91-
else:
92-
trainlen = int(np.floor(len(base_df) * train))
93-
if blocktrain:
94-
testlen, midpoint = len(base_df) - trainlen, len(base_df) // 2
95-
starttest, endtest = midpoint - (testlen // 2), midpoint + (testlen // 2)
96-
testinds = base_df.index[starttest:endtest]
97-
traininds = base_df.index[~np.isin(base_df.index, testinds)]
98-
else:
99-
traininds = sorted(np.random.choice(base_df.index, trainlen, replace=False))
100-
testinds = base_df.index[~base_df.index.isin(traininds)]
101-
10280
# Set model parameters to begin with
10381
self.design = design_matrix
10482
self.spikes = spks
10583
self.clu = clu
10684
self.clu_ids = np.argwhere(np.sum(trialspiking, axis=0) > mintrials).flatten()
107-
self.traininds = traininds
108-
self.testinds = testinds
10985
self.stepwise = stepwise
11086
self.binwidth = binwidth
11187

@@ -168,7 +144,7 @@ def _scorer(self, wt, bias, dm, y):
168144
"""
169145
Score a single target y
170146
"""
171-
pred = (dm @ wt + bias).flatten()
147+
pred = self.link(dm @ wt + bias).flatten()
172148
if self.metric == 'dsq':
173149
null_pred = np.ones_like(pred) * np.mean(y)
174150
null_deviance = 2 * np.sum(xlogy(y, y / null_pred.flat) - y + null_pred.flat)
@@ -186,7 +162,7 @@ def _scorer(self, wt, bias, dm, y):
186162
else:
187163
raise AttributeError('No valid metric exists in the instance for use by _scorer()')
188164

189-
def fit(self, printcond=True):
165+
def fit(self, train_idx=None, printcond=True):
190166
"""
191167
Fit the current set of binned spikes as a function of the current design matrix. Requires
192168
NeuralGLM.bin_spike_trains and NeuralGLM.compile_design_matrix to be run first. Will store
@@ -195,6 +171,9 @@ def fit(self, printcond=True):
195171
196172
Parameters
197173
----------
174+
train_idx : array-like of trial indices, optional
175+
List of which trials to use to train the model. Defaults to None, which indicates all
176+
indices in the trialsdf will be used (100% train)
198177
printcond : bool
199178
Whether or not to print the condition number of the design matrix. Defaults to True
200179
@@ -204,10 +183,24 @@ def fit(self, printcond=True):
204183
List of coefficients fit. Not recommended to use these for interpretation. Use
205184
the .combine_weights() method instead.
206185
intercepts : list
207-
List of intercepts (bias terms) fit. Not recommended to use these for interpretation.
186+
List of intercepts (bias terms) fit.
208187
"""
188+
# Input checks
189+
if train_idx is None:
190+
train_idx = self.design.trialsdf.index
191+
if not np.all(np.isin(train_idx, self.design.trialsdf.index)):
192+
raise IndexError('Not all train indices in the trials of design matrix')
193+
194+
# Store training and test indices for self so that .score() method will know what to
195+
# operate on. If all data indices are in train indices, train and test are the same set.
196+
self.traininds = train_idx
197+
if not np.all(np.isin(self.design.trialsdf.index, train_idx)):
198+
self.testinds = self.design.trialsdf.index[~self.trialsdf.index.isin(train_idx)]
199+
else:
200+
self.testinds = train_idx
201+
209202
# Mask for training data
210-
trainmask = np.isin(self.design.trlabels, self.traininds).flatten()
203+
trainmask = np.isin(self.design.trlabels, train_idx).flatten()
211204
trainbinned = self.binnedspikes[trainmask]
212205
if printcond:
213206
print(f'Condition of design matrix is {np.linalg.cond(self.design[trainmask])}')
@@ -217,6 +210,31 @@ def fit(self, printcond=True):
217210
self.coefs, self.intercepts = coefs, intercepts
218211
return
219212

213+
def score(self, testinds=None):
214+
"""
215+
Score model using chosen metric
216+
217+
Returns
218+
-------
219+
pandas.Series
220+
Score using chosen metric (defined at instantiation) for each unit fit by the model.
221+
"""
222+
if not hasattr(self, 'coefs'):
223+
raise AttributeError('Model has not been fit yet.')
224+
if testinds is None:
225+
testinds = self.testinds
226+
testmask = np.isin(self.design.trlabels, testinds).flatten()
227+
dm, binned = self.design[testmask, :], self.binnedspikes[testmask]
228+
229+
scores = pd.Series(index=self.coefs.index, name='scores')
230+
for cell in self.coefs.index:
231+
cell_idx = np.argwhere(self.clu_ids == cell)[0, 0]
232+
wt = self.coefs.loc[cell].reshape(-1, 1)
233+
bias = self.intercepts.loc[cell]
234+
y = binned[:, cell_idx]
235+
scores.at[cell] = self._scorer(wt, bias, dm, y)
236+
return scores
237+
220238
def binf(self, t):
221239
"""
222240
Bin function for a given timestep. Returns the number of bins after trial start a given t

0 commit comments

Comments
 (0)