Skip to content

Commit 5089cd7

Browse files
committed
adding dataframe validation with patito
1 parent 8ed974d commit 5089cd7

14 files changed

+639
-182
lines changed

bin/generate_variant_pivot.py

Lines changed: 54 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,15 +3,18 @@
33
# /// script
44
# requires-python = ">= 3.10"
55
# dependencies = [
6+
# "patito",
67
# "polars-lts-cpu",
78
# "loguru",
89
# ]
910
# ///
1011

12+
from __future__ import annotations
1113

1214
import argparse
1315
from pathlib import Path
1416

17+
import patito as pt
1518
import polars as pl
1619
from loguru import logger
1720

@@ -32,14 +35,55 @@ def parse_command_line_args() -> argparse.Namespace:
3235
return parser.parse_args()
3336

3437

38+
class VariantPivotSchema(pt.Model):
39+
"""Schema for validating variant pivot table input data."""
40+
41+
contig: str = pt.Field(description="Contig/chromosome name")
42+
ref: str = pt.Field(min_length=1, description="Reference allele")
43+
pos: int = pt.Field(gt=0, description="Genomic position")
44+
alt: str = pt.Field(min_length=1, description="Alternative allele")
45+
af: float | None = pt.Field(ge=0.0, le=1.0, description="Allele frequency")
46+
ac: int | None = pt.Field(ge=0, description="Allele count")
47+
dp: int | None = pt.Field(ge=0, description="Read depth")
48+
mq: float | None = pt.Field(ge=0, description="Mapping quality")
49+
gene: str | None = pt.Field(description="Gene name")
50+
aa_effect: str | None = pt.Field(description="Amino acid effect")
51+
ref_codon_alt: str | None = pt.Field(description="Reference codon/alternative codon")
52+
cds_pos: int | None = pt.Field(gt=0, description="CDS position")
53+
aa_pos: int | None = pt.Field(gt=0, description="Amino acid position")
54+
55+
56+
def validate_variant_data(df: pl.DataFrame) -> pl.DataFrame:
57+
"""Validate variant pivot data using patito schema.
58+
59+
Args:
60+
df: DataFrame with variant data
61+
62+
Returns:
63+
Validated DataFrame
64+
65+
Raises:
66+
ValueError: If validation fails
67+
"""
68+
# Validate using patito
69+
try:
70+
VariantPivotSchema.validate(df)
71+
except Exception as e:
72+
msg = f"Variant data validation failed: {e!s}"
73+
raise ValueError(msg) from e
74+
75+
return df
76+
77+
3578
def main() -> None:
3679
"""
37-
TODO
80+
Generate variant pivot table from VCF data extracted with SnpSift.
3881
"""
3982
args = parse_command_line_args()
4083
input_table = args.input_table
4184

42-
_ = pl.read_csv(
85+
# Read the data
86+
pivot_df = pl.read_csv(
4387
input_table,
4488
separator="\t",
4589
has_header=False,
@@ -60,7 +104,14 @@ def main() -> None:
60104
"aa_pos",
61105
],
62106
).with_columns(pl.col("aa_effect").str.replace(".p", "").alias("aa_effect"))
63-
logger.info("Hi mom!")
107+
108+
# Validate the data
109+
logger.info(f"Validating data from {input_table}")
110+
pivot_df = validate_variant_data(pivot_df)
111+
logger.success(f"Validation passed for {len(pivot_df)} variants")
112+
113+
# TODO: Add actual pivot table generation logic here
114+
logger.info("Variant pivot table generation complete")
64115

65116

66117
if __name__ == "__main__":

bin/ivar_variants_to_vcf.py

