Skip to content

Commit 7490f47

Browse files
committed
Realized implicit vartypes was a terrible idea.
1 parent cb62b15 commit 7490f47

File tree

1 file changed

+14
-35
lines changed

1 file changed

+14
-35
lines changed

brainbox/modeling/design_matrix.py

Lines changed: 14 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ class DesignMatrix:
1111
and allow the generation of a design matrix with specified regressors
1212
"""
1313

14-
def __init__(self, trialsdf, vartypes=None, binwidth=0.02):
14+
def __init__(self, trialsdf, vartypes, binwidth=0.02):
1515
"""
1616
Class for generating design matrices to model neural data. Provides handy routines for
1717
describing neural spiking activity using basis functions and other primitives.
@@ -31,7 +31,7 @@ def __init__(self, trialsdf, vartypes=None, binwidth=0.02):
3131
3232
Obligatory columns for the dataframe are "trial_start" and "trial_end", which tell the
3333
constructor which time points to associate with that trial.
34-
vartypes : dict, optional
34+
vartypes : dict
3535
Dictionary of types for each of the columns in trialsdf. Columns must be of the types:
3636
-- timing: timing events, in which the column values are times since the start of the
3737
session of an event within that trial, e.g. stimulus onset.
@@ -41,38 +41,25 @@ def __init__(self, trialsdf, vartypes=None, binwidth=0.02):
4141
changes within the trial. e.g. pupil diameter.
4242
Dictionary keys should be columns in trialsdf, values should be strings that are equal
4343
to one of the above.
44-
45-
If vartypes is not passed, the constructor will assume you know what you are doing. Be
46-
warned that this can result in the class failing in spectacular and vindictive ways.
47-
by default None
4844
binwidth : float, optional
4945
Length of time bins which will be used for design matrix, by default 0.02
5046
"""
5147
# Data checks #
52-
if vartypes is not None:
53-
validtypes = ('timing', 'continuous', 'value')
54-
if not all([name in vartypes for name in trialsdf.columns]):
55-
raise KeyError("Some columns were not described in vartypes")
56-
if not all([value in validtypes for value in vartypes.values()]):
57-
raise ValueError("Invalid values were passed in vartypes")
48+
validtypes = ('timing', 'continuous', 'value')
49+
if not all([name in vartypes for name in trialsdf.columns]):
50+
raise KeyError("Some columns were not described in vartypes")
51+
if not all([value in validtypes for value in vartypes.values()]):
52+
raise ValueError("Invalid values were passed in vartypes")
5853

5954
# Filter out cells which don't meet the criteria for minimum spiking, while doing trial
6055
# assignment
61-
self.vartypes = vartypes if vartypes is not None else {}
62-
if vartypes is not None:
63-
self.vartypes['duration'] = 'value'
64-
base_df = trialsdf.copy()
65-
trialsdf = trialsdf.copy() # Make sure we don't modify the original dataframe
56+
vartypes['duration'] = 'value'
57+
base_df = trialsdf.copy() # Make sure we don't modify the original dataframe
6658
trbounds = trialsdf[['trial_start', 'trial_end']] # Get the start/end of trials
6759
# Empty trial duration value to use later
6860
trialsdf['duration'] = np.nan
6961
# Figure out which columns are timing variables if vartypes was passed
70-
if vartypes is not None:
71-
timingvars = [col for col in trialsdf.columns if vartypes[col] == 'timing']
72-
self.timingsub = {x: True if x in timingvars else False for x in trialsdf.columns}
73-
else:
74-
timingvars = []
75-
self.timingsub = {x: False for x in trialsdf.columns}
62+
timingvars = [col for col in trialsdf.columns if vartypes[col] == 'timing']
7663

7764
for i, (start, end) in trbounds.iterrows():
7865
if any(np.isnan((start, end))):
@@ -90,6 +77,7 @@ def __init__(self, trialsdf, vartypes=None, binwidth=0.02):
9077
self.binwidth = binwidth
9178
self.covar = {}
9279
self.trialsdf = trialsdf
80+
self.vartypes = vartypes
9381
self.base_df = base_df
9482
self.compiled = False
9583
return
@@ -164,18 +152,9 @@ def add_covariate_timing(self, covlabel, eventname, bases,
164152
else:
165153
raise TypeError('deltaval must be None, pandas series, or string reference'
166154
f' to trialsdf column. {type(deltaval)} was passed instead.')
167-
if eventname in self.vartypes and self.vartypes[eventname] != 'timing':
155+
if self.vartypes[eventname] != 'timing':
168156
raise TypeError(f'Column {eventname} in trialsdf is not registered as a timing')
169157

170-
if not self.timingsub[eventname]:
171-
self.timingsub[eventname] = True
172-
col = eventname
173-
for i, (start, end) in self.trialsdf[['trial_start', 'trial_end']].iterrows():
174-
# Round values for the timing variables to the 5th decimal place and subtract
175-
# trial start time.
176-
self.trialsdf.at[i, col] = np.round(self.trialsdf.at[i, col] - start, decimals=5)
177-
self.trialsdf.at[i, 'duration'] = end - start
178-
179158
vecsizes = self.trialsdf['duration'].apply(self.binf)
180159
stiminds = self.trialsdf[eventname].apply(self.binf)
181160
stimvecs = []
@@ -224,10 +203,10 @@ def add_covariate_boxcar(self, covlabel, boxstart, boxend,
224203
self._compile_check()
225204
if boxstart not in self.trialsdf.columns or boxend not in self.trialsdf.columns:
226205
raise KeyError('boxstart or boxend not found in trialsdf columns.')
227-
if boxstart in self.vartypes and self.vartypes[boxstart] != 'timing':
206+
if self.vartypes[boxstart] != 'timing':
228207
raise TypeError(f'Column {boxstart} in trialsdf is not registered as a timing. '
229208
'boxstart and boxend need to refer to timing events in trialsdf.')
230-
if boxend in self.vartypes and self.vartypes[boxend] != 'timing':
209+
if self.vartypes[boxend] != 'timing':
231210
raise TypeError(f'Column {boxend} in trialsdf is not registered as a timing. '
232211
'boxstart and boxend need to refer to timing events in trialsdf.')
233212

0 commit comments

Comments
 (0)