@@ -11,7 +11,7 @@ class DesignMatrix:
1111 and allow the generation of a design matrix with specified regressors
1212 """
1313
14- def __init__ (self , trialsdf , vartypes = None , binwidth = 0.02 ):
14+ def __init__ (self , trialsdf , vartypes , binwidth = 0.02 ):
1515 """
1616 Class for generating design matrices to model neural data. Provides handy routines for
1717 describing neural spiking activity using basis functions and other primitives.
@@ -31,7 +31,7 @@ def __init__(self, trialsdf, vartypes=None, binwidth=0.02):
3131
3232 Obligatory columns for the dataframe are "trial_start" and "trial_end", which tell the
3333 constructor which time points to associate with that trial.
34- vartypes : dict, optional
34+ vartypes : dict
3535 Dictionary of types for each of the columns in trialsdf. Columns must be of the types:
3636 -- timing: timing events, in which the column values are times since the start of the
3737 session of an event within that trial, e.g. stimulus onset.
@@ -41,46 +41,44 @@ def __init__(self, trialsdf, vartypes=None, binwidth=0.02):
4141 changes within the trial. e.g. pupil diameter.
4242 Dictionary keys should be columns in trialsdf, values should be strings that are equal
4343 to one of the above.
44-
45- If vartypes is not passed, the constructor will assume you know what you are doing. Be
46- warned that this can result in the class failing in spectacular and vindictive ways.
47- by default None
4844 binwidth : float, optional
4945 Length of time bins which will be used for design matrix, by default 0.02
5046 """
5147 # Data checks #
52- if vartypes is not None :
53- validtypes = ('timing' , 'continuous' , 'value' )
54- if not all ([name in vartypes for name in trialsdf .columns ]):
55- raise KeyError ("Some columns were not described in vartypes" )
56- if not all ([value in validtypes for value in vartypes .values ()]):
57- raise ValueError ("Invalid values were passed in vartypes" )
48+ validtypes = ('timing' , 'continuous' , 'value' )
49+ if not all ([name in vartypes for name in trialsdf .columns ]):
50+ raise KeyError ("Some columns were not described in vartypes" )
51+ if not all ([value in validtypes for value in vartypes .values ()]):
52+ raise ValueError ("Invalid values were passed in vartypes" )
5853
5954 # Filter out cells which don't meet the criteria for minimum spiking, while doing trial
6055 # assignment
61- self .vartypes = vartypes
62- if vartypes is not None :
63- self .vartypes ['duration' ] = 'value'
56+ vartypes ['duration' ] = 'value'
6457 base_df = trialsdf .copy ()
6558 trialsdf = trialsdf .copy () # Make sure we don't modify the original dataframe
6659 trbounds = trialsdf [['trial_start' , 'trial_end' ]] # Get the start/end of trials
6760 # Empty trial duration value to use later
6861 trialsdf ['duration' ] = np .nan
62+ # Figure out which columns are timing variables if vartypes was passed
6963 timingvars = [col for col in trialsdf .columns if vartypes [col ] == 'timing' ]
64+
7065 for i , (start , end ) in trbounds .iterrows ():
7166 if any (np .isnan ((start , end ))):
7267 warn (f"NaN values found in trial start or end at trial number { i } . "
7368 "Discarding trial." )
7469 trialsdf .drop (i , inplace = True )
7570 continue
7671 for col in timingvars :
72+ # Round values for the timing variables to the 5th decimal place and subtract
73+ # trial start time.
7774 trialsdf .at [i , col ] = np .round (trialsdf .at [i , col ] - start , decimals = 5 )
7875 trialsdf .at [i , 'duration' ] = end - start
7976
8077 # Set model parameters to begin with
8178 self .binwidth = binwidth
8279 self .covar = {}
8380 self .trialsdf = trialsdf
81+ self .vartypes = vartypes
8482 self .base_df = base_df
8583 self .compiled = False
8684 return
@@ -155,7 +153,7 @@ def add_covariate_timing(self, covlabel, eventname, bases,
155153 else :
156154 raise TypeError ('deltaval must be None, pandas series, or string reference'
157155 f' to trialsdf column. { type (deltaval )} was passed instead.' )
158- if eventname in self . vartypes and self .vartypes [eventname ] != 'timing' :
156+ if self .vartypes [eventname ] != 'timing' :
159157 raise TypeError (f'Column { eventname } in trialsdf is not registered as a timing' )
160158
161159 vecsizes = self .trialsdf ['duration' ].apply (self .binf )
@@ -174,6 +172,33 @@ def add_covariate_timing(self, covlabel, eventname, bases,
174172
175173 def add_covariate_boxcar (self , covlabel , boxstart , boxend ,
176174 cond = None , height = None , desc = '' ):
175+ """
176+ Convenience wrapped on add_covariate to add a boxcar covariate on the given start and end
177+ variables, such that the covariate is a step function with non-zero value between those
178+ values.
179+
180+ Note: This has not been tested yet and is not guaranteed to work, or work correctly.
181+
182+ Parameters
183+ ----------
184+ covlabel : str
185+ Name of the covariate for accessing later. Can be accessed via dot syntax of the
186+ instance usually.
187+ boxstart : str
188+ Column name in trialsdf which will be used to define the start of the boxcar
189+ boxend : str
190+ Column name in trialsdf which defines the end of boxcar variable
191+ cond : None, list, or func, optional
192+ Condition in which to apply this covariate. Can either be a list of trial indices, or
193+ a function which takes in a row of the trialsdf and returns a boolen on inclusion,
194+ by default None
195+ height : None, str, or pandas series, optional
196+ Values for the height of the boxcar during the period defined per trial. Can be a
197+ reference to a column in trialsdf or a separate series, by default None
198+ desc : str, optional
199+ Additional information about the covariate to store as a string, by default ''
200+
201+ """
177202 if covlabel in self .covar :
178203 raise AttributeError (f'Covariate { covlabel } already exists in model.' )
179204 self ._compile_check ()
@@ -210,6 +235,27 @@ def add_covariate_boxcar(self, covlabel, boxstart, boxend,
210235
211236 def add_covariate_raw (self , covlabel , raw ,
212237 cond = None , desc = '' ):
238+ """
239+ Convenience wrapper to add a 'raw' covariate, that is to say a covariate which is a
240+ continuous value that changes with time during the course of a trial.
241+
242+ Note: This has not been tested and is not guaranteed to work or to work correctly.
243+
244+ Parameters
245+ ----------
246+ covlabel : str
247+ String used to reference covariate, can usually be accessed by instance's dot syntax
248+ raw : str, func, or pandas series
249+ The covariate to add to the design matrix. Can be a str reference to a column in
250+ trialsdf, a function which takes in rows of trialsdf and produces a vector for each
251+ row of the appropriate size given binwidth and trial duration, or a pandas series
252+ of vectors of said appropriate type.
253+ cond : None, list, or func, optional
254+ Trials in which to apply the given covariate. Can be a list of trial numbers,
255+ or a function which accepts rows of the trialsdf and returns a boolean, by default None
256+ desc : str, optional
257+ Additional information about the covariate for access later, by default ''
258+ """
213259 stimlens = self .trialsdf .duration .apply (self .binf )
214260 if isinstance (raw , str ):
215261 if raw not in self .trialsdf .columns :
@@ -354,7 +400,6 @@ def compile_design_matrix(self, dense=True):
354400 assert self .binnedspikes .shape [0 ] == dm .shape [0 ], "Oh shit. Indexing error."
355401 self .dm = dm
356402 self .trlabels = trlabels
357- # self.dm = np.roll(dm, -1, axis=0) # Fix weird +1 offset bug in design matrix
358403 self .compiled = True
359404 return
360405
@@ -384,7 +429,7 @@ def denseconv(X, bases):
384429 A = np .zeros ((T + TB - 1 , int (np .sum (indices [kCov , :]))))
385430 for i , j in enumerate (np .argwhere (indices [kCov , :]).flat ):
386431 A [:, i ] = np .convolve (X [:, kCov ], bases [:, j ])
387- BX [:, k : sI [kCov ]] = A [: T , :]
432+ BX [:, k : sI [kCov ]] = A [:T , :]
388433 k = sI [kCov ]
389434 return BX
390435
@@ -400,5 +445,5 @@ def convbasis(stim, bases, offset=0):
400445 if offset < 0 :
401446 X = X [- offset :, :]
402447 elif offset > 0 :
403- X = X [: - ( 1 + offset ) , :]
448+ X = X [:- offset , :]
404449 return X
0 commit comments