Skip to content

Commit 95216eb

Browse files
Fix test failure due to change in torchvision (#375)
1 parent 8151561 commit 95216eb

File tree

1 file changed

+5
-2
lines changed

1 file changed

+5
-2
lines changed

micro_sam/prompt_based_segmentation.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,7 @@
66
from typing import Optional, Tuple
77

88
import numpy as np
9+
import torch
910
from nifty.tools import blocking
1011
from skimage.feature import peak_local_max
1112
from skimage.filters import gaussian
@@ -87,12 +88,14 @@ def inv_sigmoid(x):
8788

8889
elif logits.shape[0] == logits.shape[1]: # shape is square
8990
trafo = ResizeLongestSide(expected_shape[0])
90-
logits = trafo.apply_image(logits[..., None])
91+
logits = trafo.apply_image_torch(torch.from_numpy(logits[None, None]))
92+
logits = logits.numpy().squeeze()
9193

9294
else: # shape is not square
9395
# resize the longest side to expected shape
9496
trafo = ResizeLongestSide(expected_shape[0])
95-
logits = trafo.apply_image(logits[..., None])
97+
logits = trafo.apply_image_torch(torch.from_numpy(logits[None, None]))
98+
logits = logits.numpy().squeeze()
9699

97100
# pad the other side
98101
h, w = logits.shape

0 commit comments

Comments
 (0)