Skip to content

Commit b6beec7

Browse files
authored
Merge pull request #219 from jkmckenna/codex/refactor-branch-for-optional-dependencies
Standardize optional dependency loading with `require`, add TYPE_CHECKING hints, and simplify lazy exports
2 parents 619bcee + 980fdf3 commit b6beec7

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

41 files changed

+296
-238
lines changed

src/smftools/cli/hmm_adata.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,18 +3,22 @@
33
import copy
44
from dataclasses import dataclass
55
from pathlib import Path
6-
from typing import Any, List, Optional, Sequence, Tuple, Union
6+
from typing import Any, List, Optional, Sequence, Tuple, Union, TYPE_CHECKING
77

88
import numpy as np
9-
import torch
10-
119
from smftools.logging_utils import get_logger
10+
from smftools.optional_imports import require
1211

1312
# FIX: import _to_dense_np to avoid NameError
1413
from ..hmm.HMM import _safe_int_coords, _to_dense_np, create_hmm, normalize_hmm_feature_sets
1514

1615
logger = get_logger(__name__)
1716

17+
if TYPE_CHECKING:
18+
import torch as torch_types
19+
20+
torch = require("torch", extra="torch", purpose="HMM CLI")
21+
1822
# =============================================================================
1923
# Helpers: extracting training arrays
2024
# =============================================================================

src/smftools/cli/spatial_adata.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
import anndata as ad
77

88
from smftools.logging_utils import get_logger
9+
from smftools.optional_imports import require
910

1011
logger = get_logger(__name__)
1112

