Skip to content

Commit 224b892

Browse files
committed
fixing all tests so they can run in CI and updating the pyproject.toml to encompass a non-deprecated dev environment and to enable rust-script to compile SIMD insructions
1 parent 5725d6d commit 224b892

14 files changed

+200
-166
lines changed

.github/workflows/test.yml

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,19 +31,11 @@ jobs:
3131

3232
- name: Install dependencies with UV
3333
run: |
34-
uv sync --dev --frozen
34+
uv sync --frozen
3535
3636
- name: Run Python tests with pytest
3737
run: |
38-
uv run pytest bin/ -v --cov=bin --cov-report=xml --cov-report=term-missing
39-
40-
- name: Upload coverage reports
41-
uses: codecov/codecov-action@v4
42-
with:
43-
files: ./coverage.xml
44-
flags: python-${{ matrix.python-version }}
45-
name: Python ${{ matrix.python-version }}
46-
if: matrix.python-version == '3.12'
38+
uv run pytest
4739
4840
python-tests-tox:
4941
runs-on: ubuntu-latest

bin/generate_variant_pivot.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,7 @@ def main() -> None:
4444
separator="\t",
4545
has_header=False,
4646
skip_rows=1,
47-
columns=[
47+
new_columns=[
4848
"contig",
4949
"ref",
5050
"pos",
@@ -60,7 +60,7 @@ def main() -> None:
6060
"aa_pos",
6161
],
6262
).with_columns(pl.col("aa_effect").str.replace(".p", "").alias("aa_effect"))
63-
logger.info("Hi mom!")
63+
logger.info("Pivot implementation coming soon!")
6464

6565

6666
if __name__ == "__main__":

bin/ivar_variants_to_vcf.py

Lines changed: 83 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -20,6 +20,7 @@
2020
for validation. Features a beautiful command-line interface built with Typer and Rich.
2121
"""
2222

23+
import gzip
2324
from enum import Enum
2425
from pathlib import Path
2526
from typing import Annotated, cast
@@ -97,9 +98,17 @@ class IvarVariant(BaseModel):
9798
@field_validator("ref_rv", "alt_rv")
9899
@classmethod
99100
def validate_reverse_depth(cls, v: int, info: ValidationInfo) -> int:
100-
"""Ensure reverse depth doesn't exceed total depth."""
101-
if "ref_dp" in info.data and v > info.data["ref_dp"]:
102-
msg = "Reverse depth cannot exceed total depth"
101+
"""Ensure reverse depth doesn't exceed total depth for that allele."""
102+
if info.field_name == "ref_rv":
103+
depth_field = "ref_dp"
104+
elif info.field_name == "alt_rv":
105+
depth_field = "alt_dp"
106+
else:
107+
msg = f"Unexpected field in reverse depth validator: {info.field_name}"
108+
raise ValueError(msg)
109+
110+
if depth_field in info.data and v > info.data[depth_field]:
111+
msg = f"Reverse depth ({info.field_name}) cannot exceed total depth ({depth_field})"
103112
raise ValueError(msg)
104113
return v
105114

@@ -178,7 +187,10 @@ def validate_output_dir(cls, v: Path) -> Path:
178187

179188

