Skip to content

Commit f768f24

Browse files
committed
Updates
1 parent 741ab52 commit f768f24

File tree

1 file changed

+66
-39
lines changed

1 file changed

+66
-39
lines changed

sclitr/tl.py

Lines changed: 66 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -425,60 +425,65 @@ def clonal_nn(
425425
None or AnnData
426426
Updates `adata` in place unless `copy=True`, in which case a new AnnData object is returned.
427427
"""
428-
from scipy.sparse import csr_matrix
428+
from scipy.sparse import csr_matrix, coo_matrix
429429
import pynndescent
430430

431-
# Removing clones with small size
431+
adata_to_update = adata.copy() if copy else adata
432+
433+
# 1. Preprocessing (vectorized)
432434
clonal_obs = adata.obs[obs_name].copy()
433435
clones_counts = clonal_obs.value_counts()
434436
small_clones = clones_counts[clones_counts < min_size].index
435-
clonal_obs = pd.Series([
436-
clone if clone not in small_clones else non_clonal_str for clone in clonal_obs
437-
]).astype(str).astype("category")
438437

439-
var_mapping = dict(zip(
440-
clonal_obs.cat.categories[clonal_obs.cat.categories != non_clonal_str],
441-
range(len(clonal_obs.cat.categories[clonal_obs.cat.categories != non_clonal_str])),
442-
))
443-
444-
train = adata[clonal_obs != non_clonal_str].obsm[use_rep]
445-
obs_col = clonal_obs[clonal_obs != non_clonal_str].astype(str)
446-
obs_col.index = range(len(obs_col))
447-
test = adata.obsm[use_rep]
448-
index = pynndescent.NNDescent(train, random_state=random_state, **kwargs)
438+
clonal_obs[clonal_obs.isin(small_clones)] = non_clonal_str
439+
clonal_obs = clonal_obs.astype("category")
440+
441+
# 2. Prepare data for kNN
442+
valid_clones = clonal_obs.cat.categories[clonal_obs.cat.categories != non_clonal_str]
443+
n_clones = len(valid_clones)
444+
clone_to_int = pd.Series(np.arange(n_clones), index=valid_clones)
445+
446+
is_clonal = (clonal_obs != non_clonal_str).to_numpy()
447+
train_data = adata.obsm[use_rep][is_clonal]
448+
449+
if train_data.shape[0] == 0:
450+
adata_to_update.obsm[obsm_name] = csr_matrix((adata.shape[0], n_clones))
451+
adata_to_update.uns[obsm_name + "_names"] = valid_clones.to_list()
452+
return adata_to_update if copy else None
453+
454+
# 3. Build and query the kNN index
455+
index = pynndescent.NNDescent(train_data, random_state=random_state, **kwargs)
449456
index.prepare()
450-
neighbors = index.query(test, k=k)[0]
451457

452-
col_ind = []
453-
row_ind = []
454-
data = []
458+
test_data = adata.obsm[use_rep]
459+
neighbors_indices, _ = index.query(test_data, k=k)
455460

456-
for i in (tqdm(range(len(neighbors))) if tqdm_bar else range(len(neighbors))):
457-
nn = obs_col[neighbors[i]].value_counts()
458-
nn = nn[nn > 0]
459-
col_ind += [var_mapping[var] for var in nn.index]
460-
row_ind += [i] * len(nn)
461-
data += list(nn.values)
461+
# 4. Vectorized aggregation to build the sparse matrix
462+
train_labels = clonal_obs[is_clonal]
463+
train_labels_encoded = clone_to_int[train_labels].to_numpy()
462464

463-
if copy:
464-
adata = adata.copy()
465+
neighbor_labels_encoded = train_labels_encoded[neighbors_indices]
465466

466-
adata.obsm[obsm_name] = csr_matrix((data, (row_ind, col_ind)))
467-
adata.uns[obsm_name + "_names"] = list(var_mapping.keys())
467+
n_test = test_data.shape[0]
468+
row_ind = np.repeat(np.arange(n_test), k)
469+
col_ind = neighbor_labels_encoded.flatten()
470+
data = np.ones(n_test * k, dtype=np.float32)
471+
472+
bag_of_clones_matrix = coo_matrix(
473+
(data, (row_ind, col_ind)), shape=(n_test, n_clones)
474+
).tocsr()
475+
476+
# 5. Update the AnnData object
477+
adata_to_update.obsm[obsm_name] = bag_of_clones_matrix
478+
adata_to_update.uns[f"{obsm_name}_names"] = valid_clones.to_list()
468479

469-
adata_clonal = sc.AnnData(
470-
X=csr_matrix((data, (row_ind, col_ind))),
471-
obs=pd.DataFrame(index=adata.obs_names),
472-
var=pd.DataFrame(index=list(var_mapping.keys())),
473-
)
474-
475480
if copy:
476-
return adata
481+
return adata_to_update
477482

478483

479484
def clone2vec(
480485
adata: AnnData,
481-
obs_name: str,
486+
obs_name: str = "clone",
482487
z_dim: int = 10,
483488
n_epochs: int = 100,
484489
batch_size: int = 64,
@@ -489,6 +494,8 @@ def clone2vec(
489494
obsm_key: str = "clone2vec",
490495
uns_key: str = "clone2vec_mean_loss",
491496
random_state: None | int = 4,
497+
early_stopping_patience: int = 5,
498+
early_stopping_min_delta: float = 1e-4,
492499
) -> AnnData:
493500
"""
494501
Learn a clonal embedding using a SkipGram model and return the resulting clone embeddings.
@@ -575,6 +582,9 @@ def clone2vec(
575582
criterion = nn.NLLLoss()
576583

577584
epochs_mean_loss = []
585+
best_loss = np.inf
586+
patience_counter = 0
587+
578588
for epoch in (tqdm(range(n_epochs)) if tqdm_bar else range(n_epochs)):
579589
losses = []
580590
for batch_idx, data in enumerate(train_loader):
@@ -591,13 +601,30 @@ def clone2vec(
591601
optimizer.step()
592602

593603
losses.append(loss.item())
594-
epochs_mean_loss.append(np.mean(losses))
604+
605+
current_loss = np.mean(losses)
606+
epochs_mean_loss.append(current_loss)
607+
608+
if current_loss < best_loss - early_stopping_min_delta:
609+
best_loss = current_loss
610+
patience_counter = 0
611+
else:
612+
patience_counter += 1
613+
614+
if patience_counter >= early_stopping_patience:
615+
if tqdm_bar:
616+
print(f"\nEarly stopping triggered at epoch {epoch + 1}.")
617+
break
595618

596619
clone2vec = model.embedding.weight.data.cpu().numpy()
620+
if not (fill_ct is None) and not (fill_ct in adata.obs.columns):
621+
print(f"{fill_ct} isn't in the `adata.obs.columns`. Keeping `clones.X` empty.")
622+
fill_ct = None
623+
597624
if not (fill_ct is None):
598625
cell_counts = adata_only_clones.obs.groupby(
599626
[fill_ct, obs_name]
600-
).size().unstack()[adata_only_clones.uns[f"{obsm_name}_names"]]
627+
).size().unstack(fill_value=0)[adata_only_clones.uns[f"{obsm_name}_names"]]
601628

602629
var_names = list(cell_counts.index)
603630
obs_names = list(cell_counts.columns)

0 commit comments

Comments
 (0)