Skip to content

Commit 9fe219d

Browse files
committed
WIP bufixing : fixed errors in inference + instance seg +
1 parent a7a1366 commit 9fe219d

File tree

4 files changed

+22
-10
lines changed

4 files changed

+22
-10
lines changed

src/napari_cellseg3d/model_instance_seg.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,7 +21,7 @@ def binary_connected(
2121
thres_small (int): size threshold of small objects to remove. Default: 128
2222
scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0)
2323
"""
24-
semantic = volume[0]
24+
semantic = np.squeeze(volume)
2525
foreground = semantic > thres # int(255 * thres)
2626
segm = label(foreground)
2727
segm = remove_small_objects(segm, thres_small)
@@ -64,9 +64,9 @@ def binary_watershed(
6464
thres_objects (float): threshold for foreground objects. Default: 0.3
6565
thres_small (int): size threshold of small objects removal. Default: 10
6666
scale_factors (tuple): scale factors for resizing in :math:`(Z, Y, X)` order. Default: (1.0, 1.0, 1.0)
67-
rem_seed_thres: threshold for small seeds removal. Default : 3
67+
rem_seed_thres (int): threshold for small seeds removal. Default : 3
6868
"""
69-
semantic = volume[0]
69+
semantic = np.squeeze(volume)
7070
seed_map = semantic > thres_seeding
7171
foreground = semantic > thres_objects
7272
seed = label(seed_map)

src/napari_cellseg3d/model_workers.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -197,11 +197,13 @@ def inference(self):
197197
self.log("\nChecking dimensions...")
198198
pad = utils.get_padding_dim(check)
199199
# print(pad)
200+
dims =128
201+
dims=64
200202

201203
model = self.model_dict["class"].get_net()
202204
if self.model_dict["name"] == "SegResNet":
203205
model = self.model_dict["class"].get_net()(
204-
input_image_size=[128, 128, 128], # TODO FIX !
206+
input_image_size=[dims,dims,dims], # TODO FIX !
205207
out_channels=1,
206208
# dropout_prob=0.3,
207209
)
@@ -294,6 +296,7 @@ def inference(self):
294296

295297
out = post_process_transforms(out)
296298
out = np.array(out).astype(np.float32)
299+
out = np.squeeze(out)
297300

298301
# batch_len = out.shape[1]
299302
# print("trying to check len")
@@ -486,7 +489,7 @@ def log_parameters(self):
486489
if self.weights_path is not None:
487490
self.log(f"Using weights from : {self.weights_path}")
488491

489-
self.log("\n")
492+
# self.log("\n")
490493

491494
def train(self):
492495
"""Trains the Pytorch model for the given number of epochs, with the selected model and data,
@@ -716,6 +719,7 @@ def train(self):
716719
self.log_parameters()
717720

718721
for epoch in range(self.max_epochs):
722+
# self.log("\n")
719723
self.log("-" * 10)
720724
self.log(f"Epoch {epoch + 1}/{self.max_epochs}")
721725
if self.device.type == "cuda":

src/napari_cellseg3d/plugin_crop.py

Lines changed: 12 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -227,11 +227,17 @@ def start(self):
227227
self.image_path, self.filetype, self.as_folder
228228
)
229229

230+
if len(self.image.shape) > 3:
231+
self.image = np.squeeze(self.image)
232+
230233
if self.crop_labels:
231234
self.label = utils.load_images(
232235
self.label_path, self.filetype, self.as_folder
233236
)
234237

238+
if len(self.label.shape) > 3:
239+
self.label = np.squeeze(self.label)
240+
235241
vw = self._viewer
236242

237243
vw.dims.ndisplay = 3
@@ -274,6 +280,11 @@ def add_crop_sliders(
274280
self._y = 0
275281
self._z = 0
276282

283+
print(f"Crop variables")
284+
print(image_stack.shape)
285+
286+
287+
277288
# define crop sizes and boundaries for the image
278289
crop_sizes = [self._crop_size_x, self._crop_size_y, self._crop_size_z]
279290
for i in range(len(crop_sizes)):
@@ -284,13 +295,10 @@ def add_crop_sliders(
284295
# shapez, shapey, shapex = image_stack.shape
285296
ends = np.asarray(image_stack.shape) - np.asarray(crop_sizes) + 1
286297

287-
288-
289298
stepsizes = ends // 100
290299

291-
print(f"Crop variables")
292300
print(crop_sizes)
293-
print(image_stack.shape)
301+
294302
print(ends)
295303
print(stepsizes)
296304

src/napari_cellseg3d/plugin_model_training.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -883,7 +883,7 @@ def on_yield(data, widget):
883883
data["weights"],
884884
os.path.join(
885885
widget.results_path,
886-
f"latest_weights_aborted_training_{utils.get_time()}.pth",
886+
f"latest_weights_aborted_training_{utils.get_time_filepath()}.pth",
887887
),
888888
)
889889
widget.stop_requested = False

0 commit comments

Comments
 (0)