File tree Expand file tree Collapse file tree 4 files changed +23
-5
lines changed
Expand file tree Collapse file tree 4 files changed +23
-5
lines changed Original file line number Diff line number Diff line change @@ -868,6 +868,8 @@ def __len__(self):
868868 def validate_allele_seq (self , gene , variant ):
869869 seq = self .result .gene_sequence (gene , genome = self .genome )
870870 pos = variant .rel_pos
871+ if variant .strand == "-" :
872+ pos = pos - len (variant .ref ) + 1
871873 ref_match = seq [pos : pos + len (variant .ref )] == variant .ref_tx
872874 alt_match = seq [pos : pos + len (variant .alt )] == variant .alt_tx
873875 return ref_match , alt_match
@@ -889,6 +891,9 @@ def __getitem__(self, idx):
889891 variant = self .variants .iloc [seq_idx ]
890892 rel_pos = variant .rel_pos + self .max_seq_shift
891893
894+ if variant .strand == "-" :
895+ rel_pos = rel_pos - len (variant .ref ) + 1
896+
892897 # by default cache values are nan if matched with reference genome
893898 # then it will be replaced with the predicted expression from cache.
894899 pred_expr = {model_name : torch .full ((self .result .shape [0 ],), torch .nan ) for model_name in self .model_names }
Original file line number Diff line number Diff line change @@ -44,12 +44,14 @@ def __enter__(self):
4444 return self
4545
4646 def __exit__ (self , exc_type , exc_val , exc_tb ):
47+ # breakpoint()
4748 if self .writer is not None :
4849 if self .metadata is not None :
4950 self .writer .add_key_value_metadata ({str (k ): str (v ) for k , v in self .metadata .items ()})
5051 self .writer .close ()
5152 else :
5253 warnings .warn ("NoDataFrameWrittenError: No dataframe was written to the parquet file." )
54+ pd .DataFrame ({}).to_parquet (self .output_path )
5355 self .first_chunk = True
5456
5557 def write (self , chunk : pd .DataFrame ) -> None :
Original file line number Diff line number Diff line change 3131from decima .utils .io import read_vcf_chunks , VariantAttributionWriter
3232from decima .core .result import DecimaResult
3333from decima .data .dataset import VariantDataset
34+ from decima .hub import load_decima_model
3435from decima .interpret .attributer import DecimaAttributer
3536from decima .model .metrics import WarningCounter
3637from decima .vep .vep import _log_vep_warnings , _write_vep_warnings
@@ -158,16 +159,15 @@ def variant_effect_attribution(
158159 f"Unsupported input type: { type (variants )} . Must be pd.DataFrame or str (path to .tsv or .vcf)."
159160 )
160161
161- result = DecimaResult . load ( metadata_anndata )
162-
162+ model = load_decima_model ( model , device = device )
163+ result = DecimaResult . load ( metadata_anndata or model . name )
163164 tasks , off_tasks = _get_on_off_tasks (result , tasks , off_tasks )
164- attributer = DecimaAttributer . load_decima_attributer (
165- model_name = model ,
165+ attributer = DecimaAttributer (
166+ model = model ,
166167 tasks = tasks ,
167168 off_tasks = off_tasks ,
168169 method = method ,
169170 transform = transform ,
170- device = device ,
171171 )
172172
173173 warning_counter = WarningCounter ()
Original file line number Diff line number Diff line change @@ -82,6 +82,17 @@ def test_VariantDataset_overlap_genes(df_variant):
8282 })
8383 df = VariantDataset .overlap_genes (df_variant , df_genes )
8484
85+ def test_VariantDataset_validate_allele_seq ():
86+ df_variant = pd .DataFrame ({
87+ "chrom" : ["chr15" ],
88+ "pos" : [44715509 ],
89+ "ref" : ["CC" ],
90+ "alt" : ["TT" ]
91+ })
92+ dataset = VariantDataset (df_variant )
93+ ref_match , _ = dataset .validate_allele_seq ("SPG11" , dataset .variants .iloc [1 ])
94+ assert ref_match
95+
8596def test_VariantDataset (df_variant ):
8697
8798 dataset = VariantDataset (df_variant , model_name = "v1_rep0" )
You can’t perform that action at this time.
0 commit comments