Skip to content

Commit 093161c

Browse files
committed
Better support for undecided trials
1 parent 551a3bc commit 093161c

File tree

2 files changed

+56
-24
lines changed

2 files changed

+56
-24
lines changed

pyddm/sample.py

Lines changed: 41 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,11 @@ def err(self):
183183
@accepts(NDArray(d=2), List(String), Tuple(String, String))
184184
@returns(Self)
185185
@requires('data.shape[1] >= 2')
186-
@requires('set(list(data[:,1])) - {0, 1} == set()')
187-
@requires('all(data[:,0].astype("float") == data[:,0])')
186+
@requires('set(list(data[~np.isnan(data).any(axis=1)][:,1])) - {0, 1} == set()')
187+
@requires('all(data[~np.isnan(data).any(axis=1)][:,0].astype("float") == data[~np.isnan(data).any(axis=1)][:,0])')
188188
@requires('data.shape[1] - 2 == len(column_names)')
189189
@ensures('len(column_names) == len(return.condition_names())')
190-
def from_numpy_array(data, column_names, choice_names=("correct", "error")):
190+
def from_numpy_array(data, column_names=[], choice_names=("correct", "error")):
191191
"""Generate a Sample object from a numpy array.
192192
193193
`data` should be an n x m array (n rows, m columns) where m>=2. The
@@ -203,6 +203,11 @@ def from_numpy_array(data, column_names, choice_names=("correct", "error")):
203203
correspond to the order of the columns. This function does not yet
204204
work with undecided trials.
205205
"""
206+
assert len(column_names) == data.shape[1] - 2, "Invalid number of column names for conditions"
207+
undecided = np.isnan(data[:,0]) & np.isnan(data[:,1])
208+
undecided_data = data[undecided]
209+
data = data[~undecided]
210+
assert not np.any(np.isnan(data[:,0:2])), "First two columns must be either both nan (for undecided trials) not neither nan"
206211
c = data[:,1].astype(bool)
207212
nc = (1-data[:,1]).astype(bool)
208213
def pt(x): # Pythonic types
@@ -219,17 +224,16 @@ def pt(x): # Pythonic types
219224
pass
220225
return arr
221226

222-
conditions = {k: (pt(data[c,i+2]), pt(data[nc,i+2]), np.asarray([])) for i,k in enumerate(column_names)}
223-
return Sample(pt(data[c,0]), pt(data[nc,0]), 0, **conditions)
227+
conditions = {k: (pt(data[c,i+2]), pt(data[nc,i+2]), pt(undecided_data[:,i+2])) for i,k in enumerate(column_names)}
228+
return Sample(pt(data[c,0]), pt(data[nc,0]), undecided_data.shape[0], **conditions)
224229
@staticmethod
225230
@accepts(Unchecked, String, Maybe(String), Unchecked, Maybe(String)) # TODO change unchecked to pandas
226231
@returns(Self)
227232
@requires('df.shape[1] >= 2')
228233
@requires('rt_column_name in df')
229234
@requires('choice_column_name in df or correct_column_name in df')
230-
@requires('not np.any(df.isnull())')
231-
@requires('len(np.setdiff1d(df[choice_column_name if choice_column_name is not None else correct_column_name], [0, 1])) == 0')
232-
@requires('all(df[rt_column_name].astype("float") == df[rt_column_name])')
235+
@requires('len(np.setdiff1d(df[~df.isna().any(axis=1)][choice_column_name if choice_column_name is not None else correct_column_name], [0, 1])) == 0')
236+
@requires('all(df[~df.isna().any(axis=1)][rt_column_name].astype("float") == df[~df.isna().any(axis=1)][rt_column_name])')
233237
@ensures('len(df) == len(return)')
234238
def from_pandas_dataframe(df, rt_column_name, choice_column_name=None, choice_names=("correct", "error"), correct_column_name=None):
235239
"""Generate a Sample object from a pandas dataframe.
@@ -254,16 +258,20 @@ def from_pandas_dataframe(df, rt_column_name, choice_column_name=None, choice_na
254258
"""
255259
if len(df) == 0:
256260
_logger.warning("Empty DataFrame")
257-
if np.mean(df[rt_column_name]) > 50:
258-
_logger.warning("RTs should be specified in seconds, not milliseconds")
259-
for _,col in df.items():
260-
if len(df) > 0 and isinstance(col.iloc[0], (list, np.ndarray)):
261-
raise ValueError("Conditions should not be lists or ndarrays. Please convert to a tuple instead.")
262261
if choice_column_name is None:
263262
assert correct_column_name is not None
264263
assert choice_names == ("correct", "error")
265264
choice_column_name = correct_column_name
266265
deprecation_warning("the choice_column_name argument")
266+
undecided_rows = df[rt_column_name].isna() & df[choice_column_name].isna()
267+
df_undecided = df[undecided_rows]
268+
df = df[~undecided_rows]
269+
assert not df[[choice_column_name,rt_column_name]].isna().any().any(), "Undecided trials must have nan for both RT and choice, and all other rows must not have nans for these columns"
270+
if np.mean(df[rt_column_name]) > 50:
271+
_logger.warning("RTs should be specified in seconds, not milliseconds")
272+
for _,col in df.items():
273+
if len(df) > 0 and isinstance(col.iloc[0], (list, np.ndarray)):
274+
raise ValueError("Conditions should not be lists or ndarrays. Please convert to a tuple instead.")
267275
assert np.all(np.isin(df[choice_column_name], [0, 1, True, False])), "Choice must be specified as True/False or 0/1"
268276
c = df[choice_column_name].astype(bool)
269277
nc = (1-df[choice_column_name]).astype(bool)
@@ -282,8 +290,8 @@ def pt(x): # Pythonic types
282290
return arr
283291

