Skip to content

Commit 169a809

Browse files
committed
Merge branch 'develop' into olivier
2 parents 208aa6c + d61ee27 commit 169a809

File tree

13 files changed

+617
-348
lines changed

13 files changed

+617
-348
lines changed

CITATION.cff

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
cff-version: 0.0.0
2+
message: "If you use this software, please cite it as below."
3+
authors:
4+
- family-names: International Brain Laboratory
5+
given-names: The
6+
orcid:
7+
title: "ibllib"
8+
version:
9+
doi:
10+
date-released: 2021-12-09
11+
url: "https://github.com/int-brain-lab/ibllib"

brainbox/behavior/wheel.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,7 @@
44
import numpy as np
55
from numpy import pi
66
import scipy.interpolate as interpolate
7-
from scipy.signal import convolve, gaussian
7+
from scipy.signal import convolve, windows
88
from scipy.linalg import hankel
99
import matplotlib.pyplot as plt
1010
from matplotlib.collections import LineCollection
@@ -114,7 +114,7 @@ def velocity_smoothed(pos, freq, smooth_size=0.03):
114114
std_samps = np.round(smooth_size * freq) # Standard deviation relative to sampling frequency
115115
N = std_samps * 6 # Number of points in the Gaussian covering +/-3 standard deviations
116116
gauss_std = (N - 1) / 6
117-
win = gaussian(N, gauss_std)
117+
win = windows.gaussian(N, gauss_std)
118118
win = win / win.sum() # Normalize amplitude
119119

