@@ -11,7 +11,7 @@ class DesignMatrix:
1111 and allow the generation of a design matrix with specified regressors
1212 """
1313
14- def __init__ (self , trialsdf , vartypes = None , binwidth = 0.02 ):
14+ def __init__ (self , trialsdf , vartypes , binwidth = 0.02 ):
1515 """
1616 Class for generating design matrices to model neural data. Provides handy routines for
1717 describing neural spiking activity using basis functions and other primitives.
@@ -31,7 +31,7 @@ def __init__(self, trialsdf, vartypes=None, binwidth=0.02):
3131
3232 Obligatory columns for the dataframe are "trial_start" and "trial_end", which tell the
3333 constructor which time points to associate with that trial.
34- vartypes : dict, optional
34+ vartypes : dict
3535 Dictionary of types for each of the columns in trialsdf. Columns must be of the types:
3636 -- timing: timing events, in which the column values are times since the start of the
3737 session of an event within that trial, e.g. stimulus onset.
@@ -41,38 +41,25 @@ def __init__(self, trialsdf, vartypes=None, binwidth=0.02):
4141 changes within the trial. e.g. pupil diameter.
4242 Dictionary keys should be columns in trialsdf, values should be strings that are equal
4343 to one of the above.
44-
45- If vartypes is not passed, the constructor will assume you know what you are doing. Be
46- warned that this can result in the class failing in spectacular and vindictive ways.
47- by default None
4844 binwidth : float, optional
4945 Length of time bins which will be used for design matrix, by default 0.02
5046 """
5147 # Data checks #
52- if vartypes is not None :
53- validtypes = ('timing' , 'continuous' , 'value' )
54- if not all ([name in vartypes for name in trialsdf .columns ]):
55- raise KeyError ("Some columns were not described in vartypes" )
56- if not all ([value in validtypes for value in vartypes .values ()]):
57- raise ValueError ("Invalid values were passed in vartypes" )
48+ validtypes = ('timing' , 'continuous' , 'value' )
49+ if not all ([name in vartypes for name in trialsdf .columns ]):
50+ raise KeyError ("Some columns were not described in vartypes" )
51+ if not all ([value in validtypes for value in vartypes .values ()]):
52+ raise ValueError ("Invalid values were passed in vartypes" )
5853
5954 # Filter out cells which don't meet the criteria for minimum spiking, while doing trial
6055 # assignment
61- self .vartypes = vartypes if vartypes is not None else {}
62- if vartypes is not None :
63- self .vartypes ['duration' ] = 'value'
64- base_df = trialsdf .copy ()
65- trialsdf = trialsdf .copy () # Make sure we don't modify the original dataframe
56+ vartypes ['duration' ] = 'value'
57+ base_df = trialsdf .copy () # Make sure we don't modify the original dataframe
6658 trbounds = trialsdf [['trial_start' , 'trial_end' ]] # Get the start/end of trials
6759 # Empty trial duration value to use later
6860 trialsdf ['duration' ] = np .nan
6961 # Figure out which columns are timing variables if vartypes was passed
70- if vartypes is not None :
71- timingvars = [col for col in trialsdf .columns if vartypes [col ] == 'timing' ]
72- self .timingsub = {x : True if x in timingvars else False for x in trialsdf .columns }
73- else :
74- timingvars = []
75- self .timingsub = {x : False for x in trialsdf .columns }
62+ timingvars = [col for col in trialsdf .columns if vartypes [col ] == 'timing' ]
7663
7764 for i , (start , end ) in trbounds .iterrows ():
7865 if any (np .isnan ((start , end ))):
@@ -90,6 +77,7 @@ def __init__(self, trialsdf, vartypes=None, binwidth=0.02):
9077 self .binwidth = binwidth
9178 self .covar = {}
9279 self .trialsdf = trialsdf
80+ self .vartypes = vartypes
9381 self .base_df = base_df
9482 self .compiled = False
9583 return
@@ -164,18 +152,9 @@ def add_covariate_timing(self, covlabel, eventname, bases,
164152 else :
165153 raise TypeError ('deltaval must be None, pandas series, or string reference'
166154 f' to trialsdf column. { type (deltaval )} was passed instead.' )
167- if eventname in self . vartypes and self .vartypes [eventname ] != 'timing' :
155+ if self .vartypes [eventname ] != 'timing' :
168156 raise TypeError (f'Column { eventname } in trialsdf is not registered as a timing' )
169157
170- if not self .timingsub [eventname ]:
171- self .timingsub [eventname ] = True
172- col = eventname
173- for i , (start , end ) in self .trialsdf [['trial_start' , 'trial_end' ]].iterrows ():
174- # Round values for the timing variables to the 5th decimal place and subtract
175- # trial start time.
176- self .trialsdf .at [i , col ] = np .round (self .trialsdf .at [i , col ] - start , decimals = 5 )
177- self .trialsdf .at [i , 'duration' ] = end - start
178-
179158 vecsizes = self .trialsdf ['duration' ].apply (self .binf )
180159 stiminds = self .trialsdf [eventname ].apply (self .binf )
181160 stimvecs = []
@@ -224,10 +203,10 @@ def add_covariate_boxcar(self, covlabel, boxstart, boxend,
224203 self ._compile_check ()
225204 if boxstart not in self .trialsdf .columns or boxend not in self .trialsdf .columns :
226205 raise KeyError ('boxstart or boxend not found in trialsdf columns.' )
227- if boxstart in self . vartypes and self .vartypes [boxstart ] != 'timing' :
206+ if self .vartypes [boxstart ] != 'timing' :
228207 raise TypeError (f'Column { boxstart } in trialsdf is not registered as a timing. '
229208 'boxstart and boxend need to refer to timing events in trialsdf.' )
230- if boxend in self . vartypes and self .vartypes [boxend ] != 'timing' :
209+ if self .vartypes [boxend ] != 'timing' :
231210 raise TypeError (f'Column { boxend } in trialsdf is not registered as a timing. '
232211 'boxstart and boxend need to refer to timing events in trialsdf.' )
233212
0 commit comments