Skip to content

Commit ffdd9e0

Browse files
authored
Merge pull request #100 from EliHei2/segger_taylors_version
Segger Taylor's Version TLDR: Simplified package dependencies and installation Rolled back model definition, training, and prediction to previous working versions Note: This PR adds guides for pre-processing CosMX and MERSCOPE data for segger, but these platform are still not supported. More Detailed: Made the following changes for usability and bug fixing: Added uv for dependency management; removed dask and RAPIDS as dependencies Updated README with installation instructions Changes to model were untested and broke model evaluation; reverted to previous model definition Added testing to ensure full pipeline works with minimal xenium example Removed redundant global imports in __init__.py files which otherwise require the entire segger package and all dependencies to be imported everywhere in the repo. Removed out-of-date tests. Added utility functions and a README.md guide for preprocessing CosMX dataset to output nuclear boundaries (needed for shape estimation during dataset creation) Added installation guide and param file for preprocessing MERSCOPE data using VPT to get a nuclear segmentation.
2 parents b629684 + 34ce2aa commit ffdd9e0

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

58 files changed

+5532
-1601
lines changed

.dev/_get_imports.py

Lines changed: 128 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,128 @@
1+
# Re-import after code reset
2+
import importlib.metadata
3+
import os
4+
import ast
5+
import sys
6+
import re
7+
import pandas as pd
8+
import pathlib
9+
from importlib.metadata import distributions
10+
import tomllib
11+
12+
13+
def extract_third_party_imports(root_dir: str) -> pd.DataFrame:
14+
"""
15+
Walk codebase and collect third-party root import names.
16+
"""
17+
stdlib = (
18+
set(sys.stdlib_module_names) if hasattr(sys, "stdlib_module_names") else set()
19+
)
20+
rows = []
21+
22+
for dirpath, _, filenames in os.walk(root_dir):
23+
for filename in filenames:
24+
if not filename.endswith(".py"):
25+
continue
26+
full_path = os.path.join(dirpath, filename)
27+
try:
28+
with open(full_path, "r", encoding="utf-8") as f:
29+
tree = ast.parse(f.read(), filename=full_path)
30+
for node in ast.walk(tree):
31+
if isinstance(node, ast.Import):
32+
for alias in node.names:
33+
name = alias.name.split(".")[0]
34+
if name not in stdlib:
35+
rows.append((full_path, name))
36+
elif isinstance(node, ast.ImportFrom):
37+
if node.module:
38+
name = node.module.split(".")[0]
39+
if name not in stdlib:
40+
rows.append((full_path, name))
41+
except (SyntaxError, UnicodeDecodeError):
42+
continue
43+
44+
return pd.DataFrame(rows, columns=["filename", "root_package"]).drop_duplicates()
45+
46+
47+
def _extract_pkg_name(dep: str) -> str:
48+
return re.split(r"[<>=~! ]", dep, 1)[0].strip().lower()
49+
50+
51+
def _get_import_names(declared: set[str]) -> set[str]:
52+
"""
53+
Given a set of declared package names, return the set of all import names
54+
associated with those packages based on installed distributions.
55+
"""
56+
dist_map = importlib.metadata.packages_distributions()
57+
import_names = set()
58+
59+
for dep in declared:
60+
dep_matches = {k for k, v in dist_map.items() if dep in v}
61+
if dep_matches:
62+
import_names.update(dep_matches)
63+
else:
64+
import_names.add(dep.lower().replace("-", "_"))
65+
66+
return import_names
67+
68+
69+
def find_missing_dependencies(project_path: os.PathLike) -> set[str]:
70+
"""
71+
Compare third-party imports with declared dependencies in pyproject.toml.
72+
73+
Parameters
74+
----------
75+
project_path : str
76+
Base path of the Python project
77+
78+
Returns
79+
-------
80+
pd.DataFrame
81+
Subset of `imports_df` where the root_package is used but not declared
82+
in pyproject.toml.
83+
"""
84+
project_path = pathlib.Path(project_path)
85+
with open(project_path / "pyproject.toml", "rb") as f:
86+
toml = tomllib.load(f)
87+
88+
declared = {_extract_pkg_name(d) for d in toml["project"]["dependencies"]}
89+
optional = toml["project"].get("optional-dependencies", {})
90+
for group in optional.values():
91+
declared.update(_extract_pkg_name(d) for d in group)
92+
project_name = toml["project"]["name"].replace("-", "_").lower()
93+
declared.add(project_name)
94+
declared = _get_import_names(declared)
95+
96+
imports = extract_third_party_imports(project_path / "src")
97+
98+
return imports[~imports["root_package"].isin(declared)]
99+
100+
101+
if __name__ == "__main__":
102+
import argparse
103+
104+
parser = argparse.ArgumentParser(description="Find undeclared third-party imports.")
105+
parser.add_argument(
106+
"--base",
107+
type=str,
108+
help="Path to the base Python package or source root.",
109+
default="./",
110+
)
111+
parser.add_argument(
112+
"--exclude",
113+
nargs="*",
114+
default=[],
115+
help="List of package names to exclude from the check.",
116+
)
117+
118+
args = parser.parse_args()
119+
missing_df = find_missing_dependencies(pathlib.Path(args.base))
120+
121+
if args.exclude:
122+
missing_df = missing_df[~missing_df["root_package"].isin(args.exclude)]
123+
124+
if missing_df.empty:
125+
print("No missing dependencies found.")
126+
else:
127+
print("Missing dependencies:")
128+
print(missing_df.sort_values("root_package").to_string(index=False))