284292
column_names = [e for e in df.columns if not e in [rt_column_name, choice_column_name]]
285-
conditions = {k: (pt(df[c][k]), pt(df[nc][k]), np.asarray([])) for k in column_names}
286-
return Sample(pt(df[c][rt_column_name]), pt(df[nc][rt_column_name]), 0, choice_names=choice_names, **conditions)
293+
conditions = {k: (pt(df[c][k]), pt(df[nc][k]), pt(df_undecided[k])) for k in column_names}
294+
return Sample(pt(df[c][rt_column_name]), pt(df[nc][rt_column_name]), len(df_undecided), choice_names=choice_names, **conditions)
287295
def to_pandas_dataframe(self, rt_column_name='RT', choice_column_name='choice', drop_undecided=False, correct_column_name=None):
288296
"""Convert the sample to a Pandas dataframe.
289297
@@ -303,9 +311,10 @@ def to_pandas_dataframe(self, rt_column_name='RT', choice_column_name='choice',
303311
choice_column_name = correct_column_name
304312
import pandas
305313
all_trials = []
306-
if self.undecided != 0 and drop_undecided is False:
307-
raise ValueError("The sample object has undecided trials. These do not have an RT or a P(correct), so they cannot be converted to a data frame. Please use the 'drop_undecided' flag when calling this function.")
308314
conditions = list(self.condition_names())
315+
if self.undecided != 0 and drop_undecided is False:
316+
for trial in self.items("undecided"):
317+
all_trials.append([np.nan, np.nan] + [trial[1][c] for c in conditions])
309318
columns = [choice_column_name, rt_column_name] + conditions
310319
for trial in self.items("_top"):
311320
all_trials.append([1, trial[0]] + [trial[1][c] for c in conditions])
@@ -315,10 +324,11 @@ def to_pandas_dataframe(self, rt_column_name='RT', choice_column_name='choice',
315324
def items(self, choice=None, correct=None):
316325
"""Iterate through the reaction times.
317326
318-
`choice` is whether to iterate through RTs corresponding to the upper
319-
or lower boundary, given as the name of the choice, e.g. "correct",
327+
`choice` is whether to iterate through RTs corresponding to the upper or
328+
lower boundary, given as the name of the choice, e.g. "correct",
320329
"error", or the choice names specified in the model's choice_names
321-
parameter.
330+
parameter. This can also be "undecided" to iterate through undecided
331+
trials.
322332
323333
`correct` is a deprecated parameter for backward compatibility, please
324334
use `choice` instead.
@@ -329,6 +339,7 @@ def items(self, choice=None, correct=None):
329339
330340
If you just want the list of RTs, you can directly iterate
331341
through "sample.corr" and "sample.err".
342+
332343
"""
333344
if correct is not None:
334345
assert choice is None, "Either choice or correct argument must be None"
@@ -337,7 +348,10 @@ def items(self, choice=None, correct=None):
337348
use_choice_upper = correct
338349
else:
339350
assert choice is not None, "Choice and correct arguments cannot both be None"
340-
use_choice_upper = (self._choice_name_to_id(choice) == 1)
351+
if choice == "undecided":
352+
use_choice_upper = "undecided"
353+
else:
354+
use_choice_upper = (self._choice_name_to_id(choice) == 1)
341355
return _Sample_Iter_Wraper(self, use_choice_upper=use_choice_upper)
342356
@accepts(Self)
343357
@returns(Self)
@@ -456,7 +470,7 @@ def t_domain(dt=.01, T_dur=2):
456470
return np.linspace(0, T_dur, int(T_dur/dt)+1)
457471

458472
@accepts(Self, Choice)
459-
@returns(Set([1, 2]))
473+
@returns(Set([1, 2, -1]))
460474
def _choice_name_to_id(self, choice):
461475
"""Get an ID from the choice name.
462476
@@ -694,12 +708,15 @@ def __init__(self, sample_obj, use_choice_upper):
694708
self.sample = sample_obj
695709
self.i = 0
696710
self.use_choice_upper = use_choice_upper
697-
if self.use_choice_upper:
711+
if self.use_choice_upper is True:
698712
self.rt = self.sample.choice_upper
699713
self.ind = 0
700-
elif not self.use_choice_upper:
714+
elif self.use_choice_upper is False:
701715
self.rt = self.sample.choice_lower
702716
self.ind = 1
717+
elif self.use_choice_upper == "undecided":
718+
self.rt = [np.nan]*self.sample.undecided
719+
self.ind = 2
703720
def __iter__(self):
704721
return self
705722
def __next__(self):

