Skip to content

Commit b888276

Browse files
committed
fix normal dist for ellipse
1 parent 71a9d2f commit b888276

File tree

3 files changed

+14
-5
lines changed

3 files changed

+14
-5
lines changed

src/polygraph/models.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,7 @@ def load_enformer():
4141
)
4242

4343

44-
def enformer_embed(sequences, model):
44+
def enformer_embed(sequences, model, device="cpu"):
4545
"""
4646
Embed a batch of sequences using pretrained or fine-tuned enformer
4747
@@ -52,6 +52,9 @@ def enformer_embed(sequences, model):
5252
Returns:
5353
np.array of shape (n_seqs x 3072)
5454
"""
55+
if isinstance(device, int):
56+
device = torch.device(device)
57+
sequences = str_to_one_hot(sequences).to(device)
5558
return model(sequences, return_only_embeddings=True).mean(1).cpu().detach().numpy()
5659

5760

@@ -128,7 +131,7 @@ def _get_embeddings(seqs, model, drop_last_layers=1, device="cpu", swapaxes=Fals
128131
np.array of shape (n_seqs x n_features)
129132
"""
130133
if isinstance(model, Enformer):
131-
return enformer_embed(seqs, model)
134+
return enformer_embed(seqs, model, device=device)
132135
elif isinstance(model, EsmForMaskedLM):
133136
return nucleotide_transformer_embed(seqs, model)
134137
elif isinstance(model, nn.Sequential):

src/polygraph/motifs.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -50,7 +50,7 @@ def scan(seqs, meme_file, group_col="Group", pthresh=1e-3, rc=True):
5050
for m in match:
5151
out["MotifID"].append(motif.name.decode())
5252
out["SeqID"].append(m.source.accession.decode())
53-
if m.strand=='+':
53+
if m.strand == "+":
5454
out["start"].append(m.start)
5555
out["end"].append(m.stop)
5656
else:
@@ -122,6 +122,10 @@ def nmf(counts, seqs, reference_group, group_col="Group", n_components=10):
122122
H.index = factors
123123
H.columns = counts.columns
124124

125+
# Normalize W and H matrices
126+
H = H.div(H.sum(axis=1), axis=0)
127+
W = W * H.sum(1)
128+
125129
# Add group IDs to W
126130
W[group_col] = seqs[group_col].tolist()
127131

src/polygraph/visualize.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -146,7 +146,9 @@ def pca_plot(
146146
group_col (str): Column containing group IDs.
147147
components (list): PCA components to plot
148148
size (float): Size of points
149-
show_ellipse (bool): Outline each group with an ellipse.
149+
show_ellipse (bool): Fit each group with a multivariate normal
150+
distribution and display an ellipse representing the 95%
151+
confidence level.
150152
reference_group (str): Group to use as reference. This group will
151153
be plotted first.
152154
"""
@@ -175,7 +177,7 @@ def pca_plot(
175177
+ p9.theme_classic()
176178
)
177179
if show_ellipse:
178-
g = g + p9.stat_ellipse()
180+
g = g + p9.stat_ellipse(type="norm")
179181
return g
180182

181183

0 commit comments

Comments
 (0)