.gitignore

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -172,3 +172,8 @@ figure*
172172
dev*
173173
.DS_Store
174174
.idea/
175+
176+
177+
# Custom
178+
*_old*
179+
.dev

.python_version

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
3.11.11

.scripts/create_dataset.py

Lines changed: 24 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -81,7 +81,9 @@ def main(args):
8181

8282

8383
if __name__ == "__main__":
84-
parser = argparse.ArgumentParser(description="Create dataset from Xenium Human Pancreatic data.")
84+
parser = argparse.ArgumentParser(
85+
description="Create dataset from Xenium Human Pancreatic data."
86+
)
8587
parser.add_argument(
8688
"--raw_data_dir",
8789
type=str,
@@ -100,7 +102,9 @@ def main(args):
100102
required=True,
101103
help="URL for transcripts data.",
102104
)
103-
parser.add_argument("--nuclei_url", type=str, required=True, help="URL for nuclei data.")
105+
parser.add_argument(
106+
"--nuclei_url", type=str, required=True, help="URL for nuclei data."
107+
)
104108
parser.add_argument(
105109
"--min_qv",
106110
type=int,
@@ -121,9 +125,15 @@ def main(args):
121125
)
122126
parser.add_argument("--x_size", type=int, default=200, help="Width of each tile.")
123127
parser.add_argument("--y_size", type=int, default=200, help="Height of each tile.")
124-
parser.add_argument("--margin_x", type=int, default=None, help="Margin in x direction.")
125-
parser.add_argument("--margin_y", type=int, default=None, help="Margin in y direction.")
126-
parser.add_argument("--r_tx", type=int, default=3, help="Radius for building the graph.")
128+
parser.add_argument(
129+
"--margin_x", type=int, default=None, help="Margin in x direction."
130+
)
131+
parser.add_argument(
132+
"--margin_y", type=int, default=None, help="Margin in y direction."
133+
)
134+
parser.add_argument(
135+
"--r_tx", type=int, default=3, help="Radius for building the graph."
136+
)
127137
parser.add_argument(
128138
"--val_prob",
129139
type=float,
@@ -142,7 +152,9 @@ def main(args):
142152
default=3,
143153
help="Number of nearest neighbors for nuclei.",
144154
)
145-
parser.add_argument("--dist_nc", type=int, default=10, help="Distance threshold for nuclei.")
155+
parser.add_argument(
156+
"--dist_nc", type=int, default=10, help="Distance threshold for nuclei."
157+
)
146158
parser.add_argument(
147159
"--k_tx",
148160
type=int,
@@ -161,8 +173,12 @@ def main(args):
161173
default=True,
162174
help="Whether to compute edge labels.",
163175
)
164-
parser.add_argument("--sampling_rate", type=float, default=1, help="Rate of sampling tiles.")
165-
parser.add_argument("--parallel", action="store_true", help="Use parallel processing.")
176+
parser.add_argument(
177+
"--sampling_rate", type=float, default=1, help="Rate of sampling tiles."
178+
)
179+
parser.add_argument(
180+
"--parallel", action="store_true", help="Use parallel processing."
181+
)
166182
parser.add_argument(
167183
"--num_workers",
168184
type=int,

.scripts/predict.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -49,15 +49,21 @@ def main(args: argparse.Namespace) -> None:
4949
required=True,
5050
help="Path to the model checkpoint",
5151
)
52-
parser.add_argument("--init_emb", type=int, default=8, help="Initial embedding size")
52+
parser.add_argument(
53+
"--init_emb", type=int, default=8, help="Initial embedding size"
54+
)
5355
parser.add_argument(
5456
"--hidden_channels",
5557
type=int,
5658
default=64,
5759
help="Number of hidden channels",
5860
)
59-
parser.add_argument("--out_channels", type=int, default=16, help="Number of output channels")
60-
parser.add_argument("--heads", type=int, default=4, help="Number of attention heads")
61+
parser.add_argument(
62+
"--out_channels", type=int, default=16, help="Number of output channels"
63+
)
64+
parser.add_argument(
65+
"--heads", type=int, default=4, help="Number of attention heads"
66+
)
6167
parser.add_argument("--aggr", type=str, default="sum", help="Aggregation method")
6268
parser.add_argument(
6369
"--score_cut",
@@ -71,7 +77,9 @@ def main(args: argparse.Namespace) -> None:
7177
default=4,
7278
help="Number of nearest neighbors for nuclei",
7379
)
74-
parser.add_argument("--dist_nc", type=int, default=20, help="Distance threshold for nuclei")
80+
parser.add_argument(
81+
"--dist_nc", type=int, default=20, help="Distance threshold for nuclei"
82+
)
7583
parser.add_argument(
7684
"--k_tx",
7785
type=int,

.scripts/train_model.py

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -95,19 +95,31 @@ def main(args):
9595
default=4,
9696
help="Batch size for validation",
9797
)
98-
parser.add_argument("--init_emb", type=int, default=8, help="Initial embedding size")
98+
parser.add_argument(
99+
"--init_emb", type=int, default=8, help="Initial embedding size"
100+
)
99101
parser.add_argument(
100102
"--hidden_channels",
101103
type=int,
102104
default=64,
103105
help="Number of hidden channels",
104106
)
105-
parser.add_argument("--out_channels", type=int, default=16, help="Number of output channels")
106-
parser.add_argument("--heads", type=int, default=4, help="Number of attention heads")
107+
parser.add_argument(
108+
"--out_channels", type=int, default=16, help="Number of output channels"
109+
)
110+
parser.add_argument(
111+
"--heads", type=int, default=4, help="Number of attention heads"
112+
)
107113
parser.add_argument("--aggr", type=str, default="sum", help="Aggregation method")
108-
parser.add_argument("--accelerator", type=str, default="cuda", help="Type of accelerator")
109-
parser.add_argument("--strategy", type=str, default="auto", help="Training strategy")
110-
parser.add_argument("--precision", type=str, default="16-mixed", help="Precision mode")
114+
parser.add_argument(
115+
"--accelerator", type=str, default="cuda", help="Type of accelerator"
116+
)
117+
parser.add_argument(
118+
"--strategy", type=str, default="auto", help="Training strategy"
119+
)
120+
parser.add_argument(
121+
"--precision", type=str, default="16-mixed", help="Precision mode"
122+
)
111123
parser.add_argument("--devices", type=int, default=4, help="Number of devices")
112124
parser.add_argument("--epochs", type=int, default=100, help="Number of epochs")
113125
parser.add_argument(

0 commit comments

Comments
 (0)