@@ -155,7 +156,8 @@ def spatial_adata_core(
155156

156157
import numpy as np
157158
import pandas as pd
158-
import scanpy as sc
159+
160+
sc = require("scanpy", extra="scanpy", purpose="spatial analyses")
159161

160162
from ..metadata import record_smftools_metadata
161163
from ..plotting import (

src/smftools/hmm/HMM.py

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,20 @@
33
import ast
44
import json
55
from pathlib import Path
6-
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
6+
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, TYPE_CHECKING
77

88
import numpy as np
9-
import torch
10-
import torch.nn as nn
119
from scipy.sparse import issparse
1210

1311
from smftools.logging_utils import get_logger
12+
from smftools.optional_imports import require
13+
14+
if TYPE_CHECKING:
15+
import torch as torch_types
16+
import torch.nn as nn_types
17+
18+
torch = require("torch", extra="torch", purpose="HMM modeling")
19+
nn = torch.nn
1420

1521
logger = get_logger(__name__)
1622
# =============================================================================

src/smftools/hmm/__init__.py

Lines changed: 1 addition & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -21,11 +21,4 @@ def __getattr__(name: str):
2121
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
2222

2323

24-
__all__ = [
25-
"call_hmm_peaks",
26-
"display_hmm",
27-
"load_hmm",
28-
"refine_nucleosome_calls",
29-
"infer_nucleosomes_in_large_bound",
30-
"save_hmm",
31-
]
24+
__all__ = list(_LAZY_ATTRS.keys())

src/smftools/hmm/call_hmm_peaks.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55
from typing import Any, Dict, Optional, Sequence, Union
66

77
from smftools.logging_utils import get_logger
8+
from smftools.optional_imports import require
89

910
logger = get_logger(__name__)
1011

@@ -36,12 +37,13 @@ def call_hmm_peaks(
3637
- adata.var["is_in_any_{layer}_peak_{ref}"]
3738
- adata.var["is_in_any_peak"] (global)
3839
"""
39-
import matplotlib.pyplot as plt
4040
import numpy as np
4141
import pandas as pd
4242
from scipy.signal import find_peaks
4343
from scipy.sparse import issparse
4444

45+
plt = require("matplotlib.pyplot", extra="plotting", purpose="HMM peak plots")
46+
4547
if not inplace:
4648
adata = adata.copy()
4749

src/smftools/hmm/display_hmm.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
from smftools.logging_utils import get_logger
4+
from smftools.optional_imports import require
45

56
logger = get_logger(__name__)
67

@@ -13,7 +14,7 @@ def display_hmm(hmm, state_labels=["Non-Methylated", "Methylated"], obs_labels=[
1314
state_labels: Optional labels for states.
1415
obs_labels: Optional labels for observations.
1516
"""
16-
import torch
17+
torch = require("torch", extra="torch", purpose="HMM display")
1718

1819
logger.info("**HMM Model Overview**")
1920
logger.info("%s", hmm)

src/smftools/hmm/hmm_readwrite.py

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
from __future__ import annotations
22

3+
from smftools.optional_imports import require
4+
35

46
def load_hmm(model_path, device="cpu"):
57
"""
@@ -8,7 +10,7 @@ def load_hmm(model_path, device="cpu"):
810
Parameters:
911
model_path (str): Path to a pretrained HMM
1012
"""
11-
import torch
13+
torch = require("torch", extra="torch", purpose="HMM read/write")
1214

1315
# Load model using PyTorch
1416
hmm = torch.load(model_path)
@@ -23,6 +25,6 @@ def save_hmm(model, model_path):
2325
model: HMM model instance.
2426
model_path: Output path for the model.
2527
"""
26-
import torch
28+
torch = require("torch", extra="torch", purpose="HMM read/write")
2729

2830
torch.save(model, model_path)

src/smftools/informatics/converted_BAM_to_adata.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -7,14 +7,13 @@
77
import traceback
88
from multiprocessing import Manager, Pool, current_process
99
from pathlib import Path
10-
from typing import Iterable, Optional, Union
10+
from typing import Iterable, Optional, Union, TYPE_CHECKING
1111

1212
import anndata as ad
1313
import numpy as np
1414
import pandas as pd
15-
import torch
16-
1715
from smftools.logging_utils import get_logger
16+
from smftools.optional_imports import require
1817

1918
from ..readwrite import make_dirs
2019
from .bam_functions import count_aligned_reads, extract_base_identities
@@ -24,6 +23,11 @@
2423

2524
logger = get_logger(__name__)
2625

26+
if TYPE_CHECKING:
27+
import torch as torch_types
28+
29+
torch = require("torch", extra="torch", purpose="converted BAM processing")
30+
2731
if __name__ == "__main__":
2832
multiprocessing.set_start_method("forkserver", force=True)
2933

src/smftools/informatics/fasta_functions.py

Lines changed: 23 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,39 +1,44 @@
11
from __future__ import annotations
22

33
import gzip
4+
import shutil
5+
import subprocess
46
from concurrent.futures import ProcessPoolExecutor
7+
from importlib.util import find_spec
58
from pathlib import Path
6-
from typing import Dict, Iterable, Tuple
9+
from typing import Dict, Iterable, Tuple, TYPE_CHECKING
710

811
import numpy as np
912
from Bio import SeqIO
1013
from Bio.Seq import Seq
1114
from Bio.SeqRecord import SeqRecord
1215

1316
from smftools.logging_utils import get_logger
17+
from smftools.optional_imports import require
1418

1519
from ..readwrite import time_string
1620

1721
logger = get_logger(__name__)
1822

19-
try:
20-
import pysam
21-
except Exception:
22-
pysam = None # type: ignore
23+
if TYPE_CHECKING:
24+
import pysam as pysam_module
2325

24-
try:
25-
import shutil
26-
import subprocess
27-
except Exception: # pragma: no cover - stdlib
28-
shutil = None # type: ignore
29-
subprocess = None # type: ignore
26+
27+
def _require_pysam() -> "pysam_module":
28+
if pysam_types is not None:
29+
return pysam_types
30+
return require("pysam", extra="pysam", purpose="FASTA access")
31+
32+
pysam_types = None
33+
if find_spec("pysam") is not None:
34+
pysam_types = require("pysam", extra="pysam", purpose="FASTA access")
3035

3136

3237
def _resolve_fasta_backend() -> str:
3338
"""Resolve the backend to use for FASTA access."""
3439
if shutil is not None and shutil.which("samtools"):
3540
return "cli"
36-
if pysam is not None:
41+
if pysam_types is not None:
3742
return "python"
3843
raise RuntimeError("FASTA access requires pysam or samtools in PATH.")
3944

@@ -43,10 +48,9 @@ def _ensure_fasta_index(fasta: Path) -> None:
4348
if fai.exists():
4449
return
4550
if subprocess is None or shutil is None or not shutil.which("samtools"):
46-
if pysam is not None:
47-
pysam.faidx(str(fasta))
48-
return
49-
raise RuntimeError("FASTA indexing requires pysam or samtools in PATH.")
51+
pysam_mod = _require_pysam()
52+
pysam_mod.faidx(str(fasta))
53+
return
5054
cp = subprocess.run(
5155
["samtools", "faidx", str(fasta)],
5256
stdout=subprocess.DEVNULL,
@@ -225,7 +229,7 @@ def index_fasta(fasta: str | Path, write_chrom_sizes: bool = True) -> Path:
225229
Path: Path to the index file or chromosome sizes file.
226230
"""
227231
fasta = Path(fasta)
228-
pysam.faidx(str(fasta)) # creates <fasta>.fai
232+
_require_pysam().faidx(str(fasta)) # creates <fasta>.fai
229233

230234
fai = fasta.with_suffix(fasta.suffix + ".fai")
231235
if write_chrom_sizes:
@@ -377,8 +381,8 @@ def subsample_fasta_from_bed(
377381

378382
fasta_handle = None
379383
if backend == "python":
380-
assert pysam is not None
381-
fasta_handle = pysam.FastaFile(str(input_FASTA))
384+
pysam_mod = _require_pysam()
385+
fasta_handle = pysam_mod.FastaFile(str(input_FASTA))
382386

383387
# Open BED + output FASTA
384388
with input_bed.open("r") as bed, output_FASTA.open("w") as out_fasta:

src/smftools/informatics/h5ad_functions.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,12 +9,15 @@
99
import numpy as np
1010
import pandas as pd
1111
import scipy.sparse as sp
12-
from pod5 import Reader
1312

1413
from smftools.logging_utils import get_logger
14+
from smftools.optional_imports import require
1515

1616
logger = get_logger(__name__)
1717

18+
p5 = require("pod5", extra="ont", purpose="POD5 metadata")
19+
Reader = p5.Reader
20+
1821

1922
def add_demux_type_annotation(
2023
adata,

0 commit comments

Comments
 (0)