Skip to content

Commit 2cc8f92

Browse files
Merge branch 'main' into more-inner-ear-analysis
2 parents 69d7a3d + 8dbe7d3 commit 2cc8f92

File tree

13 files changed

+1013
-25
lines changed

13 files changed

+1013
-25
lines changed

.github/doc_env.yaml

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,23 @@
1+
channels:
2+
- conda-forge
3+
- pytorch
4+
name:
5+
synaptic-reconstruction
6+
dependencies:
7+
- python=3.12
8+
- python-elf
9+
- torch_em
10+
- napari
11+
- pip
12+
- pyqt
13+
- magicgui
14+
- pytorch
15+
- bioimageio.core
16+
- kornia
17+
- tensorboard
18+
- pdoc
19+
- scikit-learn
20+
- mrcfile
21+
- trimesh
22+
- pip:
23+
- napari-skimage-regionprops

.github/workflows/build_docs.yaml

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
name: Build and Deploy Docs
2+
3+
on:
4+
push:
5+
paths:
6+
- "doc/*.md" # Trigger on changes to any markdown file
7+
- "**/*.py" # Optionally include changes in Python files
8+
branches:
9+
- main # Run the workflow only on pushes to the main branch
10+
workflow_dispatch:
11+
12+
# security: restrict permissions for CI jobs.
13+
permissions:
14+
contents: read
15+
16+
jobs:
17+
build:
18+
name: Build Documentation
19+
runs-on: ubuntu-latest
20+
21+
steps:
22+
- name: Checkout Code
23+
uses: actions/checkout@v4
24+
25+
- name: Set up Micromamba
26+
uses: mamba-org/setup-micromamba@v2
27+
with:
28+
micromamba-version: "latest" # Use the latest version of micromamba
29+
environment-file: .github/doc_env.yaml # Reference your environment.yml file
30+
init-shell: bash
31+
cache-environment: true
32+
post-cleanup: 'all'
33+
# cache: true # Cache the micromamba environment
34+
35+
- name: Install package
36+
# shell: bash -l {0}
37+
run: pip install -e .
38+
39+
- name: Generate Documentation
40+
shell: bash -l {0}
41+
run: pdoc synaptic_reconstruction -o doc/
42+
43+
- name: Verify Documentation Output
44+
run: ls -la doc/
45+
46+
- name: Upload Documentation Artifact
47+
uses: actions/upload-pages-artifact@v3
48+
with:
49+
# name: documentation
50+
path: doc/
51+
52+
deploy:
53+
name: Deploy Documentation
54+
needs: build
55+
runs-on: ubuntu-latest
56+
permissions:
57+
pages: write
58+
id-token: write
59+
environment:
60+
name: github-pages
61+
url: ${{ steps.deployment.outputs.page_url }}
62+
steps:
63+
# - name: Download Documentation Artifact
64+
# uses: actions/download-artifact@v4
65+
# with:
66+
# name: documentation
67+
# path: .
68+
69+
- name: Deploy to GiHub Pages
70+
uses: actions/deploy-pages@v4
71+
# with:
72+
# artifact_name: documentation
73+

build_doc.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
import argparse
2+
import glob
3+
import os
4+
import warnings
5+
6+
from subprocess import run
7+
8+
9+
def check_docs_completeness():
10+
"""@private
11+
All markdown and RST documentation files **SHOULD** be included in the module
12+
docstring at micro_sam/__init__.py
13+
"""
14+
import micro_sam
15+
16+
# We don't search in subfolders anymore, to allow putting additional documentation
17+
# (e.g. for bioimage.io mdoels) that should not be included in the main documentation here.
18+
markdown_doc_files = glob.glob("doc/*.md", recursive=True)
19+
rst_doc_files = glob.glob("doc/*.rst", recursive=True)
20+
all_doc_files = markdown_doc_files + rst_doc_files
21+
missing_from_docs = [f for f in all_doc_files if os.path.basename(f) not in micro_sam.__doc__]
22+
if len(missing_from_docs) > 0:
23+
warnings.warn(
24+
"Documentation files missing! Please add include statements "
25+
"to the docstring in micro_sam/__init__.py for every file, eg:"
26+
"'.. include:: ../doc/filename.md'. "
27+
"List of missing files: "
28+
f"{missing_from_docs}"
29+
)
30+
31+
32+
if __name__ == "__main__":
33+
check_docs_completeness()
34+
35+
parser = argparse.ArgumentParser()
36+
parser.add_argument("--out", "-o", action="store_true")
37+
args = parser.parse_args()
38+
39+
logo_url = "https://raw.githubusercontent.com/computational-cell-analytics/micro-sam/master/doc/logo/logo_and_text.png"
40+
cmd = ["pdoc", "--docformat", "google", "--logo", logo_url]
41+
42+
if args.out:
43+
cmd.extend(["--out", "tmp/"])
44+
cmd.append("micro_sam")
45+
46+
run(cmd)

