Skip to content

Commit 920230b

Browse files
Merge branch 'intensity-masking' of https://github.com/computational-cell-analytics/lightsheet-moser into intensity-masking
2 parents c0d2820 + cf7d738 commit 920230b

File tree

6 files changed

+38
-55
lines changed

6 files changed

+38
-55
lines changed

reproducibility/label_components/IHC_v4c_fig2.json

Lines changed: 0 additions & 32 deletions
This file was deleted.

scripts/la-vision/check_detections.py

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
import napari
22
import zarr
33

4+
from skimage.feature import peak_local_max
5+
46

57
resolution = [3.0, 1.887779, 1.887779]
68
positions = [
@@ -29,7 +31,7 @@ def _load_prediction(bb):
2931

3032

3133
def _load_prediction_debug():
32-
path = "./debug-pred/pred-v5.h5"
34+
path = "./debug-pred/pred-v7.h5"
3335
with zarr.open(path, "r") as f:
3436
pred = f["pred"][:]
3537
return pred
@@ -45,10 +47,14 @@ def check_detection(position, halo=[32, 384, 384]):
4547
# pred = _load_prediction(bb)
4648
pred = _load_prediction_debug()
4749

50+
print("Runnign detection")
51+
det_new = peak_local_max(pred, min_distance=4, threshold_abs=0.5)
52+
4853
v = napari.Viewer()
4954
v.add_image(pv)
5055
v.add_image(pred)
5156
v.add_labels(detections_mobie)
57+
v.add_points(det_new)
5258
napari.run()
5359

5460

scripts/la-vision/debug_prediction.py

Lines changed: 11 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -30,21 +30,23 @@ def run_prediction(position, halo=[32, 384, 384]):
3030
print(mean, std)
3131
preproc = partial(standardize, mean=mean, std=std)
3232

33-
block_shape = (24, 256, 256)
34-
halo = (8, 64, 64)
33+
block_shape = (12, 128, 128)
34+
halo = (10, 64, 64)
3535

36-
model_path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/sgn-detection-v1.pt"
36+
# model_path = "/mnt/vast-nhr/projects/nim00007/data/moser/cochlea-lightsheet/trained_models/SGN/sgn-detection-v1.pt"
37+
model_path = "/mnt/vast-nhr/home/pape41/u12086/Work/my_projects/flamingo-tools/scripts/la-vision/checkpoints/sgn-detection-v3.pt"
3738
model = torch.load(model_path, weights_only=False)
3839

39-
def postproc(x):
40-
x = np.clip(x, 0, 1)
41-
max_ = np.percentile(x, 99)
42-
x = x / max_
43-
return x
40+
# def postproc(x):
41+
# x = np.clip(x, 0, 1)
42+
# max_ = np.percentile(x, 99)
43+
# x = x / max_
44+
# return x
45+
postproc = None
4446

4547
pred = predict_with_halo(pv, model, [0], block_shape, halo, preprocess=preproc, postprocess=postproc).squeeze()
4648

47-
pred_name = "pred-v5"
49+
pred_name = "pred-v7"
4850
out_folder = "./debug-pred"
4951
os.makedirs(out_folder, exist_ok=True)
5052

scripts/la-vision/export_model.py

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -9,17 +9,18 @@
99
sys.path.append("../synapse_marker_detection")
1010

1111

12-
def export_model(input_, output):
13-
model = load_model(input_, device="cpu")
12+
def export_model(input_, output, latest):
13+
model = load_model(input_, device="cpu", name="latest" if latest else "best")
1414
torch.save(model, output)
1515

1616

1717
def main():
1818
parser = argparse.ArgumentParser()
1919
parser.add_argument("-i", "--input", required=True)
2020
parser.add_argument("-o", "--output", required=True)
21+
parser.add_argument("--latest", action="store_true")
2122
args = parser.parse_args()
22-
export_model(args.input, args.output)
23+
export_model(args.input, args.output, latest=args.latest)
2324

2425

2526
if __name__ == "__main__":

scripts/la-vision/train_sgn_detection.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
TRAIN = os.path.join(ROOT, "images")
1919
TRAIN_EMPTY = os.path.join(ROOT, "empty_images")
2020

21-
LABEL = os.path.join(ROOT, "centroids")
21+
LABEL = os.path.join(ROOT, "centroids_v2")
2222
LABEL_EMPTY = os.path.join(ROOT, "empty_centroids")
2323

2424

@@ -46,13 +46,13 @@ def _get_paths(split, train_folder, label_folder, n=None):
4646

4747
def get_paths(split):
4848
image_paths, label_paths = _get_paths(split, TRAIN, LABEL)
49-
empty_image_paths, empty_label_paths = _get_paths(split, TRAIN_EMPTY, LABEL_EMPTY, n=4)
49+
empty_image_paths, empty_label_paths = _get_paths(split, TRAIN_EMPTY, LABEL_EMPTY, n=6)
5050
return image_paths + empty_image_paths, label_paths + empty_label_paths
5151

5252

5353
def train():
5454

55-
model_name = "sgn-low-res-detection-v1"
55+
model_name = "sgn-low-res-detection-v6"
5656

5757
train_paths, train_label_paths = get_paths("train")
5858
val_paths, val_label_paths = get_paths("val")
@@ -63,8 +63,8 @@ def train():
6363
print(len(train_paths), "tomograms for training")
6464
print(len(val_paths), "tomograms for validation")
6565

66-
patch_shape = [48, 256, 256]
67-
batch_size = 8
66+
patch_shape = [32, 144, 144]
67+
batch_size = 16
6868
check = False
6969

7070
checkpoint_path = f"./checkpoints/{model_name}"
@@ -78,6 +78,10 @@ def train():
7878
f, indent=2, sort_keys=True
7979
)
8080

81+
# For marmoset model
82+
sigma = (0.6, 3, 3)
83+
# For mouse model
84+
# sigma = (1, 4, 4)
8185
supervised_training(
8286
name=model_name,
8387
train_paths=train_paths,
@@ -92,7 +96,7 @@ def train():
9296
out_channels=1,
9397
augmentations=None,
9498
eps=1e-5,
95-
sigma=4,
99+
sigma=sigma,
96100
lower_bound=None,
97101
upper_bound=None,
98102
test_paths=test_paths,

scripts/synapse_marker_detection/detection_dataset.py

Lines changed: 6 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -206,12 +206,14 @@ def _get_sample(self, index):
206206
raise RuntimeError(f"Could not sample a valid batch in {self.max_sampling_attempts} attempts")
207207

208208
# For synapse detection.
209-
label = process_labels(coords, shape, self.sigma, self.eps, bb=bb)
209+
# label = process_labels(coords, shape, self.sigma, self.eps, bb=bb)
210210

211211
# For SGN detection with data specfic hacks
212-
# label = process_labels_hacky(coords, shape, self.sigma, self.eps, bb=bb)
213-
# gap = 8
214-
# raw_patch, label = raw_patch[gap:-gap], label[gap:-gap]
212+
label = process_labels_hacky(coords, shape, self.sigma, self.eps, bb=bb)
213+
# Having this halo actually makes sense in general!
214+
gap = 8
215+
gap_bb = np.s_[gap:-gap, gap:-gap, gap:-gap]
216+
raw_patch, label = raw_patch[gap_bb], label[gap_bb]
215217

216218
have_label_channels = label.ndim == 4
217219
if have_label_channels:

0 commit comments

Comments
 (0)