@@ -425,60 +425,65 @@ def clonal_nn(
425425 None or AnnData
426426 Updates `adata` in place unless `copy=True`, in which case a new AnnData object is returned.
427427 """
428- from scipy .sparse import csr_matrix
428+ from scipy .sparse import csr_matrix , coo_matrix
429429 import pynndescent
430430
431- # Removing clones with small size
431+ adata_to_update = adata .copy () if copy else adata
432+
433+ # 1. Preprocessing (vectorized)
432434 clonal_obs = adata .obs [obs_name ].copy ()
433435 clones_counts = clonal_obs .value_counts ()
434436 small_clones = clones_counts [clones_counts < min_size ].index
435- clonal_obs = pd .Series ([
436- clone if clone not in small_clones else non_clonal_str for clone in clonal_obs
437- ]).astype (str ).astype ("category" )
438437
439- var_mapping = dict (zip (
440- clonal_obs .cat .categories [clonal_obs .cat .categories != non_clonal_str ],
441- range (len (clonal_obs .cat .categories [clonal_obs .cat .categories != non_clonal_str ])),
442- ))
443-
444- train = adata [clonal_obs != non_clonal_str ].obsm [use_rep ]
445- obs_col = clonal_obs [clonal_obs != non_clonal_str ].astype (str )
446- obs_col .index = range (len (obs_col ))
447- test = adata .obsm [use_rep ]
448- index = pynndescent .NNDescent (train , random_state = random_state , ** kwargs )
438+ clonal_obs [clonal_obs .isin (small_clones )] = non_clonal_str
439+ clonal_obs = clonal_obs .astype ("category" )
440+
441+ # 2. Prepare data for kNN
442+ valid_clones = clonal_obs .cat .categories [clonal_obs .cat .categories != non_clonal_str ]
443+ n_clones = len (valid_clones )
444+ clone_to_int = pd .Series (np .arange (n_clones ), index = valid_clones )
445+
446+ is_clonal = (clonal_obs != non_clonal_str ).to_numpy ()
447+ train_data = adata .obsm [use_rep ][is_clonal ]
448+
449+ if train_data .shape [0 ] == 0 :
450+ adata_to_update .obsm [obsm_name ] = csr_matrix ((adata .shape [0 ], n_clones ))
451+ adata_to_update .uns [obsm_name + "_names" ] = valid_clones .to_list ()
452+ return adata_to_update if copy else None
453+
454+ # 3. Build and query the kNN index
455+ index = pynndescent .NNDescent (train_data , random_state = random_state , ** kwargs )
449456 index .prepare ()
450- neighbors = index .query (test , k = k )[0 ]
451457
452- col_ind = []
453- row_ind = []
454- data = []
458+ test_data = adata .obsm [use_rep ]
459+ neighbors_indices , _ = index .query (test_data , k = k )
455460
456- for i in (tqdm (range (len (neighbors ))) if tqdm_bar else range (len (neighbors ))):
457- nn = obs_col [neighbors [i ]].value_counts ()
458- nn = nn [nn > 0 ]
459- col_ind += [var_mapping [var ] for var in nn .index ]
460- row_ind += [i ] * len (nn )
461- data += list (nn .values )
461+ # 4. Vectorized aggregation to build the sparse matrix
462+ train_labels = clonal_obs [is_clonal ]
463+ train_labels_encoded = clone_to_int [train_labels ].to_numpy ()
462464
463- if copy :
464- adata = adata .copy ()
465+ neighbor_labels_encoded = train_labels_encoded [neighbors_indices ]
465466
466- adata .obsm [obsm_name ] = csr_matrix ((data , (row_ind , col_ind )))
467- adata .uns [obsm_name + "_names" ] = list (var_mapping .keys ())
467+ n_test = test_data .shape [0 ]
468+ row_ind = np .repeat (np .arange (n_test ), k )
469+ col_ind = neighbor_labels_encoded .flatten ()
470+ data = np .ones (n_test * k , dtype = np .float32 )
471+
472+ bag_of_clones_matrix = coo_matrix (
473+ (data , (row_ind , col_ind )), shape = (n_test , n_clones )
474+ ).tocsr ()
475+
476+ # 5. Update the AnnData object
477+ adata_to_update .obsm [obsm_name ] = bag_of_clones_matrix
478+ adata_to_update .uns [f"{ obsm_name } _names" ] = valid_clones .to_list ()
468479
469- adata_clonal = sc .AnnData (
470- X = csr_matrix ((data , (row_ind , col_ind ))),
471- obs = pd .DataFrame (index = adata .obs_names ),
472- var = pd .DataFrame (index = list (var_mapping .keys ())),
473- )
474-
475480 if copy :
476- return adata
481+ return adata_to_update
477482
478483
479484def clone2vec (
480485 adata : AnnData ,
481- obs_name : str ,
486+ obs_name : str = "clone" ,
482487 z_dim : int = 10 ,
483488 n_epochs : int = 100 ,
484489 batch_size : int = 64 ,
@@ -489,6 +494,8 @@ def clone2vec(
489494 obsm_key : str = "clone2vec" ,
490495 uns_key : str = "clone2vec_mean_loss" ,
491496 random_state : None | int = 4 ,
497+ early_stopping_patience : int = 5 ,
498+ early_stopping_min_delta : float = 1e-4 ,
492499) -> AnnData :
493500 """
494501 Learn a clonal embedding using a SkipGram model and return the resulting clone embeddings.
@@ -575,6 +582,9 @@ def clone2vec(
575582 criterion = nn .NLLLoss ()
576583
577584 epochs_mean_loss = []
585+ best_loss = np .inf
586+ patience_counter = 0
587+
578588 for epoch in (tqdm (range (n_epochs )) if tqdm_bar else range (n_epochs )):
579589 losses = []
580590 for batch_idx , data in enumerate (train_loader ):
@@ -591,13 +601,30 @@ def clone2vec(
591601 optimizer .step ()
592602
593603 losses .append (loss .item ())
594- epochs_mean_loss .append (np .mean (losses ))
604+
605+ current_loss = np .mean (losses )
606+ epochs_mean_loss .append (current_loss )
607+
608+ if current_loss < best_loss - early_stopping_min_delta :
609+ best_loss = current_loss
610+ patience_counter = 0
611+ else :
612+ patience_counter += 1
613+
614+ if patience_counter >= early_stopping_patience :
615+ if tqdm_bar :
616+ print (f"\n Early stopping triggered at epoch { epoch + 1 } ." )
617+ break
595618
596619 clone2vec = model .embedding .weight .data .cpu ().numpy ()
620+ if not (fill_ct is None ) and not (fill_ct in adata .obs .columns ):
621+ print (f"{ fill_ct } isn't in the `adata.obs.columns`. Keeping `clones.X` empty." )
622+ fill_ct = None
623+
597624 if not (fill_ct is None ):
598625 cell_counts = adata_only_clones .obs .groupby (
599626 [fill_ct , obs_name ]
600- ).size ().unstack ()[adata_only_clones .uns [f"{ obsm_name } _names" ]]
627+ ).size ().unstack (fill_value = 0 )[adata_only_clones .uns [f"{ obsm_name } _names" ]]
601628
602629 var_names = list (cell_counts .index )
603630 obs_names = list (cell_counts .columns )
0 commit comments