Skip to content

Commit da2b7e8

Browse files
committed
flake
1 parent 1af65c2 commit da2b7e8

File tree

3 files changed

+6
-10
lines changed

3 files changed

+6
-10
lines changed

brainbox/behavior/training.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -430,15 +430,13 @@ def compute_performance(trials, signed_contrast=None, block=None, prob_right=Fal
430430

431431
contrasts, n_contrasts = np.unique(signed_contrast[block_idx], return_counts=True)
432432

433-
rightward = trials.choice == -1
434-
# Calculate the proportion rightward for each contrast type
435-
performance = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) &
436-
block_idx]))(contrasts)
437-
438433
if not prob_right:
439434
correct = trials.feedbackType == 1
440-
performance = np.vectorize(lambda x: np.mean(correct[(x == signed_contrast) &
441-
block_idx]))(contrasts)
435+
performance = np.vectorize(lambda x: np.mean(correct[(x == signed_contrast) & block_idx]))(contrasts)
436+
else:
437+
rightward = trials.choice == -1
438+
# Calculate the proportion rightward for each contrast type
439+
performance = np.vectorize(lambda x: np.mean(rightward[(x == signed_contrast) & block_idx]))(contrasts)
442440

443441
return performance, contrasts, n_contrasts
444442

brainbox/task/trials.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -328,5 +328,3 @@ def filter_trials(trials, event_raster, event, contrast=(1, 0.5, 0.25, 0.125, 0.
328328
raster, psth = filter_correct_incorrect_left_right(trials, event_raster, event, contrast, order)
329329

330330
return raster, psth
331-
332-

brainbox/tests/test_trials.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -167,4 +167,4 @@ def test_get_event_aligned_rasters(self):
167167
raster, t = get_event_aligned_raster(spikes, use_trials)
168168
assert (raster.shape[0] == len(use_trials))
169169
assert (all(np.isnan(raster[0:2, :]).ravel()))
170-
assert (all(np.isnan(raster[-5:, :]).ravel()))
170+
assert (all(np.isnan(raster[-5:, :]).ravel()))

0 commit comments

Comments
 (0)