120120
# Convolve and multiply by sampling frequency to restore original units
@@ -274,7 +274,7 @@ def movements(t, pos, freq=1000, pos_thresh=8, t_thresh=.2, min_gap=.1, pos_thre
274274
peak_amps = np.fromiter(peaks, dtype=float, count=onsets.size)
275275
N = 10 # Number of points in the Gaussian
276276
STDEV = 1.8 # Equivalent to a width factor (alpha value) of 2.5
277-
gauss = gaussian(N, STDEV) # A 10-point Gaussian window of a given s.d.
277+
gauss = windows.gaussian(N, STDEV) # A 10-point Gaussian window of a given s.d.
278278
vel = convolve(np.diff(np.insert(pos, 0, 0)), gauss, mode='same')
279279
# For each movement period, find the timestamp where the absolute velocity was greatest
280280
peaks = (t[m + np.abs(vel[m:n]).argmax()] for m, n in zip(onset_samps, offset_samps))

brainbox/modeling/design_matrix.py

Lines changed: 64 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class DesignMatrix:
1111
and allow the generation of a design matrix with specified regressors
1212
"""
1313

14-
def __init__(self, trialsdf, vartypes=None, binwidth=0.02):
14+
def __init__(self, trialsdf, vartypes, binwidth=0.02):
1515
"""
1616
Class for generating design matrices to model neural data. Provides handy routines for
1717
describing neural spiking activity using basis functions and other primitives.
@@ -31,7 +31,7 @@ def __init__(self, trialsdf, vartypes=None, binwidth=0.02):
3131
3232
Obligatory columns for the dataframe are "trial_start" and "trial_end", which tell the
3333
constructor which time points to associate with that trial.
34-
vartypes : dict, optional
34+
vartypes : dict
3535
Dictionary of types for each of the columns in trialsdf. Columns must be of the types:
3636
-- timing: timing events, in which the column values are times since the start of the
3737
session of an event within that trial, e.g. stimulus onset.
@@ -41,46 +41,44 @@ def __init__(self, trialsdf, vartypes=None, binwidth=0.02):
4141
changes within the trial. e.g. pupil diameter.
4242
Dictionary keys should be columns in trialsdf, values should be strings that are equal
4343
to one of the above.
44-
45-
If vartypes is not passed, the constructor will assume you know what you are doing. Be
46-
warned that this can result in the class failing in spectacular and vindictive ways.
47-
by default None
4844
binwidth : float, optional
4945
Length of time bins which will be used for design matrix, by default 0.02
5046
"""
5147
# Data checks #
52-
if vartypes is not None:
53-
validtypes = ('timing', 'continuous', 'value')
54-
if not all([name in vartypes for name in trialsdf.columns]):
55-
raise KeyError("Some columns were not described in vartypes")
56-
if not all([value in validtypes for value in vartypes.values()]):
57-
raise ValueError("Invalid values were passed in vartypes")
48+
validtypes = ('timing', 'continuous', 'value')
49+
if not all([name in vartypes for name in trialsdf.columns]):
50+
raise KeyError("Some columns were not described in vartypes")
51+
if not all([value in validtypes for value in vartypes.values()]):
52+
raise ValueError("Invalid values were passed in vartypes")
5853

5954
# Filter out cells which don't meet the criteria for minimum spiking, while doing trial
6055
# assignment
61-
self.vartypes = vartypes
62-
if vartypes is not None:
63-
self.vartypes['duration'] = 'value'
56+
vartypes['duration'] = 'value'
6457
base_df = trialsdf.copy()
6558
trialsdf = trialsdf.copy() # Make sure we don't modify the original dataframe
6659
trbounds = trialsdf[['trial_start', 'trial_end']] # Get the start/end of trials
6760
# Empty trial duration value to use later
6861
trialsdf['duration'] = np.nan
62+
# Figure out which columns are timing variables if vartypes was passed
6963
timingvars = [col for col in trialsdf.columns if vartypes[col] == 'timing']
64+
7065
for i, (start, end) in trbounds.iterrows():
7166
if any(np.isnan((start, end))):
7267
warn(f"NaN values found in trial start or end at trial number {i}. "
7368
"Discarding trial.")
7469
trialsdf.drop(i, inplace=True)
7570
continue
7671
for col in timingvars:
72+
# Round values for the timing variables to the 5th decimal place and subtract
73+
# trial start time.
7774
trialsdf.at[i, col] = np.round(trialsdf.at[i, col] - start, decimals=5)
7875
trialsdf.at[i, 'duration'] = end - start
7976

8077
# Set model parameters to begin with
8178
self.binwidth = binwidth
8279
self.covar = {}
8380
self.trialsdf = trialsdf
81+
self.vartypes = vartypes
8482
self.base_df = base_df
8583
self.compiled = False
8684
return
@@ -155,7 +153,7 @@ def add_covariate_timing(self, covlabel, eventname, bases,
155153
else:
156154
raise TypeError('deltaval must be None, pandas series, or string reference'
157155
f' to trialsdf column. {type(deltaval)} was passed instead.')
158-
if eventname in self.vartypes and self.vartypes[eventname] != 'timing':
156+
if self.vartypes[eventname] != 'timing':
159157
raise TypeError(f'Column {eventname} in trialsdf is not registered as a timing')
160158

161159
vecsizes = self.trialsdf['duration'].apply(self.binf)
@@ -174,6 +172,33 @@ def add_covariate_timing(self, covlabel, eventname, bases,
174172

175173
def add_covariate_boxcar(self, covlabel, boxstart, boxend,
176174
cond=None, height=None, desc=''):
175+
"""
176+
Convenience wrapped on add_covariate to add a boxcar covariate on the given start and end
177+
variables, such that the covariate is a step function with non-zero value between those
178+
values.
179+
180+
Note: This has not been tested yet and is not guaranteed to work, or work correctly.
181+
182+
Parameters
183+
----------
184+
covlabel : str
185+
Name of the covariate for accessing later. Can be accessed via dot syntax of the
186+
instance usually.
187+
boxstart : str
188+
Column name in trialsdf which will be used to define the start of the boxcar
189+
boxend : str
190+
Column name in trialsdf which defines the end of boxcar variable
191+
cond : None, list, or func, optional
192+
Condition in which to apply this covariate. Can either be a list of trial indices, or
193+
a function which takes in a row of the trialsdf and returns a boolen on inclusion,
194+
by default None
195+
height : None, str, or pandas series, optional
196+
Values for the height of the boxcar during the period defined per trial. Can be a
197+
reference to a column in trialsdf or a separate series, by default None
198+
desc : str, optional
199+
Additional information about the covariate to store as a string, by default ''
200+
201+
"""
177202
if covlabel in self.covar:
178203
raise AttributeError(f'Covariate {covlabel} already exists in model.')
179204
self._compile_check()
@@ -210,6 +235,27 @@ def add_covariate_boxcar(self, covlabel, boxstart, boxend,
210235

211236
def add_covariate_raw(self, covlabel, raw,
212237
cond=None, desc=''):
238+
"""
239+
Convenience wrapper to add a 'raw' covariate, that is to say a covariate which is a
240+
continuous value that changes with time during the course of a trial.
241+
242+
Note: This has not been tested and is not guaranteed to work or to work correctly.
243+
244+
Parameters
245+
----------
246+
covlabel : str
247+
String used to reference covariate, can usually be accessed by instance's dot syntax
248+
raw : str, func, or pandas series
249+
The covariate to add to the design matrix. Can be a str reference to a column in
250+
trialsdf, a function which takes in rows of trialsdf and produces a vector for each
251+
row of the appropriate size given binwidth and trial duration, or a pandas series
252+
of vectors of said appropriate type.
253+
cond : None, list, or func, optional
254+
Trials in which to apply the given covariate. Can be a list of trial numbers,
255+
or a function which accepts rows of the trialsdf and returns a boolean, by default None
256+
desc : str, optional
257+
Additional information about the covariate for access later, by default ''
258+
"""
213259
stimlens = self.trialsdf.duration.apply(self.binf)
214260
if isinstance(raw, str):
215261
if raw not in self.trialsdf.columns:
@@ -354,7 +400,6 @@ def compile_design_matrix(self, dense=True):
354400
assert self.binnedspikes.shape[0] == dm.shape[0], "Oh shit. Indexing error."
355401
self.dm = dm
356402
self.trlabels = trlabels
357-
# self.dm = np.roll(dm, -1, axis=0) # Fix weird +1 offset bug in design matrix
358403
self.compiled = True
359404
return
360405

@@ -384,7 +429,7 @@ def denseconv(X, bases):
384429
A = np.zeros((T + TB - 1, int(np.sum(indices[kCov, :]))))
385430
for i, j in enumerate(np.argwhere(indices[kCov, :]).flat):
386431
A[:, i] = np.convolve(X[:, kCov], bases[:, j])
387-
BX[:, k: sI[kCov]] = A[: T, :]
432+
BX[:, k: sI[kCov]] = A[:T, :]
388433
k = sI[kCov]
389434
return BX
390435

@@ -400,5 +445,5 @@ def convbasis(stim, bases, offset=0):
400445
if offset < 0:
401446
X = X[-offset:, :]
402447
elif offset > 0:
403-
X = X[: -(1 + offset), :]
448+
X = X[:-offset, :]
404449
return X

brainbox/modeling/linear.py

Lines changed: 4 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
class LinearGLM(NeuralModel):
1818
def __init__(self, design_matrix, spk_times, spk_clu,
1919
binwidth=0.02, metric='rsq', estimator=None,
20-
train=0.8, blocktrain=False, mintrials=100):
20+
mintrials=100):
2121
"""
2222
Fit a linear model using a DesignMatrix object and spike data. Can use ridge regression
2323
or pure linear regression
@@ -48,13 +48,15 @@ def __init__(self, design_matrix, spk_times, spk_clu,
4848
fitting, by default 100
4949
"""
5050
super().__init__(design_matrix, spk_times, spk_clu,
51-
binwidth, train, blocktrain, mintrials)
51+
binwidth, mintrials)
5252
if estimator is None:
5353
estimator = LinearRegression()
5454
if not isinstance(estimator, BaseEstimator):
5555
raise ValueError('Estimator must be a scikit-learn estimator, e.g. LinearRegression')
5656
self.metric = metric
5757
self.estimator = estimator
58+
self.link = lambda x: x
59+
self.invlink = self.link
5860

5961
def _fit(self, dm, binned, cells=None):
6062
"""
@@ -94,26 +96,3 @@ def _fit(self, dm, binned, cells=None):
9496
coefs.at[cell] = weight[cell_idx, :]
9597
intercepts.at[cell] = intercept[cell_idx]
9698
return coefs, intercepts
97-
98-
def score(self):
99-
"""
100-
Score model using chosen metric
101-
102-
Returns
103-
-------
104-
pandas.Series
105-
Score using chosen metric (defined at instantiation) for each unit fit by the model.
106-
"""
107-
if not hasattr(self, 'coefs'):
108-
raise AttributeError('Model has not been fit yet.')
109-
testmask = np.isin(self.design.trlabels, self.testinds).flatten()
110-
dm, binned = self.design[testmask, :], self.binnedspikes[testmask]
111-
112-
scores = pd.Series(index=self.coefs.index, name='scores')
113-
for cell in self.coefs.index:
114-
cell_idx = np.argwhere(self.clu_ids == cell)[0, 0]
115-
wt = self.coefs.loc[cell].reshape(-1, 1)
116-
bias = self.intercepts.loc[cell]
117-
y = binned[:, cell_idx]
118-
scores.at[cell] = self._scorer(wt, bias, dm, y)
119-
return scores

0 commit comments

Comments
 (0)