33# requires-python = ">=3.10"
44# dependencies = [
55# "biopython",
6+ # "patito",
67# "loguru",
78# "numpy",
8- # "polars",
9+ # "polars-lts-cpu ",
910# "pydantic",
1011# "scipy",
1112# "typer",
2021for validation. Features a beautiful command-line interface built with Typer and Rich.
2122"""
2223
24+ from __future__ import annotations
25+
2326from enum import Enum
24- from pathlib import Path
27+ from pathlib import Path # noqa: TC003
2528from typing import Annotated , cast
2629
2730import numpy as np
31+ import patito as pt
2832import polars as pl
2933import typer
3034from Bio import SeqIO
@@ -104,7 +108,7 @@ def validate_reverse_depth(cls, v: int, info: ValidationInfo) -> int:
104108 return v
105109
106110 @model_validator (mode = "after" )
107- def validate_total_depth (self ) -> " IvarVariant" :
111+ def validate_total_depth (self ) -> IvarVariant :
108112 """Ensure total depth is consistent."""
109113 if self .total_dp < self .ref_dp + self .alt_dp :
110114 msg = "Total depth must be at least ref_dp + alt_dp"
@@ -174,11 +178,69 @@ def validate_output_dir(cls, v: Path) -> Path:
174178 return v
175179
176180
181+ # ===== Patito Schema for iVar TSV Validation =====
182+
183+
184+ class IvarTsvSchema (pt .Model ):
185+ """Schema for validating iVar TSV input data."""
186+
187+ REGION : str = pt .Field (description = "Reference sequence name" )
188+ POS : int = pt .Field (gt = 0 , description = "Position in reference" )
189+ REF : str = pt .Field (min_length = 1 , description = "Reference allele" )
190+ ALT : str = pt .Field (description = "Alternative allele (can be +/- for indels)" )
191+ REF_DP : int = pt .Field (ge = 0 , description = "Reference depth" )
192+ REF_RV : int = pt .Field (ge = 0 , description = "Reference reverse reads" )
193+ REF_QUAL : float = pt .Field (ge = 0 , le = 100 , description = "Reference quality" )
194+ ALT_DP : int = pt .Field (ge = 0 , description = "Alternative depth" )
195+ ALT_RV : int = pt .Field (ge = 0 , description = "Alternative reverse reads" )
196+ ALT_QUAL : float = pt .Field (ge = 0 , le = 100 , description = "Alternative quality" )
197+ ALT_FREQ : float = pt .Field (ge = 0 , le = 1 , description = "Alternative frequency" )
198+ TOTAL_DP : int = pt .Field (ge = 0 , description = "Total depth" )
199+ PVAL : float = pt .Field (description = "P-value from Fisher's exact test" )
200+ PASS : bool = pt .Field (description = "Whether variant passed iVar filters" )
201+ REF_CODON : str | None = pt .Field (None , description = "Reference codon" )
202+ ALT_CODON : str | None = pt .Field (None , description = "Alternative codon" )
203+
204+
205+ def validate_ivar_data (unvalidated_lf : pl .LazyFrame ) -> pl .LazyFrame :
206+ """Validate iVar TSV data using patito schema.
207+
208+ Args:
209+ unvalidated_lf: Lazy DataFrame with iVar data
210+
211+ Returns:
212+ Validated DataFrame
213+
214+ Raises:
215+ ValueError: If validation fails
216+ """
217+ # Collect a small sample to validate schema
218+ sample_df = unvalidated_lf .head (1000 ).collect ()
219+
220+ # Validate using patito
221+ try :
222+ IvarTsvSchema .validate (sample_df )
223+ except Exception as e :
224+ msg = f"iVar data validation failed:\n { e } "
225+ raise ValueError (msg ) from e
226+
227+ # Additional validation: ensure REF_RV <= REF_DP and ALT_RV <= ALT_DP
228+ validated_lf = unvalidated_lf .filter (
229+ (pl .col ("REF_RV" ) <= pl .col ("REF_DP" )) & (pl .col ("ALT_RV" ) <= pl .col ("ALT_DP" )),
230+ )
231+
232+ # Ensure TOTAL_DP >= REF_DP + ALT_DP
233+ return validated_lf .filter (pl .col ("TOTAL_DP" ) >= (pl .col ("REF_DP" ) + pl .col ("ALT_DP" )))
234+
235+
177236# ===== Pure Functions for Data Transformation =====
178237
179238
180239def calculate_strand_bias_pvalue (
181- ref_dp : int , ref_rv : int , alt_dp : int , alt_rv : int
240+ ref_dp : int ,
241+ ref_rv : int ,
242+ alt_dp : int ,
243+ alt_rv : int ,
182244) -> float :
183245 """Calculate p-value for strand bias using Fisher's exact test.
184246
@@ -289,9 +351,7 @@ def create_filter_expr(config: ConversionConfig) -> pl.Expr:
289351
290352 # iVar PASS filter
291353 filters .append (
292- pl .when (pl .col ("PASS" ))
293- .then (pl .lit ("" ))
294- .otherwise (pl .lit (FilterType .FAIL_TEST .value )),
354+ pl .when (pl .col ("PASS" )).then (pl .lit ("" )).otherwise (pl .lit (FilterType .FAIL_TEST .value )),
295355 )
296356
297357 # Quality filter
@@ -317,7 +377,8 @@ def create_filter_expr(config: ConversionConfig) -> pl.Expr:
317377 .list .join (";" )
318378 .fill_null ("" )
319379 .map_elements (
320- lambda x : FilterType .PASS .value if x == "" else x , return_dtype = pl .Utf8
380+ lambda x : FilterType .PASS .value if x == "" else x ,
381+ return_dtype = pl .Utf8 ,
321382 )
322383 )
323384
@@ -345,7 +406,8 @@ def create_sample_info_expr() -> pl.Expr:
345406
346407
347408def transform_ivar_to_vcf (
348- ivar_lf : pl .LazyFrame , config : ConversionConfig
409+ ivar_lf : pl .LazyFrame ,
410+ config : ConversionConfig ,
349411) -> pl .LazyFrame :
350412 """Transform iVar data to VCF format using pure expressions.
351413
@@ -382,7 +444,7 @@ def transform_ivar_to_vcf(
382444 ],
383445 ).alias ("INFO" ),
384446 pl .lit (
385- "GT:DP:REF_DP:REF_RV:REF_QUAL:ALT_DP:ALT_RV:ALT_QUAL:ALT_FREQ"
447+ "GT:DP:REF_DP:REF_RV:REF_QUAL:ALT_DP:ALT_RV:ALT_QUAL:ALT_FREQ" ,
386448 ).alias ("FORMAT" ),
387449 create_sample_info_expr ().alias ("SAMPLE" ),
388450 ],
@@ -400,7 +462,8 @@ def find_consecutive_variants_expr() -> pl.Expr:
400462
401463
402464def process_consecutive_snps (
403- ivar_lf : pl .LazyFrame , config : ConversionConfig
465+ ivar_lf : pl .LazyFrame ,
466+ config : ConversionConfig ,
404467) -> pl .LazyFrame :
405468 """Process consecutive SNPs for potential merging.
406469
@@ -539,6 +602,10 @@ def process_ivar_file(config: ConversionConfig) -> None:
539602 task = progress .add_task ("[cyan]Loading iVar data..." , total = None )
540603 ivar_df = pl .scan_csv (str (config .file_in ), separator = "\t " )
541604
605+ # Validate data
606+ progress .update (task , description = "[yellow]Validating iVar data..." )
607+ ivar_df = validate_ivar_data (ivar_df )
608+
542609 # Transform to VCF format
543610 progress .update (task , description = "[yellow]Transforming to VCF format..." )
544611 vcf_df = transform_ivar_to_vcf (ivar_df , config )
@@ -565,13 +632,14 @@ def process_ivar_file(config: ConversionConfig) -> None:
565632 # Write all haplotypes output
566633 progress .update (task , description = "[green]Writing all haplotypes VCF..." )
567634 all_hap_path = (
568- config .file_out .parent
569- / f"{ config .file_out .stem } _all_hap{ config .file_out .suffix } "
635+ config .file_out .parent / f"{ config .file_out .stem } _all_hap{ config .file_out .suffix } "
570636 )
571637 write_vcf_file (result_df , all_hap_path , headers , sample_name )
572638
573639 progress .update (
574- task , description = "[bold green]✓ Conversion complete!" , completed = True
640+ task ,
641+ description = "[bold green]✓ Conversion complete!" ,
642+ completed = True ,
575643 )
576644
577645
@@ -758,14 +826,13 @@ def convert( # noqa: PLR0913
758826
759827 # Success message
760828 console .print (
761- f"\n [bold green]✓[/bold green] Successfully converted to { config .file_out } "
829+ f"\n [bold green]✓[/bold green] Successfully converted to { config .file_out } " ,
762830 )
763831 all_hap_path = (
764- config .file_out .parent
765- / f"{ config .file_out .stem } _all_hap{ config .file_out .suffix } "
832+ config .file_out .parent / f"{ config .file_out .stem } _all_hap{ config .file_out .suffix } "
766833 )
767834 console .print (
768- f"[bold green]✓[/bold green] All haplotypes written to { all_hap_path } "
835+ f"[bold green]✓[/bold green] All haplotypes written to { all_hap_path } " ,
769836 )
770837
771838 except Exception as e : # noqa: BLE001
@@ -774,7 +841,7 @@ def convert( # noqa: PLR0913
774841
775842
776843@app .command ()
777- def validate (
844+ def validate ( # noqa: C901, PLR0912
778845 file_path : Annotated [
779846 Path ,
780847 typer .Argument (
@@ -841,7 +908,7 @@ def validate(
841908
842909 if missing_cols :
843910 console .print (
844- f"\n [red]✗ Missing required columns:[/red] { ', ' .join (missing_cols )} "
911+ f"\n [red]✗ Missing required columns:[/red] { ', ' .join (missing_cols )} " ,
845912 )
846913 else :
847914 console .print ("\n [green]✓ All required VCF columns present[/green]" )
@@ -905,8 +972,7 @@ def stats(
905972 console .print ("\n [bold]Variant Types:[/bold]" )
906973 snp_count = len (
907974 ivar_df .filter (
908- ~ pl .col ("ALT" ).str .starts_with ("+" )
909- & ~ pl .col ("ALT" ).str .starts_with ("-" ),
975+ ~ pl .col ("ALT" ).str .starts_with ("+" ) & ~ pl .col ("ALT" ).str .starts_with ("-" ),
910976 ),
911977 )
912978 ins_count = len (ivar_df .filter (pl .col ("ALT" ).str .starts_with ("+" )))
@@ -929,8 +995,8 @@ def stats(
929995 for low , high in freq_bins :
930996 count = len (
931997 ivar_df .filter (
932- (pl .col ("ALT_FREQ" ) >= low ) & (pl .col ("ALT_FREQ" ) < high )
933- )
998+ (pl .col ("ALT_FREQ" ) >= low ) & (pl .col ("ALT_FREQ" ) < high ),
999+ ),
9341000 )
9351001 if count > 0 :
9361002 console .print (f" • { low :.0%} -{ high :.0%} : { count :,} variants" )
0 commit comments