Skip to content

Commit 7e95c88

Browse files
committed
Merge branch 'clusterless-with-integers'
2 parents 5997b32 + 9da5214 commit 7e95c88

14 files changed

+1699
-248
lines changed

.gitignore

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -59,5 +59,4 @@ coverage.xml
5959
*.lock
6060
*.dirlock
6161
*.nc
62-
63-
!fra_11_04_0001.gif
62+
*.prof

notebooks/tutorial/03-Decoding_with_Clusterless_Spikes.ipynb

Lines changed: 48 additions & 34 deletions
Large diffs are not rendered by default.

notebooks/tutorial/05-Classifying_with_Clusterless_Spikes.ipynb

Lines changed: 31 additions & 20 deletions
Large diffs are not rendered by default.

replay_trajectory_classification/classifier.py

Lines changed: 98 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,14 @@
1717
from replay_trajectory_classification.misc import NumbaKDE
1818
from replay_trajectory_classification.multiunit_likelihood import (
1919
estimate_multiunit_likelihood, fit_multiunit_likelihood)
20+
from replay_trajectory_classification.multiunit_likelihood_integer import (
21+
estimate_multiunit_likelihood_integer, fit_multiunit_likelihood_integer)
22+
from replay_trajectory_classification.multiunit_likelihood_integer_no_dask import (
23+
estimate_multiunit_likelihood_integer_no_dask,
24+
fit_multiunit_likelihood_integer_no_dask)
25+
from replay_trajectory_classification.multiunit_likelihood_integer_pass_position import (
26+
estimate_multiunit_likelihood_integer_pass_position,
27+
fit_multiunit_likelihood_integer_pass_position)
2028
from replay_trajectory_classification.spiking_likelihood import (
2129
estimate_place_fields, estimate_spiking_likelihood)
2230
from replay_trajectory_classification.state_transition import (
@@ -27,8 +35,28 @@
2735

2836
sklearn.set_config(print_changed_only=False)
2937

30-
_DEFAULT_CLUSTERLESS_MODEL_KWARGS = dict(
31-
bandwidth=np.array([24.0, 24.0, 24.0, 24.0, 6.0, 6.0]))
38+
_DEFAULT_CLUSTERLESS_MODEL_KWARGS = {
39+
'model': NumbaKDE,
40+
'model_kwargs': {
41+
'bandwidth': np.array([24.0, 24.0, 24.0, 24.0, 6.0, 6.0])
42+
}
43+
}
44+
45+
_ClUSTERLESS_ALGORITHMS = {
46+
'multiunit_likelihood': (
47+
fit_multiunit_likelihood,
48+
estimate_multiunit_likelihood),
49+
'multiunit_likelihood_integer': (
50+
fit_multiunit_likelihood_integer,
51+
estimate_multiunit_likelihood_integer),
52+
'multiunit_likelihood_integer_no_dask': (
53+
fit_multiunit_likelihood_integer_no_dask,
54+
estimate_multiunit_likelihood_integer_no_dask),
55+
'multiunit_likelihood_integer_pass_position': (
56+
fit_multiunit_likelihood_integer_pass_position,
57+
estimate_multiunit_likelihood_integer_pass_position),
58+
}
59+
3260
_DEFAULT_CONTINUOUS_TRANSITIONS = (
3361
[['random_walk', 'uniform', 'identity'],
3462
['uniform', 'uniform', 'uniform'],
@@ -37,7 +65,7 @@
3765

3866

3967
class _ClassifierBase(BaseEstimator):
40-
def __init__(self, place_bin_size=2.0, replay_speed=40, movement_var=0.05,
68+
def __init__(self, place_bin_size=2.0, replay_speed=1, movement_var=6.0,
4169
position_range=None,
4270
continuous_transition_types=_DEFAULT_CONTINUOUS_TRANSITIONS,
4371
discrete_transition_type='strong_diagonal',
@@ -255,7 +283,7 @@ class SortedSpikesClassifier(_ClassifierBase):
255283
256284
'''
257285

258-
def __init__(self, place_bin_size=2.0, replay_speed=40, movement_var=0.05,
286+
def __init__(self, place_bin_size=2.0, replay_speed=1, movement_var=6.0,
259287
position_range=None,
260288
continuous_transition_types=_DEFAULT_CONTINUOUS_TRANSITIONS,
261289
discrete_transition_type='strong_diagonal',
@@ -392,6 +420,13 @@ def predict(self, spikes, time=None, is_compute_acausal=True,
392420
393421
'''
394422
spikes = np.asarray(spikes)
423+
is_track_interior = self.is_track_interior_.ravel(order='F')
424+
n_time = spikes.shape[0]
425+
n_position_bins = is_track_interior.shape[0]
426+
n_states = self.discrete_state_transition_.shape[0]
427+
is_states = np.ones((n_states,), dtype=bool)
428+
st_interior_ind = np.ix_(
429+
is_states, is_states, is_track_interior, is_track_interior)
395430

396431
results = {}
397432

@@ -401,7 +436,7 @@ def predict(self, spikes, time=None, is_compute_acausal=True,
401436
spikes,
402437
np.asarray(self.place_fields_.sel(
403438
encoding_group=encoding_group)),
404-
self.is_track_interior_)
439+
is_track_interior)
405440

406441
results['likelihood'] = np.stack(
407442
[likelihood[encoding_group]
@@ -410,13 +445,20 @@ def predict(self, spikes, time=None, is_compute_acausal=True,
410445
results['likelihood'] = scaled_likelihood(
411446
results['likelihood'], axis=(1, 2))[..., np.newaxis]
412447

413-
results['causal_posterior'] = _causal_classify(
414-
self.initial_conditions_, self.continuous_state_transition_,
415-
self.discrete_state_transition_, results['likelihood'])
448+
results['causal_posterior'] = np.full(
449+
(n_time, n_states, n_position_bins, 1), np.nan)
450+
results['causal_posterior'][:, :, is_track_interior] = _causal_classify(
451+
self.initial_conditions_[:, is_track_interior],
452+
self.continuous_state_transition_[st_interior_ind],
453+
self.discrete_state_transition_,
454+
results['likelihood'][:, :, is_track_interior])
416455

417456
if is_compute_acausal:
418-
results['acausal_posterior'] = _acausal_classify(
419-
results['causal_posterior'], self.continuous_state_transition_,
457+
results['acausal_posterior'] = np.full(
458+
(n_time, n_states, n_position_bins, 1), np.nan)
459+
results['acausal_posterior'][:, :, is_track_interior] = _acausal_classify(
460+
results['causal_posterior'][:, :, is_track_interior],
461+
self.continuous_state_transition_[st_interior_ind],
420462
self.discrete_state_transition_)
421463

422464
n_time = spikes.shape[0]
@@ -465,31 +507,23 @@ class ClusterlessClassifier(_ClassifierBase):
465507
466508
'''
467509

468-
def __init__(self, place_bin_size=2.0, replay_speed=40, movement_var=0.05,
510+
def __init__(self, place_bin_size=2.0, replay_speed=1, movement_var=6.0,
469511
position_range=None,
470512
continuous_transition_types=_DEFAULT_CONTINUOUS_TRANSITIONS,
471513
discrete_transition_type='strong_diagonal',
472514
initial_conditions_type='uniform_on_track',
473515
discrete_transition_diag=_DISCRETE_DIAG,
474516
infer_track_interior=True,
475-
model=NumbaKDE,
476-
model_kwargs=_DEFAULT_CLUSTERLESS_MODEL_KWARGS,
477-
occupancy_model=None,
478-
occupancy_kwargs=None):
517+
clusterless_algorithm='multiunit_likelihood',
518+
clusterless_algorithm_params=_DEFAULT_CLUSTERLESS_MODEL_KWARGS
519+
):
479520
super().__init__(place_bin_size, replay_speed, movement_var,
480521
position_range, continuous_transition_types,
481522
discrete_transition_type, initial_conditions_type,
482523
discrete_transition_diag, infer_track_interior)
483524

484-
self.model = model
485-
self.model_kwargs = model_kwargs
486-
487-
if occupancy_model is None:
488-
self.occupancy_model = model
489-
self.occupancy_kwargs = model_kwargs
490-
else:
491-
self.occupancy_model = occupancy_model
492-
self.occupancy_kwargs = occupancy_kwargs
525+
self.clusterless_algorithm = clusterless_algorithm
526+
self.clusterless_algorithm_params = clusterless_algorithm_params
493527

494528
def fit_multiunits(self, position, multiunits, is_training=None,
495529
encoding_group_labels=None,
@@ -519,29 +553,25 @@ def fit_multiunits(self, position, multiunits, is_training=None,
519553
else:
520554
self.encoding_group_to_state_ = np.asarray(encoding_group_to_state)
521555

556+
kwargs = self.clusterless_algorithm_params
557+
if kwargs is None:
558+
kwargs = {}
559+
522560
is_training = np.asarray(is_training).squeeze()
523561

524-
self.joint_pdf_models_ = {}
525-
self.ground_process_intensities_ = {}
526-
self.occupancy_ = {}
527-
self.mean_rates_ = {}
562+
self.encoding_model_ = {}
528563

529564
for encoding_group in np.unique(encoding_group_labels[is_training]):
530565
is_group = is_training & (
531566
encoding_group == encoding_group_labels)
532-
(self.joint_pdf_models_[encoding_group],
533-
self.ground_process_intensities_[encoding_group],
534-
self.occupancy_[encoding_group],
535-
self.mean_rates_[encoding_group]
536-
) = fit_multiunit_likelihood(
537-
position[is_group],
538-
multiunits[is_group],
539-
self.place_bin_centers_,
540-
self.model,
541-
self.model_kwargs,
542-
self.occupancy_model,
543-
self.occupancy_kwargs,
544-
self.is_track_interior_.ravel(order='F'))
567+
self.encoding_model_[encoding_group] = _ClUSTERLESS_ALGORITHMS[
568+
self.clusterless_algorithm][0](
569+
position=position[is_group],
570+
multiunits=multiunits[is_group],
571+
place_bin_centers=self.place_bin_centers_,
572+
is_track_interior=self.is_track_interior_.ravel(order='F'),
573+
**kwargs
574+
)
545575

546576
def fit(self,
547577
position,
@@ -606,19 +636,25 @@ def predict(self, multiunits, time=None, is_compute_acausal=True,
606636
607637
'''
608638
multiunits = np.asarray(multiunits)
639+
is_track_interior = self.is_track_interior_.ravel(order='F')
640+
n_time = multiunits.shape[0]
641+
n_position_bins = is_track_interior.shape[0]
642+
n_states = self.discrete_state_transition_.shape[0]
643+
is_states = np.ones((n_states,), dtype=bool)
644+
st_interior_ind = np.ix_(
645+
is_states, is_states, is_track_interior, is_track_interior)
609646

610647
results = {}
611648

612649
likelihood = {}
613-
for encoding_group in self.joint_pdf_models_:
614-
likelihood[encoding_group] = estimate_multiunit_likelihood(
615-
multiunits,
616-
self.place_bin_centers_,
617-
self.joint_pdf_models_[encoding_group],
618-
self.ground_process_intensities_[encoding_group],
619-
self.occupancy_[encoding_group],
620-
self.mean_rates_[encoding_group],
621-
self.is_track_interior_.ravel(order='F'))
650+
for encoding_group, encoding_params in self.encoding_model_.items():
651+
likelihood[encoding_group] = _ClUSTERLESS_ALGORITHMS[
652+
self.clusterless_algorithm][1](
653+
multiunits=multiunits,
654+
place_bin_centers=self.place_bin_centers_,
655+
is_track_interior=is_track_interior,
656+
**encoding_params
657+
)
622658

623659
results['likelihood'] = np.stack(
624660
[likelihood[encoding_group]
@@ -627,17 +663,22 @@ def predict(self, multiunits, time=None, is_compute_acausal=True,
627663
results['likelihood'] = scaled_likelihood(
628664
results['likelihood'], axis=(1, 2))[..., np.newaxis]
629665

630-
results['causal_posterior'] = _causal_classify(
631-
self.initial_conditions_, self.continuous_state_transition_,
632-
self.discrete_state_transition_, results['likelihood'])
666+
results['causal_posterior'] = np.full(
667+
(n_time, n_states, n_position_bins, 1), np.nan)
668+
results['causal_posterior'][:, :, is_track_interior] = _causal_classify(
669+
self.initial_conditions_[:, is_track_interior],
670+
self.continuous_state_transition_[st_interior_ind],
671+
self.discrete_state_transition_,
672+
results['likelihood'][:, :, is_track_interior])
633673

634674
if is_compute_acausal:
635-
results['acausal_posterior'] = _acausal_classify(
636-
results['causal_posterior'], self.continuous_state_transition_,
675+
results['acausal_posterior'] = np.full(
676+
(n_time, n_states, n_position_bins, 1), np.nan)
677+
results['acausal_posterior'][:, :, is_track_interior] = _acausal_classify(
678+
results['causal_posterior'][:, :, is_track_interior],
679+
self.continuous_state_transition_[st_interior_ind],
637680
self.discrete_state_transition_)
638681

639-
n_time = multiunits.shape[0]
640-
641682
if time is None:
642683
time = np.arange(n_time)
643684

0 commit comments

Comments
 (0)