180189
def calculate_strand_bias_pvalue(
181-
ref_dp: int, ref_rv: int, alt_dp: int, alt_rv: int
190+
ref_dp: int,
191+
ref_rv: int,
192+
alt_dp: int,
193+
alt_rv: int,
182194
) -> float:
183195
"""Calculate p-value for strand bias using Fisher's exact test.
184196
@@ -199,7 +211,7 @@ def calculate_strand_bias_pvalue(
199211
)
200212
_odds_ratio, pvalue = cast(
201213
"tuple[float, float]",
202-
fisher_exact(contingency_table, alternative="greater"),
214+
fisher_exact(contingency_table, alternative="two-sided"),
203215
)
204216
return pvalue
205217

@@ -289,9 +301,7 @@ def create_filter_expr(config: ConversionConfig) -> pl.Expr:
289301

290302
# iVar PASS filter
291303
filters.append(
292-
pl.when(pl.col("PASS"))
293-
.then(pl.lit(""))
294-
.otherwise(pl.lit(FilterType.FAIL_TEST.value)),
304+
pl.when(pl.col("PASS")).then(pl.lit("")).otherwise(pl.lit(FilterType.FAIL_TEST.value)),
295305
)
296306

297307
# Quality filter
@@ -317,7 +327,8 @@ def create_filter_expr(config: ConversionConfig) -> pl.Expr:
317327
.list.join(";")
318328
.fill_null("")
319329
.map_elements(
320-
lambda x: FilterType.PASS.value if x == "" else x, return_dtype=pl.Utf8
330+
lambda x: FilterType.PASS.value if x == "" else x,
331+
return_dtype=pl.Utf8,
321332
)
322333
)
323334

@@ -345,7 +356,8 @@ def create_sample_info_expr() -> pl.Expr:
345356

346357

347358
def transform_ivar_to_vcf(
348-
ivar_lf: pl.LazyFrame, config: ConversionConfig
359+
ivar_lf: pl.LazyFrame,
360+
config: ConversionConfig,
349361
) -> pl.LazyFrame:
350362
"""Transform iVar data to VCF format using pure expressions.
351363
@@ -382,7 +394,7 @@ def transform_ivar_to_vcf(
382394
],
383395
).alias("INFO"),
384396
pl.lit(
385-
"GT:DP:REF_DP:REF_RV:REF_QUAL:ALT_DP:ALT_RV:ALT_QUAL:ALT_FREQ"
397+
"GT:DP:REF_DP:REF_RV:REF_QUAL:ALT_DP:ALT_RV:ALT_QUAL:ALT_FREQ",
386398
).alias("FORMAT"),
387399
create_sample_info_expr().alias("SAMPLE"),
388400
],
@@ -400,7 +412,8 @@ def find_consecutive_variants_expr() -> pl.Expr:
400412

401413