doc/start_page.md

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
# Synaptic Reconstruction
2+
lorem ipsum...

environment.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,5 +12,6 @@ dependencies:
1212
- bioimageio.core
1313
- kornia
1414
- tensorboard
15+
- trimesh
1516
- pip:
1617
- napari-skimage-regionprops
Lines changed: 196 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,196 @@
1+
import torch
2+
import os
3+
import argparse
4+
import time
5+
import torch_em
6+
7+
from torch_em.data import MinInstanceSampler
8+
from torch_em.model import AnisotropicUNet
9+
# from torch_em.util.debug import check_loader, check_trainer
10+
11+
# Import your util.py for data loading
12+
import util
13+
14+
15+
def main():
16+
parser = argparse.ArgumentParser(description="3D UNet training for mitochondrial segmentation")
17+
parser.add_argument(
18+
"--data_dir", type=str, default="/scratch-grete/projects/nim00007/data/mitochondria/cooper/fidi_down_s2",
19+
help="Path to the data directory"
20+
)
21+
parser.add_argument(
22+
"--patch_shape", type=int, nargs=3, default=(32, 256, 256), help="Patch shape for data loading (3D tuple)"
23+
)
24+
parser.add_argument(
25+
"--n_iterations", type=int, default=10000, help="Number of training iterations"
26+
)
27+
parser.add_argument(
28+
"--learning_rate", type=float, default=1e-4, help="Learning rate"
29+
)
30+
parser.add_argument(
31+
"--checkpoint_path", type=str, default="", help="Path to checkpoint used to load model's state_dict"
32+
)
33+
parser.add_argument(
34+
"--experiment_name", type=str, default="default-mito-net",
35+
help="Name that is used for the experiment and store the model's weights"
36+
)
37+
parser.add_argument(
38+
"--batch_size", type=int, default=1, help="Batch size to be used"
39+
)
40+
parser.add_argument(
41+
"--feature_size", type=int, default=32, help="Initial feature size of the 3D UNet"
42+
)
43+
parser.add_argument(
44+
"--without_rois", type=bool, default=False, help="Train without Regions Of Interest (ROI)"
45+
)
46+
parser.add_argument(
47+
"--early_stopping", type=int, default=10, help="Number of epochs without improvement before stopping training"
48+
)
49+
parser.add_argument(
50+
"--save_dir", type=str, default="./", help="Path where the model checkpoints will be saved."
51+
)
52+
53+
# Parse arguments
54+
args = parser.parse_args()
55+
checkpoint_path = args.checkpoint_path
56+
n_iterations = args.n_iterations
57+
learning_rate = args.learning_rate
58+
data_dir = args.data_dir
59+
save_dir = args.save_dir
60+
experiment_name = args.experiment_name
61+
batch_size = args.batch_size
62+
patch_shape = args.patch_shape
63+
initial_features = args.feature_size
64+
with_rois = not args.without_rois
65+
early_stopping = args.early_stopping
66+
67+
n_workers = 12 if torch.cuda.is_available() else 1
68+
device = "cuda" if torch.cuda.is_available() else "cpu"
69+
print(f"\n Experiment: {experiment_name}\n")
70+
print(f"Using {device} with {n_workers} workers.")
71+
label_transform = torch_em.transform.label.BoundaryTransform(add_binary_target=True)
72+
73+
loss_name = "dice"
74+
metric_name = "dice"
75+
ndim = 3
76+
77+
loss_function = util.get_loss_function(loss_name)
78+
metric_function = util.get_loss_function(metric_name)
79+
in_channels, out_channels = 1, 2
80+
gain = 2
81+
82+
scale_factors = [
83+
[1, 2, 2],
84+
[1, 2, 2],
85+
[2, 2, 2],
86+
[2, 2, 2]
87+
]
88+
89+
final_activation = None
90+
if final_activation is None and loss_name == "dice":
91+
final_activation = "Sigmoid"
92+
93+
# load data paths etc.
94+
start_time = time.time()
95+
print(f"Start time {time.ctime()}")
96+
print(f"Loading Data paths and ROIs if with_rois={with_rois}...")
97+
98+
if with_rois:
99+
data_paths, rois_dict = util.get_data_paths_and_rois(
100+
data_dir, min_shape=patch_shape, with_thresholds=True
101+
)
102+
data, rois_dict = util.split_data_paths_to_dict(
103+
data_paths, rois_dict, train_ratio=.8, val_ratio=0.2, test_ratio=0
104+
)
105+
else:
106+
data_paths = util.get_data_paths(data_dir)
107+
108+
for path in data_paths:
109+
if "combined" in path:
110+
data_paths.remove(path)
111+
data_paths.sort(reverse=True)
112+
data = util.split_data_paths_to_dict(
113+
data_paths, rois_list=None, train_ratio=.8, val_ratio=0.15, test_ratio=0.05
114+
)
115+
116+
end_time = time.time()
117+
# Calculate execution time in seconds
118+
execution_time = end_time - start_time
119+
print(f"Data and ROI preprocessing execution time: {execution_time:.6f} seconds")
120+
121+
print("Creating 3d UNet with", in_channels, "input channels and", out_channels, "output channels.")
122+
# UNet3d
123+
model = AnisotropicUNet(
124+
in_channels=in_channels, out_channels=out_channels, initial_features=initial_features,
125+
final_activation=final_activation, gain=gain, scale_factors=scale_factors
126+
)
127+
print("Does a checkpoint exist at", os.path.join(save_dir, "checkpoints", experiment_name, "best.pt"), "?")
128+
print(os.path.exists(os.path.join(save_dir, "checkpoints", experiment_name, "best.pt")))
129+
if checkpoint_path or os.path.exists(os.path.join(save_dir, "checkpoints", experiment_name, "best.pt")):
130+
if not checkpoint_path:
131+
checkpoint_path = os.path.join(save_dir, "checkpoints", experiment_name)
132+
model = torch_em.util.load_model(checkpoint=checkpoint_path, device=device)
133+
print("loaded model from checkpoint:", os.path.join(save_dir, "checkpoints", experiment_name))
134+
model.to(device)
135+
print(model)
136+
with_channels = False
137+
with_label_channels = False
138+
sampler = MinInstanceSampler(p_reject=0.95)
139+
140+
print("train", len(data["train"]), "val", len(data["val"]), "test", len(data["test"]))
141+
print("data['test']", data["test"])
142+
143+
if with_rois:
144+
train_loader = torch_em.default_segmentation_loader(
145+
raw_paths=data["train"], raw_key="raw",
146+
label_paths=data["train"], label_key="labels/mitochondria",
147+
patch_shape=patch_shape, ndim=ndim, batch_size=batch_size,
148+
label_transform=label_transform, num_workers=n_workers,
149+
with_channels=with_channels, with_label_channels=with_label_channels,
150+
rois=rois_dict["train"]
151+
)
152+
val_loader = torch_em.default_segmentation_loader(
153+
raw_paths=data["val"], raw_key="raw",
154+
label_paths=data["val"], label_key="labels/mitochondria",
155+
patch_shape=patch_shape, ndim=ndim, batch_size=batch_size,
156+
label_transform=label_transform, num_workers=n_workers,
157+
with_channels=with_channels, with_label_channels=with_label_channels,
158+
rois=rois_dict["val"]
159+
)
160+
else:
161+
train_loader = torch_em.default_segmentation_loader(
162+
raw_paths=data["train"], raw_key="raw",
163+
label_paths=data["train"], label_key="labels/mitochondria",
164+
patch_shape=patch_shape, ndim=ndim, batch_size=batch_size,
165+
label_transform=label_transform, num_workers=n_workers,
166+
with_channels=with_channels, with_label_channels=with_label_channels,
167+
sampler=sampler
168+
)
169+
val_loader = torch_em.default_segmentation_loader(
170+
raw_paths=data["val"], raw_key="raw",
171+
label_paths=data["val"], label_key="labels/mitochondria",
172+
patch_shape=patch_shape, ndim=ndim, batch_size=batch_size,
173+
label_transform=label_transform, num_workers=n_workers,
174+
with_channels=with_channels, with_label_channels=with_label_channels,
175+
sampler=sampler
176+
)
177+
178+
trainer = torch_em.default_segmentation_trainer(
179+
name=experiment_name, model=model,
180+
train_loader=train_loader, val_loader=val_loader,
181+
loss=loss_function, metric=metric_function,
182+
learning_rate=learning_rate,
183+
mixed_precision=True,
184+
log_image_interval=50,
185+
device=device,
186+
compile_model=False,
187+
save_root=save_dir,
188+
early_stopping=early_stopping,
189+
# logger=None
190+
)
191+
192+
trainer.fit(n_iterations)
193+
194+
195+
if __name__ == "__main__":
196+
main()

0 commit comments

Comments
 (0)