1717from replay_trajectory_classification .misc import NumbaKDE
1818from replay_trajectory_classification .multiunit_likelihood import (
1919 estimate_multiunit_likelihood , fit_multiunit_likelihood )
20+ from replay_trajectory_classification .multiunit_likelihood_integer import (
21+ estimate_multiunit_likelihood_integer , fit_multiunit_likelihood_integer )
22+ from replay_trajectory_classification .multiunit_likelihood_integer_no_dask import (
23+ estimate_multiunit_likelihood_integer_no_dask ,
24+ fit_multiunit_likelihood_integer_no_dask )
25+ from replay_trajectory_classification .multiunit_likelihood_integer_pass_position import (
26+ estimate_multiunit_likelihood_integer_pass_position ,
27+ fit_multiunit_likelihood_integer_pass_position )
2028from replay_trajectory_classification .spiking_likelihood import (
2129 estimate_place_fields , estimate_spiking_likelihood )
2230from replay_trajectory_classification .state_transition import (
2735
2836sklearn .set_config (print_changed_only = False )
2937
30- _DEFAULT_CLUSTERLESS_MODEL_KWARGS = dict (
31- bandwidth = np .array ([24.0 , 24.0 , 24.0 , 24.0 , 6.0 , 6.0 ]))
38+ _DEFAULT_CLUSTERLESS_MODEL_KWARGS = {
39+ 'model' : NumbaKDE ,
40+ 'model_kwargs' : {
41+ 'bandwidth' : np .array ([24.0 , 24.0 , 24.0 , 24.0 , 6.0 , 6.0 ])
42+ }
43+ }
44+
45+ _ClUSTERLESS_ALGORITHMS = {
46+ 'multiunit_likelihood' : (
47+ fit_multiunit_likelihood ,
48+ estimate_multiunit_likelihood ),
49+ 'multiunit_likelihood_integer' : (
50+ fit_multiunit_likelihood_integer ,
51+ estimate_multiunit_likelihood_integer ),
52+ 'multiunit_likelihood_integer_no_dask' : (
53+ fit_multiunit_likelihood_integer_no_dask ,
54+ estimate_multiunit_likelihood_integer_no_dask ),
55+ 'multiunit_likelihood_integer_pass_position' : (
56+ fit_multiunit_likelihood_integer_pass_position ,
57+ estimate_multiunit_likelihood_integer_pass_position ),
58+ }
59+
3260_DEFAULT_CONTINUOUS_TRANSITIONS = (
3361 [['random_walk' , 'uniform' , 'identity' ],
3462 ['uniform' , 'uniform' , 'uniform' ],
3765
3866
3967class _ClassifierBase (BaseEstimator ):
40- def __init__ (self , place_bin_size = 2.0 , replay_speed = 40 , movement_var = 0.05 ,
68+ def __init__ (self , place_bin_size = 2.0 , replay_speed = 1 , movement_var = 6.0 ,
4169 position_range = None ,
4270 continuous_transition_types = _DEFAULT_CONTINUOUS_TRANSITIONS ,
4371 discrete_transition_type = 'strong_diagonal' ,
@@ -255,7 +283,7 @@ class SortedSpikesClassifier(_ClassifierBase):
255283
256284 '''
257285
258- def __init__ (self , place_bin_size = 2.0 , replay_speed = 40 , movement_var = 0.05 ,
286+ def __init__ (self , place_bin_size = 2.0 , replay_speed = 1 , movement_var = 6.0 ,
259287 position_range = None ,
260288 continuous_transition_types = _DEFAULT_CONTINUOUS_TRANSITIONS ,
261289 discrete_transition_type = 'strong_diagonal' ,
@@ -392,6 +420,13 @@ def predict(self, spikes, time=None, is_compute_acausal=True,
392420
393421 '''
394422 spikes = np .asarray (spikes )
423+ is_track_interior = self .is_track_interior_ .ravel (order = 'F' )
424+ n_time = spikes .shape [0 ]
425+ n_position_bins = is_track_interior .shape [0 ]
426+ n_states = self .discrete_state_transition_ .shape [0 ]
427+ is_states = np .ones ((n_states ,), dtype = bool )
428+ st_interior_ind = np .ix_ (
429+ is_states , is_states , is_track_interior , is_track_interior )
395430
396431 results = {}
397432
@@ -401,7 +436,7 @@ def predict(self, spikes, time=None, is_compute_acausal=True,
401436 spikes ,
402437 np .asarray (self .place_fields_ .sel (
403438 encoding_group = encoding_group )),
404- self . is_track_interior_ )
439+ is_track_interior )
405440
406441 results ['likelihood' ] = np .stack (
407442 [likelihood [encoding_group ]
@@ -410,13 +445,20 @@ def predict(self, spikes, time=None, is_compute_acausal=True,
410445 results ['likelihood' ] = scaled_likelihood (
411446 results ['likelihood' ], axis = (1 , 2 ))[..., np .newaxis ]
412447
413- results ['causal_posterior' ] = _causal_classify (
414- self .initial_conditions_ , self .continuous_state_transition_ ,
415- self .discrete_state_transition_ , results ['likelihood' ])
448+ results ['causal_posterior' ] = np .full (
449+ (n_time , n_states , n_position_bins , 1 ), np .nan )
450+ results ['causal_posterior' ][:, :, is_track_interior ] = _causal_classify (
451+ self .initial_conditions_ [:, is_track_interior ],
452+ self .continuous_state_transition_ [st_interior_ind ],
453+ self .discrete_state_transition_ ,
454+ results ['likelihood' ][:, :, is_track_interior ])
416455
417456 if is_compute_acausal :
418- results ['acausal_posterior' ] = _acausal_classify (
419- results ['causal_posterior' ], self .continuous_state_transition_ ,
457+ results ['acausal_posterior' ] = np .full (
458+ (n_time , n_states , n_position_bins , 1 ), np .nan )
459+ results ['acausal_posterior' ][:, :, is_track_interior ] = _acausal_classify (
460+ results ['causal_posterior' ][:, :, is_track_interior ],
461+ self .continuous_state_transition_ [st_interior_ind ],
420462 self .discrete_state_transition_ )
421463
422464 n_time = spikes .shape [0 ]
@@ -465,31 +507,23 @@ class ClusterlessClassifier(_ClassifierBase):
465507
466508 '''
467509
468- def __init__ (self , place_bin_size = 2.0 , replay_speed = 40 , movement_var = 0.05 ,
510+ def __init__ (self , place_bin_size = 2.0 , replay_speed = 1 , movement_var = 6.0 ,
469511 position_range = None ,
470512 continuous_transition_types = _DEFAULT_CONTINUOUS_TRANSITIONS ,
471513 discrete_transition_type = 'strong_diagonal' ,
472514 initial_conditions_type = 'uniform_on_track' ,
473515 discrete_transition_diag = _DISCRETE_DIAG ,
474516 infer_track_interior = True ,
475- model = NumbaKDE ,
476- model_kwargs = _DEFAULT_CLUSTERLESS_MODEL_KWARGS ,
477- occupancy_model = None ,
478- occupancy_kwargs = None ):
517+ clusterless_algorithm = 'multiunit_likelihood' ,
518+ clusterless_algorithm_params = _DEFAULT_CLUSTERLESS_MODEL_KWARGS
519+ ):
479520 super ().__init__ (place_bin_size , replay_speed , movement_var ,
480521 position_range , continuous_transition_types ,
481522 discrete_transition_type , initial_conditions_type ,
482523 discrete_transition_diag , infer_track_interior )
483524
484- self .model = model
485- self .model_kwargs = model_kwargs
486-
487- if occupancy_model is None :
488- self .occupancy_model = model
489- self .occupancy_kwargs = model_kwargs
490- else :
491- self .occupancy_model = occupancy_model
492- self .occupancy_kwargs = occupancy_kwargs
525+ self .clusterless_algorithm = clusterless_algorithm
526+ self .clusterless_algorithm_params = clusterless_algorithm_params
493527
494528 def fit_multiunits (self , position , multiunits , is_training = None ,
495529 encoding_group_labels = None ,
@@ -519,29 +553,25 @@ def fit_multiunits(self, position, multiunits, is_training=None,
519553 else :
520554 self .encoding_group_to_state_ = np .asarray (encoding_group_to_state )
521555
556+ kwargs = self .clusterless_algorithm_params
557+ if kwargs is None :
558+ kwargs = {}
559+
522560 is_training = np .asarray (is_training ).squeeze ()
523561
524- self .joint_pdf_models_ = {}
525- self .ground_process_intensities_ = {}
526- self .occupancy_ = {}
527- self .mean_rates_ = {}
562+ self .encoding_model_ = {}
528563
529564 for encoding_group in np .unique (encoding_group_labels [is_training ]):
530565 is_group = is_training & (
531566 encoding_group == encoding_group_labels )
532- (self .joint_pdf_models_ [encoding_group ],
533- self .ground_process_intensities_ [encoding_group ],
534- self .occupancy_ [encoding_group ],
535- self .mean_rates_ [encoding_group ]
536- ) = fit_multiunit_likelihood (
537- position [is_group ],
538- multiunits [is_group ],
539- self .place_bin_centers_ ,
540- self .model ,
541- self .model_kwargs ,
542- self .occupancy_model ,
543- self .occupancy_kwargs ,
544- self .is_track_interior_ .ravel (order = 'F' ))
567+ self .encoding_model_ [encoding_group ] = _ClUSTERLESS_ALGORITHMS [
568+ self .clusterless_algorithm ][0 ](
569+ position = position [is_group ],
570+ multiunits = multiunits [is_group ],
571+ place_bin_centers = self .place_bin_centers_ ,
572+ is_track_interior = self .is_track_interior_ .ravel (order = 'F' ),
573+ ** kwargs
574+ )
545575
546576 def fit (self ,
547577 position ,
@@ -606,19 +636,25 @@ def predict(self, multiunits, time=None, is_compute_acausal=True,
606636
607637 '''
608638 multiunits = np .asarray (multiunits )
639+ is_track_interior = self .is_track_interior_ .ravel (order = 'F' )
640+ n_time = multiunits .shape [0 ]
641+ n_position_bins = is_track_interior .shape [0 ]
642+ n_states = self .discrete_state_transition_ .shape [0 ]
643+ is_states = np .ones ((n_states ,), dtype = bool )
644+ st_interior_ind = np .ix_ (
645+ is_states , is_states , is_track_interior , is_track_interior )
609646
610647 results = {}
611648
612649 likelihood = {}
613- for encoding_group in self .joint_pdf_models_ :
614- likelihood [encoding_group ] = estimate_multiunit_likelihood (
615- multiunits ,
616- self .place_bin_centers_ ,
617- self .joint_pdf_models_ [encoding_group ],
618- self .ground_process_intensities_ [encoding_group ],
619- self .occupancy_ [encoding_group ],
620- self .mean_rates_ [encoding_group ],
621- self .is_track_interior_ .ravel (order = 'F' ))
650+ for encoding_group , encoding_params in self .encoding_model_ .items ():
651+ likelihood [encoding_group ] = _ClUSTERLESS_ALGORITHMS [
652+ self .clusterless_algorithm ][1 ](
653+ multiunits = multiunits ,
654+ place_bin_centers = self .place_bin_centers_ ,
655+ is_track_interior = is_track_interior ,
656+ ** encoding_params
657+ )
622658
623659 results ['likelihood' ] = np .stack (
624660 [likelihood [encoding_group ]
@@ -627,17 +663,22 @@ def predict(self, multiunits, time=None, is_compute_acausal=True,
627663 results ['likelihood' ] = scaled_likelihood (
628664 results ['likelihood' ], axis = (1 , 2 ))[..., np .newaxis ]
629665
630- results ['causal_posterior' ] = _causal_classify (
631- self .initial_conditions_ , self .continuous_state_transition_ ,
632- self .discrete_state_transition_ , results ['likelihood' ])
666+ results ['causal_posterior' ] = np .full (
667+ (n_time , n_states , n_position_bins , 1 ), np .nan )
668+ results ['causal_posterior' ][:, :, is_track_interior ] = _causal_classify (
669+ self .initial_conditions_ [:, is_track_interior ],
670+ self .continuous_state_transition_ [st_interior_ind ],
671+ self .discrete_state_transition_ ,
672+ results ['likelihood' ][:, :, is_track_interior ])
633673
634674 if is_compute_acausal :
635- results ['acausal_posterior' ] = _acausal_classify (
636- results ['causal_posterior' ], self .continuous_state_transition_ ,
675+ results ['acausal_posterior' ] = np .full (
676+ (n_time , n_states , n_position_bins , 1 ), np .nan )
677+ results ['acausal_posterior' ][:, :, is_track_interior ] = _acausal_classify (
678+ results ['causal_posterior' ][:, :, is_track_interior ],
679+ self .continuous_state_transition_ [st_interior_ind ],
637680 self .discrete_state_transition_ )
638681
639- n_time = multiunits .shape [0 ]
640-
641682 if time is None :
642683 time = np .arange (n_time )
643684
0 commit comments