Skip to content

Commit 894cb6b

Browse files
committed
fixed compare func
1 parent 70ee3fd commit 894cb6b

File tree

1 file changed

+30
-15
lines changed

1 file changed

+30
-15
lines changed

src/grelu/interpret/motifs.py

Lines changed: 30 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -345,22 +345,37 @@ def compare_motifs(
345345
pthresh=pthresh,
346346
rc=rc, # Scan both strands
347347
)
348-
349-
# Compare the results for alt and ref sequences
350-
scan = (
351-
scan.pivot_table(
352-
index=["motif", "start", "end", "strand"],
353-
columns=["sequence"],
354-
values="score",
348+
if len(scan) > 0:
349+
350+
# Compare the results for alt and ref sequences
351+
scan = (
352+
scan.pivot_table(
353+
index=["motif", "start", "end", "strand"],
354+
columns=["sequence"],
355+
values=["score", "p-value"],
356+
)
357+
.reset_index()
355358
)
356-
.fillna(0)
357-
.reset_index()
358-
)
359-
360-
# Compute fold change
361-
scan["foldChange"] = scan.alt / scan.ref
362-
scan = scan.sort_values("foldChange").reset_index(drop=True)
363-
return scan
359+
scan.columns = [col[0] if col[1] == '' else '_'.join(col) for col in scan.columns]
360+
for col in ["p-value_alt", "p-value_ref", "score_alt", "score_ref"]:
361+
if col not in scan.columns:
362+
scan[col] = np.nan
363+
364+
# Fill in empty positions
365+
for row in scan[scan.score_alt.isna()].itertuples():
366+
sc = scan_sequences(seqs=alt_seq[row.start:row.end+1], motifs=motifs, names=[row.motif], pthresh=1, rc=row.strand=='-').iloc[0]
367+
scan.loc[row.Index, 'score_alt'] = sc.score
368+
scan.loc[row.Index, 'p-value_alt'] = sc['p-value']
369+
370+
for row in scan[scan.score_ref.isna()].itertuples():
371+
sc = scan_sequences(seqs=ref_seq[row.start:row.end+1], motifs=motifs, names=[row.motif], pthresh=1, rc=row.strand=='-').iloc[0]
372+
scan.loc[row.Index, 'score_ref'] = sc.score
373+
scan.loc[row.Index, 'p-value_ref'] = sc['p-value']
374+
375+
# Compute fold change
376+
scan["score_diff"] = scan.score_alt - scan.score_ref
377+
scan = scan.sort_values("score_diff").reset_index(drop=True)
378+
return scan
364379

365380

366381
def run_tomtom(motifs: Dict[str, np.ndarray], meme_file: str) -> pd.DataFrame:

0 commit comments

Comments
 (0)