402414
def process_consecutive_snps(
403-
ivar_lf: pl.LazyFrame, config: ConversionConfig
415+
ivar_lf: pl.LazyFrame,
416+
config: ConversionConfig,
404417
) -> pl.LazyFrame:
405418
"""Process consecutive SNPs for potential merging.
406419
@@ -509,7 +522,6 @@ def write_vcf_file(
509522
# Write file (handle gzipped output)
510523
if str(filepath).endswith(".gz"):
511524
# For gzip, we need to write everything as text to the same handle
512-
import gzip
513525

514526
with gzip.open(filepath, "wt") as f:
515527
f.write(header_text)
@@ -539,6 +551,52 @@ def process_ivar_file(config: ConversionConfig) -> None:
539551
task = progress.add_task("[cyan]Loading iVar data...", total=None)
540552
ivar_df = pl.scan_csv(str(config.file_in), separator="\t")
541553

554+
# Check if input has any data rows (collect schema to check)
555+
progress.update(task, description="[yellow]Checking input data...")
556+
try:
557+
row_count = ivar_df.select(pl.len()).collect().item()
558+
except pl.exceptions.NoDataError:
559+
row_count = 0
560+
561+
# Generate headers (needed regardless of data)
562+
progress.update(task, description="[yellow]Generating VCF headers...")
563+
headers = generate_vcf_header(config)
564+
sample_name = config.file_in.stem
565+
566+
if row_count == 0:
567+
# Handle empty input: write VCF with headers only
568+
progress.update(
569+
task,
570+
description="[yellow]No variants found, writing empty VCF...",
571+
)
572+
empty_df = pl.DataFrame(
573+
schema={
574+
"CHROM": pl.Utf8,
575+
"POS": pl.Int64,
576+
"ID": pl.Utf8,
577+
"REF": pl.Utf8,
578+
"ALT": pl.Utf8,
579+
"QUAL": pl.Utf8,
580+
"FILTER": pl.Utf8,
581+
"INFO": pl.Utf8,
582+
"FORMAT": pl.Utf8,
583+
"SAMPLE": pl.Utf8,
584+
},
585+
)
586+
write_vcf_file(empty_df, config.file_out, headers, sample_name)
587+
588+
all_hap_path = (
589+
config.file_out.parent / f"{config.file_out.stem}_all_hap{config.file_out.suffix}"
590+
)
591+
write_vcf_file(empty_df, all_hap_path, headers, sample_name)
592+
593+
progress.update(
594+
task,
595+
description="[bold yellow]✓ No variants found, empty VCF written",
596+
completed=True,
597+
)
598+
return
599+
542600
# Transform to VCF format
543601
progress.update(task, description="[yellow]Transforming to VCF format...")
544602
vcf_df = transform_ivar_to_vcf(ivar_df, config)
@@ -551,27 +609,21 @@ def process_ivar_file(config: ConversionConfig) -> None:
551609
progress.update(task, description="[yellow]Collecting results...")
552610
result_df = processed_df.collect()
553611

554-
# Generate headers
555-
progress.update(task, description="[yellow]Generating VCF headers...")
556-
headers = generate_vcf_header(config)
557-
558-
# Get sample name from input file
559-
sample_name = config.file_in.stem
560-
561612
# Write consensus output
562613
progress.update(task, description="[green]Writing consensus VCF...")
563614
write_vcf_file(result_df, config.file_out, headers, sample_name)
564615

565616
# Write all haplotypes output
566617
progress.update(task, description="[green]Writing all haplotypes VCF...")
567618
all_hap_path = (
568-
config.file_out.parent
569-
/ f"{config.file_out.stem}_all_hap{config.file_out.suffix}"
619+
config.file_out.parent / f"{config.file_out.stem}_all_hap{config.file_out.suffix}"
570620
)
571621
write_vcf_file(result_df, all_hap_path, headers, sample_name)
572622

573623
progress.update(
574-
task, description="[bold green]✓ Conversion complete!", completed=True
624+
task,
625+
description="[bold green]✓ Conversion complete!",
626+
completed=True,
575627
)
576628

577629

@@ -758,14 +810,13 @@ def convert( # noqa: PLR0913
758810

759811
# Success message
760812
console.print(
761-
f"\n[bold green]✓[/bold green] Successfully converted to {config.file_out}"
813+
f"\n[bold green]✓[/bold green] Successfully converted to {config.file_out}",
762814
)
763815
all_hap_path = (
764-
config.file_out.parent
765-
/ f"{config.file_out.stem}_all_hap{config.file_out.suffix}"
816+
config.file_out.parent / f"{config.file_out.stem}_all_hap{config.file_out.suffix}"
766817
)
767818
console.print(
768-
f"[bold green]✓[/bold green] All haplotypes written to {all_hap_path}"
819+
f"[bold green]✓[/bold green] All haplotypes written to {all_hap_path}",
769820
)
770821

771822
except Exception as e: # noqa: BLE001
@@ -774,7 +825,7 @@ def convert( # noqa: PLR0913
774825

775826

776827
@app.command()
777-
def validate(
828+
def validate( # noqa: C901, PLR0912
778829
file_path: Annotated[
779830
Path,
780831
typer.Argument(
@@ -792,8 +843,6 @@ def validate(
792843
console.print(f"[cyan]Validating VCF file:[/cyan] {file_path}")
793844

794845
try:
795-
import gzip
796-
797846
# Count header lines (handle gzipped files)
798847
header_count = 0
799848
if str(file_path).endswith(".gz"):
@@ -841,7 +890,7 @@ def validate(
841890

842891
if missing_cols:
843892
console.print(
844-
f"\n[red]✗ Missing required columns:[/red] {', '.join(missing_cols)}"
893+
f"\n[red]✗ Missing required columns:[/red] {', '.join(missing_cols)}",
845894
)
846895
else:
847896
console.print("\n[green]✓ All required VCF columns present[/green]")
@@ -905,8 +954,7 @@ def stats(
905954
console.print("\n[bold]Variant Types:[/bold]")
906955
snp_count = len(
907956
ivar_df.filter(
908-
~pl.col("ALT").str.starts_with("+")
909-
& ~pl.col("ALT").str.starts_with("-"),
957+
~pl.col("ALT").str.starts_with("+") & ~pl.col("ALT").str.starts_with("-"),
910958
),
911959
)
912960
ins_count = len(ivar_df.filter(pl.col("ALT").str.starts_with("+")))
@@ -929,8 +977,8 @@ def stats(
929977
for low, high in freq_bins:
930978
count = len(
931979
ivar_df.filter(
932-
(pl.col("ALT_FREQ") >= low) & (pl.col("ALT_FREQ") < high)
933-
)
980+
(pl.col("ALT_FREQ") >= low) & (pl.col("ALT_FREQ") < high),
981+
),
934982
)
935983
if count > 0:
936984
console.print(f" • {low:.0%}-{high:.0%}: {count:,} variants")

bin/prepare_primers.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -492,7 +492,10 @@ def generate_splice_combinations(
492492

493493
result: list[tuple[BedRecord, str]] = []
494494

495-
for splice_idx, ((fwd_rec, fwd_name), (rev_rec, rev_name)) in enumerate(all_pairs, 1):
495+
for splice_idx, ((fwd_rec, fwd_name), (rev_rec, rev_name)) in enumerate(
496+
all_pairs,
497+
1,
498+
):
496499
# Strip the index suffix and add splice identifier
497500
fwd_base = fwd_name.rsplit("-", 1)[0]
498501
rev_base = rev_name.rsplit("-", 1)[0]

bin/test_concat_consensus.py

Lines changed: 20 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
from pathlib import Path
1010
from textwrap import dedent
1111
from unittest.mock import patch
12+
1213
import pytest
1314
from Bio import SeqIO
1415

@@ -174,24 +175,25 @@ def test_main_with_malformed_fasta(self, malformed_fasta_file, monkeypatch):
174175
assert str(records[0].seq) == "" # No valid sequences parsed
175176

176177
def test_file_naming_with_different_extensions(self, temp_dir, monkeypatch):
177-
"""Test that only .consensus.fasta files are processed."""
178+
"""Test that .consensus.fasta and .consensus.fa files are processed."""
178179
monkeypatch.chdir(temp_dir)
179180

180181
# Create files with different extensions
181182
(temp_dir / "sample.consensus.fasta").write_text(">seq\nATCG")
182-
(temp_dir / "other.fasta").write_text(">seq\nGCTA")
183+
(temp_dir / "other.fasta").write_text(">seq\nGCTA") # Should NOT match
183184
(temp_dir / "another.consensus.fa").write_text(
184-
">seq\nTTTT"
185-
) # Note: .fa not .fasta
185+
">seq\nTTTT",
186+
) # Should match - .fa is valid
186187

187188
main()
188189

189190
output_file = Path("all_sample_consensus.fasta")
190191
records = list(SeqIO.parse(output_file, "fasta"))
191192

192-
# Only the .consensus.fasta file should be processed
193-
assert len(records) == 1
194-
assert records[0].id == "sample"
193+
# Both .consensus.fasta and .consensus.fa files should be processed
194+
assert len(records) == 2
195+
record_ids = {r.id for r in records}
196+
assert record_ids == {"sample", "another.consensus.fa"}
195197

196198
def test_sample_name_extraction(self, temp_dir, monkeypatch):
197199
"""Test correct extraction of sample names from file paths."""
@@ -327,17 +329,18 @@ def test_glob_pattern_matching(self, temp_dir, monkeypatch):
327329
monkeypatch.chdir(temp_dir)
328330

329331
# Create files that should and shouldn't match
332+
# Pattern is *.consensus.fa* so matches .fasta, .fa, .fastq etc.
330333
matching_files = [
331334
"sample1.consensus.fasta",
332335
"sample2.consensus.fasta",
333336
"SAMPLE3.consensus.fasta", # Test case sensitivity
337+
"sample4.consensus.fa", # .fa extension also matches
334338
]
335339

336340
non_matching_files = [
337-
"sample.consensus.fastq", # Wrong extension
338-
"sample_consensus.fasta", # Missing dot
339-
"sample.consensus", # Missing extension
340-
"consensus.fasta", # Missing sample name
341+
"sample_consensus.fasta", # Missing dot before consensus
342+
"sample.consensus", # Missing extension after .consensus
343+
"consensus.fasta", # Missing sample name prefix
341344
]
342345

343346
for filename in matching_files:
@@ -354,7 +357,8 @@ def test_glob_pattern_matching(self, temp_dir, monkeypatch):
354357
# Only matching files should be processed
355358
assert len(records) == len(matching_files)
356359
record_ids = {r.id for r in records}
357-
expected_ids = {"sample1", "sample2", "SAMPLE3"}
360+
# Note: .consensus.fasta is stripped, but .consensus.fa is not
361+
expected_ids = {"sample1", "sample2", "SAMPLE3", "sample4.consensus.fa"}
358362
assert record_ids == expected_ids
359363

360364

@@ -367,7 +371,10 @@ def test_glob_pattern_matching(self, temp_dir, monkeypatch):
367371
],
368372
)
369373
def test_performance_with_many_files(
370-
temp_dir, monkeypatch, num_files, num_seqs_per_file
374+
temp_dir,
375+
monkeypatch,
376+
num_files,
377+
num_seqs_per_file,
371378
):
372379
"""Test performance with many input files."""
373380
monkeypatch.chdir(temp_dir)

0 commit comments

Comments
 (0)