|
14 | 14 | from itertools import compress |
15 | 15 | from torch_geometric.data import HeteroData |
16 | 16 | from torch_geometric.transforms import RandomLinkSplit |
| 17 | +from pqdm.threads import pqdm |
17 | 18 | import torch |
18 | 19 | import random |
19 | 20 | from segger.data.parquet.transcript_embedding import TranscriptEmbedding |
| 21 | +# import re |
20 | 22 |
|
21 | 23 |
|
22 | 24 | # TODO: Add documentation for settings |
@@ -203,7 +205,8 @@ def transcripts_metadata(self) -> dict: |
203 | 205 | missing_genes = list(set(names_str) - set(self._emb_genes)) |
204 | 206 | logging.warning(f"Number of missing genes: {len(missing_genes)}") |
205 | 207 | self.settings.transcripts.filter_substrings.extend(missing_genes) |
206 | | - pattern = "|".join(self.settings.transcripts.filter_substrings) |
| 208 | + # pattern = "|".join(self.settings.transcripts.filter_substrings) |
| 209 | + pattern = "|".join(f"^{s}" for s in self.settings.transcripts.filter_substrings) |
207 | 210 | mask = pc.invert(pc.match_substring_regex(names, pattern)) |
208 | 211 | filtered_names = pc.filter(names, mask).to_pylist() |
209 | 212 | metadata["feature_names"] = [ |
@@ -674,6 +677,7 @@ def _load_transcripts(self, path: os.PathLike, min_qv: float = 30.0): |
674 | 677 | transcripts[self.settings.transcripts.label] = transcripts[ |
675 | 678 | self.settings.transcripts.label |
676 | 679 | ].apply(lambda x: x.decode("utf-8") if isinstance(x, bytes) else x) |
| 680 | + qv_column = getattr(self.settings.transcripts, "qv_column", None) |
677 | 681 | transcripts = utils.filter_transcripts( |
678 | 682 | transcripts, |
679 | 683 | self.settings.transcripts.label, |
|
0 commit comments