unit_tests.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -201,6 +201,8 @@ def test_OverlayNone(self):
201201
assert s == ddm.models.OverlayNone().apply(s)
202202
s = self.FakePointModel().solve()
203203
assert s == ddm.models.OverlayNone().apply(s)
204+
s = self.FakeUndecidedModel().solve()
205+
assert s == ddm.models.OverlayNone().apply(s)
204206
def test_OverlayUniformMixture(self):
205207
"""Uniform mixture model overlay: a uniform distribution plus the model's solved distribution"""
206208
# Do nothing with 0 probability
@@ -570,13 +572,20 @@ def test_subset(self):
570572
assert len(self.samps['two'].subset(conda=["a", "z"])) == 2
571573
# Query by function
572574
assert len(self.samps['two'].subset(conda=lambda x : True if x=="a" else False)) == 2
575+
# Undecided
576+
assert len(self.samps['undeccond'].subset(cond1=2)) == 3
573577
def test_from_numpy_array(self):
574578
"""Create a sample from a numpy array"""
575579
simple_ndarray = np.asarray([[1, 1], [.5, 0], [.7, 0], [2, 1]])
576580
assert ddm.Sample.from_numpy_array(simple_ndarray, []) == self.samps['simple']
577581
conds_ndarray = np.asarray([[1, 1, 1], [2, 1, 1], [3, 1, 2]])
578582
assert ddm.Sample.from_numpy_array(conds_ndarray, ["cond1"]) == self.samps['conds']
579583
assert ddm.Sample.from_numpy_array(conds_ndarray, ["cond1"]) == self.samps['condsexp']
584+
# Undecided trials
585+
conds_ndarray = np.asarray([[np.nan, np.nan, 1], [2, 1, 3], [3, 1, 2]])
586+
samp = ddm.Sample.from_numpy_array(conds_ndarray, column_names=["cond1"])
587+
assert samp.undecided == 1, "One undecided trial"
588+
assert list(samp.items("undecided"))[0][1]['cond1'] == 1
580589
def test_from_pandas(self):
581590
"""Create a sample from a pandas dataframe"""
582591
simple_df = pandas.DataFrame({'corr': [1, 0, 0, 1], 'resptime': [1, .5, .7, 2]})
@@ -589,6 +598,12 @@ def test_from_pandas(self):
589598
assert ddm.Sample.from_pandas_dataframe(cond_df, choice_column_name='c', rt_column_name='rt', choice_names=("c", "d")) != self.samps['condsexp']
590599
condsstr_df = pandas.DataFrame({'c': [1, 1, 1], 'rt': [1, 2, 3], 'cond1': ["x", "yy", "z z z"]})
591600
assert ddm.Sample.from_pandas_dataframe(condsstr_df, 'rt', 'c', choice_names=("x", "Y with space")) == self.samps['condsstr']
601+
# Undecided
602+
df = pandas.DataFrame([[1, True, "x"], [2, False, "x"], [3, True, "z"], [np.nan, np.nan, "y"], [np.nan,np.nan, "z"]], columns=["RT", "choice", "cond"])
603+
samp = ddm.Sample.from_pandas_dataframe(df, choice_column_name='choice', rt_column_name='RT')
604+
assert samp.undecided == 2
605+
assert len(samp.choice_upper) == 2
606+
assert samp == ddm.Sample.from_pandas_dataframe(samp.to_pandas_dataframe(), choice_column_name='choice', rt_column_name='RT')
592607
def test_to_pandas(self):
593608
for sname,s in self.samps.items():
594609
if s.undecided == 0:

0 commit comments

Comments
 (0)