Skip to content

Commit 51e52d4

Browse files
authored
fixed prediction bug for duplicate cell ids
1 parent 443744a commit 51e52d4

File tree

1 file changed

+9
-9
lines changed

1 file changed

+9
-9
lines changed

src/segger/prediction/predict_parquet.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
format_time,
1919
create_anndata,
2020
coo_to_dense_adj,
21+
filter_transcripts
2122
)
2223
from segger.training.train import LitSegger
2324
from segger.training.segger_data_module import SeggerDataModule
@@ -441,6 +442,7 @@ def segment(
441442
seg_tag: str,
442443
transcript_file: Union[str, Path],
443444
score_cut: float = 0.5,
445+
qv: float = 30,
444446
use_cc: bool = True,
445447
file_format: str = "",
446448
save_transcripts: bool = True,
@@ -464,6 +466,7 @@ def segment(
464466
transcript_file (Union[str, Path]): Path to the transcripts Parquet file.
465467
score_cut (float, optional): The threshold for assigning transcripts to cells based on
466468
similarity scores. Defaults to 0.5.
469+
qv (float, optional):The minimum quality value threshold for filtering transcripts.
467470
use_cc (bool, optional): If True, perform connected components analysis for unassigned
468471
transcripts. Defaults to True.
469472
save_transcripts (bool, optional): Whether to save the transcripts as Parquet. Defaults to True.
@@ -538,20 +541,16 @@ def segment(
538541
print(f"Batch processing completed in {elapsed_time:.2f} seconds.")
539542

540543
seg_final_dd = pd.read_parquet(output_ddf_save_path)
541-
seg_final_dd = seg_final_dd.set_index("transcript_id")
542544

543545
step_start_time = time()
544546
if verbose:
545547
print(f"Applying max score selection logic...")
546-
547-
max_bound = seg_final_dd[seg_final_dd["bound"] == 1]
548-
max_bound = max_bound.loc[max_bound.groupby("transcript_id")["score"].idxmax()]
548+
output_ddf_save_path = save_dir / "transcripts_df.parquet"
549549

550-
# Step 2: Filter by 'bound' == 0 and find the maximum score for each transcript_id
551-
max_unbound = seg_final_dd[seg_final_dd["bound"] != 1]
552-
max_unbound = max_unbound.loc[max_unbound.groupby("transcript_id")["score"].idxmax()]
553-
554-
seg_final_filtered = pd.concat([max_bound, max_unbound]).sort_values(
550+
551+
seg_final_dd = pd.read_parquet(output_ddf_save_path)
552+
553+
seg_final_filtered = seg_final_dd.sort_values(
555554
"score", ascending=False
556555
).drop_duplicates(subset="transcript_id", keep="first")
557556

@@ -658,6 +657,7 @@ def _get_id():
658657

659658
# Step 5: Save the merged results based on options
660659
transcripts_df_filtered["segger_cell_id"] = transcripts_df_filtered["segger_cell_id"].fillna("UNASSIGNED")
660+
transcripts_df_filtered = filter_transcripts(transcripts_df_filtered, qv=qv)
661661

662662
if save_transcripts:
663663
if verbose:

0 commit comments

Comments
 (0)