Lines changed: 90 additions & 24 deletions
Original file line numberDiff line numberDiff line change
@@ -3,9 +3,10 @@
33
# requires-python = ">=3.10"
44
# dependencies = [
55
# "biopython",
6+
# "patito",
67
# "loguru",
78
# "numpy",
8-
# "polars",
9+
# "polars-lts-cpu",
910
# "pydantic",
1011
# "scipy",
1112
# "typer",
@@ -20,11 +21,14 @@
2021
for validation. Features a beautiful command-line interface built with Typer and Rich.
2122
"""
2223

24+
from __future__ import annotations
25+
2326
from enum import Enum
24-
from pathlib import Path
27+
from pathlib import Path # noqa: TC003
2528
from typing import Annotated, cast
2629

2730
import numpy as np
31+
import patito as pt
2832
import polars as pl
2933
import typer
3034
from Bio import SeqIO
@@ -104,7 +108,7 @@ def validate_reverse_depth(cls, v: int, info: ValidationInfo) -> int:
104108
return v
105109

106110
@model_validator(mode="after")
107-
def validate_total_depth(self) -> "IvarVariant":
111+
def validate_total_depth(self) -> IvarVariant:
108112
"""Ensure total depth is consistent."""
109113
if self.total_dp < self.ref_dp + self.alt_dp:
110114
msg = "Total depth must be at least ref_dp + alt_dp"
@@ -174,11 +178,69 @@ def validate_output_dir(cls, v: Path) -> Path:
174178
return v
175179

176180

181+
# ===== Patito Schema for iVar TSV Validation =====
182+
183+
184+
class IvarTsvSchema(pt.Model):
185+
"""Schema for validating iVar TSV input data."""
186+
187+
REGION: str = pt.Field(description="Reference sequence name")
188+
POS: int = pt.Field(gt=0, description="Position in reference")
189+
REF: str = pt.Field(min_length=1, description="Reference allele")
190+
ALT: str = pt.Field(description="Alternative allele (can be +/- for indels)")
191+
REF_DP: int = pt.Field(ge=0, description="Reference depth")
192+
REF_RV: int = pt.Field(ge=0, description="Reference reverse reads")
193+
REF_QUAL: float = pt.Field(ge=0, le=100, description="Reference quality")
194+
ALT_DP: int = pt.Field(ge=0, description="Alternative depth")
195+
ALT_RV: int = pt.Field(ge=0, description="Alternative reverse reads")
196+
ALT_QUAL: float = pt.Field(ge=0, le=100, description="Alternative quality")
197+
ALT_FREQ: float = pt.Field(ge=0, le=1, description="Alternative frequency")
198+
TOTAL_DP: int = pt.Field(ge=0, description="Total depth")
199+
PVAL: float = pt.Field(description="P-value from Fisher's exact test")
200+
PASS: bool = pt.Field(description="Whether variant passed iVar filters")
201+
REF_CODON: str | None = pt.Field(None, description="Reference codon")
202+
ALT_CODON: str | None = pt.Field(None, description="Alternative codon")
203+
204+
205+
def validate_ivar_data(unvalidated_lf: pl.LazyFrame) -> pl.LazyFrame:
206+
"""Validate iVar TSV data using patito schema.
207+
208+
Args:
209+
unvalidated_lf: Lazy DataFrame with iVar data
210+
211+
Returns:
212+
Validated DataFrame
213+
214+
Raises:
215+
ValueError: If validation fails
216+
"""
217+
# Collect a small sample to validate schema
218+
sample_df = unvalidated_lf.head(1000).collect()
219+
220+
# Validate using patito
221+
try:
222+
IvarTsvSchema.validate(sample_df)
223+
except Exception as e:
224+
msg = f"iVar data validation failed:\n{e}"
225+
raise ValueError(msg) from e
226+
227+
# Additional validation: ensure REF_RV <= REF_DP and ALT_RV <= ALT_DP
228+
validated_lf = unvalidated_lf.filter(
229+
(pl.col("REF_RV") <= pl.col("REF_DP")) & (pl.col("ALT_RV") <= pl.col("ALT_DP")),
230+
)
231+
232+
# Ensure TOTAL_DP >= REF_DP + ALT_DP
233+
return validated_lf.filter(pl.col("TOTAL_DP") >= (pl.col("REF_DP") + pl.col("ALT_DP")))
234+
235+
177236
# ===== Pure Functions for Data Transformation =====
178237

179238

180239
def calculate_strand_bias_pvalue(
181-
ref_dp: int, ref_rv: int, alt_dp: int, alt_rv: int
240+
ref_dp: int,
241+
ref_rv: int,
242+
alt_dp: int,
243+
alt_rv: int,
182244
) -> float:
183245
"""Calculate p-value for strand bias using Fisher's exact test.
184246
@@ -289,9 +351,7 @@ def create_filter_expr(config: ConversionConfig) -> pl.Expr:
289351

290352
# iVar PASS filter
291353
filters.append(
292-
pl.when(pl.col("PASS"))
293-
.then(pl.lit(""))
294-
.otherwise(pl.lit(FilterType.FAIL_TEST.value)),
354+
pl.when(pl.col("PASS")).then(pl.lit("")).otherwise(pl.lit(FilterType.FAIL_TEST.value)),
295355
)
296356

297357
# Quality filter
@@ -317,7 +377,8 @@ def create_filter_expr(config: ConversionConfig) -> pl.Expr:
317377
.list.join(";")
318378
.fill_null("")
319379
.map_elements(
320-
lambda x: FilterType.PASS.value if x == "" else x, return_dtype=pl.Utf8
380+
lambda x: FilterType.PASS.value if x == "" else x,
381+
return_dtype=pl.Utf8,
321382
)
322383
)
323384

@@ -345,7 +406,8 @@ def create_sample_info_expr() -> pl.Expr:
345406

346407

