Skip to content

Commit 6f9d412

Browse files
committed
♻️ Move argmax postprocessing to utils.
1 parent 2b342f4 commit 6f9d412

File tree

3 files changed

+32
-25
lines changed

3 files changed

+32
-25
lines changed

tiatoolbox/models/architecture/unet.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,24 @@
22

33
from __future__ import annotations
44

5-
from typing import Any
5+
from typing import TYPE_CHECKING, Any
66

7-
import numpy as np
87
import torch
98
import torch.nn.functional as F # noqa: N812
109
from torch import nn
1110
from torchvision.models.resnet import Bottleneck as ResNetBottleneck
1211
from torchvision.models.resnet import ResNet
1312

14-
from tiatoolbox.models.architecture.utils import UpSample2x, centre_crop
13+
from tiatoolbox.models.architecture.utils import (
14+
UpSample2x,
15+
argmax_last_axis,
16+
centre_crop,
17+
)
1518
from tiatoolbox.models.models_abc import ModelABC
1619

20+
if TYPE_CHECKING: # pragma: no cover
21+
import numpy as np
22+
1723

1824
class ResNetEncoder(ResNet):
1925
"""A subclass of ResNet defined in torch.
@@ -463,9 +469,9 @@ def infer_batch(
463469

464470
@staticmethod
465471
def postproc(image: np.ndarray) -> np.ndarray:
466-
"""Define the post-processing of this class of model.
472+
"""Define post-processing of this class of model.
467473
468474
This simply applies argmax along last axis of the input.
469475
470476
"""
471-
return np.argmax(image, axis=-1)
477+
return argmax_last_axis(image=image)

tiatoolbox/models/architecture/utils.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -233,3 +233,20 @@ def forward(self: UpSample2x, x: torch.Tensor) -> torch.Tensor:
233233
ret = torch.tensordot(x, mat, dims=1) # bxcxhxwxshxsw
234234
ret = ret.permute(0, 1, 2, 4, 3, 5)
235235
return ret.reshape((-1, input_shape[1], input_shape[2] * 2, input_shape[3] * 2))
236+
237+
238+
def argmax_last_axis(image: np.ndarray) -> np.ndarray:
239+
"""Define the post-processing of this class of model.
240+
241+
This simply applies argmax along last axis of the input.
242+
243+
Args:
244+
image (np.ndarray):
245+
The input image array.
246+
247+
Returns:
248+
np.ndarray:
249+
The post-processed image array.
250+
251+
"""
252+
return np.argmax(image, axis=-1)

tiatoolbox/models/architecture/vanilla.py

Lines changed: 4 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,17 @@
44

55
from typing import TYPE_CHECKING
66

7-
import numpy as np
87
import timm
98
import torch
109
import torchvision.models as torch_models
1110
from timm.layers import SwiGLUPacked
1211
from torch import nn
1312

13+
from tiatoolbox.models.architecture.utils import argmax_last_axis
1414
from tiatoolbox.models.models_abc import ModelABC
1515

1616
if TYPE_CHECKING: # pragma: no cover
17+
import numpy as np
1718
from torchvision.models import WeightsEnum
1819

1920

@@ -205,23 +206,6 @@ def _get_timm_architecture(
205206
raise ValueError(msg)
206207

207208

208-
def _postproc(image: np.ndarray) -> np.ndarray:
209-
"""Define the post-processing of this class of model.
210-
211-
This simply applies argmax along last axis of the input.
212-
213-
Args:
214-
image (np.ndarray):
215-
The input image array.
216-
217-
Returns:
218-
np.ndarray:
219-
The post-processed image array.
220-
221-
"""
222-
return np.argmax(image, axis=-1)
223-
224-
225209
def _infer_batch(
226210
model: nn.Module,
227211
batch_data: torch.Tensor,
@@ -339,7 +323,7 @@ def postproc(image: np.ndarray) -> np.ndarray:
339323
The post-processed image array.
340324
341325
"""
342-
return _postproc(image=image)
326+
return argmax_last_axis(image=image)
343327

344328
@staticmethod
345329
def infer_batch(
@@ -463,7 +447,7 @@ def postproc(image: np.ndarray) -> np.ndarray:
463447
The post-processed image array.
464448
465449
"""
466-
return _postproc(image=image)
450+
return argmax_last_axis(image=image)
467451

468452
@staticmethod
469453
def infer_batch(

0 commit comments

Comments
 (0)