Skip to content

Commit 238598b

Browse files
committed
fix: minor bug fixes
1 parent 7cc1349 commit 238598b

File tree

4 files changed

+34
-9
lines changed

4 files changed

+34
-9
lines changed

cellseg_models_pytorch/decoders/long_skips/cross_attn_skip.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,12 +17,12 @@ def __init__(
1717
skip_channels: Tuple[int, ...] = None,
1818
num_heads: int = 8,
1919
head_dim: int = 64,
20-
n_blocks: int = 1,
21-
block_types: Tuple[str, ...] = ("exact",),
22-
computation_types: Tuple[str, ...] = ("basic",),
23-
dropouts: Tuple[float, ...] = (0.0,),
24-
biases: Tuple[bool, ...] = (False,),
25-
layer_scales: Tuple[bool, ...] = (False,),
20+
n_blocks: int = 2,
21+
block_types: Tuple[str, ...] = ("exact", "exact"),
22+
computation_types: Tuple[str, ...] = ("basic", "basic"),
23+
dropouts: Tuple[float, ...] = (0.0, 0.0),
24+
biases: Tuple[bool, ...] = (False, False),
25+
layer_scales: Tuple[bool, ...] = (False, False),
2626
activation: str = "star_relu",
2727
mlp_ratio: int = 2,
2828
slice_size: int = 4,

cellseg_models_pytorch/inference/_base_inferer.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414

1515
from ..utils import FileHandler, tensor_to_ndarray
1616
from .folder_dataset_infer import FolderDatasetInfer
17-
from .hdf5_dataset_infer import HDF5DatasetInfer
1817
from .post_processor import PostProcessor
1918
from .predictor import Predictor
2019

@@ -126,6 +125,8 @@ def __init__(
126125
" `save_dir` argument."
127126
)
128127
elif self.path.is_file() and self.path.suffix in (".h5", ".hdf5"):
128+
from .hdf5_dataset_infer import HDF5DatasetInfer
129+
129130
ds = HDF5DatasetInfer(self.path, n_images=n_images)
130131
else:
131132
raise ValueError(
@@ -270,7 +271,9 @@ def _strip_state_dict(self, ckpt: Dict) -> OrderedDict:
270271
state_dict = OrderedDict()
271272
for k, w in ckpt["state_dict"].items():
272273
if "num_batches_track" not in k:
273-
new_key = k.strip("model")[1:]
274+
# new_key = k.strip("model")[1:]
275+
spl = ["".join(kk) for kk in k.split(".")]
276+
new_key = ".".join(spl[1:])
274277
state_dict[new_key] = w
275278
ckpt["state_dict"] = state_dict
276279

cellseg_models_pytorch/models/base/_timm_encoder.py

Lines changed: 21 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,9 +14,25 @@ def __init__(
1414
pretrained: bool = True,
1515
in_channels: int = 3,
1616
depth: int = 5,
17+
out_indices: List[int] = None,
1718
**kwargs
1819
) -> None:
19-
"""Import any encoder from timm package."""
20+
"""Import any encoder from timm package.
21+
22+
Parameters
23+
----------
24+
name : str
25+
Name of the encoder.
26+
pretrained : bool, optional
27+
If True, load pretrained weights, by default True.
28+
in_channels : int, optional
29+
Number of input channels, by default 3.
30+
depth : int, optional
31+
Number of output features, by default 5.
32+
out_indices : List[int], optional
33+
Indices of the output features, by default None. If None, all the
34+
features are returned.
35+
"""
2036
super().__init__()
2137

2238
kwargs = dict(
@@ -28,9 +44,13 @@ def __init__(
2844

2945
self.model = timm.create_model(name, **kwargs)
3046

47+
self.out_indices = out_indices
3148
self.in_channels = in_channels
3249
self.out_channels = tuple(self.model.feature_info.channels()[::-1])
3350

51+
if self.out_indices is not None:
52+
self.out_channels = tuple(self.out_channels[i] for i in self.out_indices)
53+
3454
def forward(self, x: torch.Tensor) -> List[torch.Tensor]:
3555
"""Forward pass of the encoder and return all the features."""
3656
features = self.model(x)

cellseg_models_pytorch/modules/mlp.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -49,6 +49,8 @@ def __init__(
4949
act_kwargs = act_kwargs if act_kwargs is not None else {}
5050
self.out_channels = in_channels if out_channels is None else out_channels
5151
hidden_channels = int(mlp_ratio * in_channels)
52+
act_kwargs["dim_in"] = hidden_channels
53+
act_kwargs["dim_out"] = hidden_channels
5254

5355
self.fc1 = nn.Linear(in_channels, hidden_channels, bias=bias)
5456
self.act = Activation(activation, **act_kwargs)

0 commit comments

Comments
 (0)