347408
def transform_ivar_to_vcf(
348-
ivar_lf: pl.LazyFrame, config: ConversionConfig
409+
ivar_lf: pl.LazyFrame,
410+
config: ConversionConfig,
349411
) -> pl.LazyFrame:
350412
"""Transform iVar data to VCF format using pure expressions.
351413
@@ -382,7 +444,7 @@ def transform_ivar_to_vcf(
382444
],
383445
).alias("INFO"),
384446
pl.lit(
385-
"GT:DP:REF_DP:REF_RV:REF_QUAL:ALT_DP:ALT_RV:ALT_QUAL:ALT_FREQ"
447+
"GT:DP:REF_DP:REF_RV:REF_QUAL:ALT_DP:ALT_RV:ALT_QUAL:ALT_FREQ",
386448
).alias("FORMAT"),
387449
create_sample_info_expr().alias("SAMPLE"),
388450
],
@@ -400,7 +462,8 @@ def find_consecutive_variants_expr() -> pl.Expr:
400462

401463

402464
def process_consecutive_snps(
403-
ivar_lf: pl.LazyFrame, config: ConversionConfig
465+
ivar_lf: pl.LazyFrame,
466+
config: ConversionConfig,
404467
) -> pl.LazyFrame:
405468
"""Process consecutive SNPs for potential merging.
406469
@@ -539,6 +602,10 @@ def process_ivar_file(config: ConversionConfig) -> None:
539602
task = progress.add_task("[cyan]Loading iVar data...", total=None)
540603
ivar_df = pl.scan_csv(str(config.file_in), separator="\t")
541604

605+
# Validate data
606+
progress.update(task, description="[yellow]Validating iVar data...")
607+
ivar_df = validate_ivar_data(ivar_df)
608+
542609
# Transform to VCF format
543610
progress.update(task, description="[yellow]Transforming to VCF format...")
544611
vcf_df = transform_ivar_to_vcf(ivar_df, config)
@@ -565,13 +632,14 @@ def process_ivar_file(config: ConversionConfig) -> None:
565632
# Write all haplotypes output
566633
progress.update(task, description="[green]Writing all haplotypes VCF...")
567634
all_hap_path = (
568-
config.file_out.parent
569-
/ f"{config.file_out.stem}_all_hap{config.file_out.suffix}"
635+
config.file_out.parent / f"{config.file_out.stem}_all_hap{config.file_out.suffix}"
570636
)
571637
write_vcf_file(result_df, all_hap_path, headers, sample_name)
572638

573639
progress.update(
574-
task, description="[bold green]✓ Conversion complete!", completed=True
640+
task,
641+
description="[bold green]✓ Conversion complete!",
642+
completed=True,
575643
)
576644

577645

@@ -758,14 +826,13 @@ def convert( # noqa: PLR0913
758826

759827
# Success message
760828
console.print(
761-
f"\n[bold green]✓[/bold green] Successfully converted to {config.file_out}"
829+
f"\n[bold green]✓[/bold green] Successfully converted to {config.file_out}",
762830
)
763831
all_hap_path = (
764-
config.file_out.parent
765-
/ f"{config.file_out.stem}_all_hap{config.file_out.suffix}"
832+
config.file_out.parent / f"{config.file_out.stem}_all_hap{config.file_out.suffix}"
766833
)
767834
console.print(
768-
f"[bold green]✓[/bold green] All haplotypes written to {all_hap_path}"
835+
f"[bold green]✓[/bold green] All haplotypes written to {all_hap_path}",
769836
)
770837

771838
except Exception as e: # noqa: BLE001
@@ -774,7 +841,7 @@ def convert( # noqa: PLR0913
774841

775842

776843
@app.command()
777-
def validate(
844+
def validate( # noqa: C901, PLR0912
778845
file_path: Annotated[
779846
Path,
780847
typer.Argument(
@@ -841,7 +908,7 @@ def validate(
841908

842909
if missing_cols:
843910
console.print(
844-
f"\n[red]✗ Missing required columns:[/red] {', '.join(missing_cols)}"
911+
f"\n[red]✗ Missing required columns:[/red] {', '.join(missing_cols)}",
845912
)
846913
else:
847914
console.print("\n[green]✓ All required VCF columns present[/green]")
@@ -905,8 +972,7 @@ def stats(
905972
console.print("\n[bold]Variant Types:[/bold]")
906973
snp_count = len(
907974
ivar_df.filter(
908-
~pl.col("ALT").str.starts_with("+")
909-
& ~pl.col("ALT").str.starts_with("-"),
975+
~pl.col("ALT").str.starts_with("+") & ~pl.col("ALT").str.starts_with("-"),
910976
),
911977
)
912978
ins_count = len(ivar_df.filter(pl.col("ALT").str.starts_with("+")))
@@ -929,8 +995,8 @@ def stats(
929995
for low, high in freq_bins:
930996
count = len(
931997
ivar_df.filter(
932-
(pl.col("ALT_FREQ") >= low) & (pl.col("ALT_FREQ") < high)
933-
)
998+
(pl.col("ALT_FREQ") >= low) & (pl.col("ALT_FREQ") < high),
999+
),
9341000
)
9351001
if count > 0:
9361002
console.print(f" • {low:.0%}-{high:.0%}: {count:,} variants")

0 commit comments

Comments
 (0)