Skip to content

Commit 7ae99aa

Browse files
committed
📝 Update docstring formatting.
1 parent 6263244 commit 7ae99aa

File tree

1 file changed

+88
-35
lines changed

1 file changed

+88
-35
lines changed

tiatoolbox/models/architecture/kongnet.py

Lines changed: 88 additions & 35 deletions
Original file line numberDiff line numberDiff line change
@@ -134,6 +134,7 @@ def __init__(
134134
Drop path rate of the encoder. Default is 0.0.
135135
pretrained (bool):
136136
Whether to use pretrained weights. Default is True.
137+
137138
"""
138139
super().__init__()
139140
if drop_path_rate is None:
@@ -165,11 +166,14 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
165166
"""Forward pass through the encoder.
166167
167168
Args:
168-
x (torch.Tensor): Input tensor of shape (B, C, H, W)
169+
x (torch.Tensor):
170+
Input tensor of shape (B, C, H, W)
169171
170172
Returns:
171-
List[torch.Tensor]: List of feature tensors at different scales,
173+
list[torch.Tensor]:
174+
List of feature tensors at different scales,
172175
including the input as the first element
176+
173177
"""
174178
features = self.model(x)
175179
return [x, *features]
@@ -179,7 +183,9 @@ def out_channels(self) -> list[int]:
179183
"""Get output channels for each feature level.
180184
181185
Returns:
182-
List[int]: Number of channels at each feature level
186+
list[int]:
187+
Number of channels at each feature level
188+
183189
"""
184190
return self._out_channels
185191

@@ -188,7 +194,9 @@ def output_stride(self) -> int:
188194
"""Get the output stride of the encoder.
189195
190196
Returns:
191-
int: Output stride value
197+
int:
198+
Output stride value
199+
192200
"""
193201
return min(self._output_stride, 2**self._depth)
194202

@@ -200,9 +208,13 @@ class SubPixelUpsample(nn.Module):
200208
which is more efficient than transposed convolution and produces better results.
201209
202210
Args:
203-
in_channels (int): Number of input channels
204-
out_channels (int): Number of output channels
205-
upscale_factor (int): Factor to increase spatial resolution. Default: 2
211+
in_channels (int):
212+
Number of input channels
213+
out_channels (int):
214+
Number of output channels
215+
upscale_factor (int):
216+
Factor to increase spatial resolution. Default: 2
217+
206218
"""
207219

208220
def __init__(
@@ -248,6 +260,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
248260
torch.Tensor:
249261
Upsampled tensor of shape
250262
(B, out_channels, H*upscale_factor, W*upscale_factor)
263+
251264
"""
252265
x = self.conv1(x)
253266
x = self.pixel_shuffle(x)
@@ -262,10 +275,15 @@ class DecoderBlock(nn.Module):
262275
and processes through convolutions.
263276
264277
Args:
265-
in_channels (int): Number of input channels
266-
skip_channels (int): Number of channels from skip connection
267-
out_channels (int): Number of output channels
268-
attention_type (str): Type of attention mechanism. Default: 'scse'
278+
in_channels (int):
279+
Number of input channels
280+
skip_channels (int):
281+
Number of channels from skip connection
282+
out_channels (int):
283+
Number of output channels
284+
attention_type (str):
285+
Type of attention mechanism. Default: 'scse'.
286+
269287
"""
270288

271289
def __init__(
@@ -285,7 +303,8 @@ def __init__(
285303
out_channels (int):
286304
Number of output channels
287305
attention_type (str):
288-
Type of attention mechanism. Default: 'scse'
306+
Type of attention mechanism. Default: 'scse'.
307+
289308
"""
290309
super().__init__()
291310
self.up = SubPixelUpsample(in_channels, in_channels, upscale_factor=2)
@@ -322,7 +341,9 @@ def forward(
322341
Skip connection tensor from encoder. Default: None
323342
324343
Returns:
325-
torch.Tensor: Processed output tensor
344+
torch.Tensor:
345+
Processed output tensor
346+
326347
"""
327348
x = self.up(x)
328349
if skip is not None:
@@ -340,14 +361,18 @@ class CenterBlock(nn.Module):
340361
to enhance feature representation using attention mechanisms.
341362
342363
Args:
343-
in_channels (int): Number of input channels
364+
in_channels (int):
365+
Number of input channels
366+
344367
"""
345368

346369
def __init__(self, in_channels: int) -> None:
347370
"""Initialize CenterBlock with attention.
348371
349372
Args:
350-
in_channels (int): Number of input channels
373+
in_channels (int):
374+
Number of input channels.
375+
351376
"""
352377
super().__init__()
353378
self.attention = AttentionModule(name="scse", in_channels=in_channels)
@@ -356,10 +381,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
356381
"""Forward pass through center block.
357382
358383
Args:
359-
x (torch.Tensor): Input tensor
384+
x (torch.Tensor):
385+
Input tensor.
360386
361387
Returns:
362-
torch.Tensor: Output tensor with attention applied
388+
torch.Tensor:
389+
Output tensor with attention applied.
390+
363391
"""
364392
return self.attention(x)
365393

@@ -371,14 +399,21 @@ class KongNetDecoder(nn.Module):
371399
attention mechanisms, and optional center block at the bottleneck.
372400
373401
Args:
374-
encoder_channels (List[int]): Number of channels at each encoder level
375-
decoder_channels (Tuple[int, ...]): Number of channels at each decoder level
376-
n_blocks (int): Number of decoder blocks. Default: 5
377-
attention_type (str): Type of attention mechanism. Default: 'scse'
378-
center (bool): Whether to use center block at bottleneck. Default: True
402+
encoder_channels (List[int]):
403+
Number of channels at each encoder level
404+
decoder_channels (Tuple[int, ...]):
405+
Number of channels at each decoder level
406+
n_blocks (int):
407+
Number of decoder blocks. Default: 5
408+
attention_type (str):
409+
Type of attention mechanism. Default: 'scse'
410+
center (bool):
411+
Whether to use center block at bottleneck. Default: True
379412
380413
Raises:
381-
ValueError: If n_blocks doesn't match length of decoder_channels
414+
ValueError:
415+
If n_blocks doesn't match length of decoder_channels
416+
382417
"""
383418

384419
def __init__(
@@ -404,6 +439,7 @@ def __init__(
404439
center (bool):
405440
Whether to include a center block at the bottleneck.
406441
Default is True.
442+
407443
"""
408444
super().__init__()
409445

@@ -442,10 +478,13 @@ def forward(self, *features: torch.Tensor) -> torch.Tensor:
442478
"""Forward pass through the decoder.
443479
444480
Args:
445-
*features: Feature tensors from encoder at different scales
481+
*features:
482+
Feature tensors from encoder at different scales
446483
447484
Returns:
448-
torch.Tensor: Decoded output tensor
485+
torch.Tensor:
486+
Decoded output tensor
487+
449488
"""
450489
features = features[1:] # remove first skip with same spatial resolution
451490
features = features[::-1] # reverse channels to start from head of encoder
@@ -471,14 +510,22 @@ class KongNet(ModelABC):
471510
472511
473512
Attributes:
474-
encoder: Encoder module (e.g., TimmEncoderFixed)
475-
decoders: List of decoder modules (KongNetDecoder)
476-
heads: List of segmentation head modules (SegmentationHead)
477-
min_distance: Minimum distance between peaks in post-processing
478-
threshold_abs: Absolute threshold for peak detection in post-processing
479-
target_channels: List of target channel indices for post-processing
480-
output_class_dict: Optional dictionary mapping class names to indices
481-
postproc_tile_shape: Tile shape for post-processing with dask
513+
encoder:
514+
Encoder module (e.g., TimmEncoderFixed)
515+
decoders:
516+
List of decoder modules (KongNetDecoder)
517+
heads:
518+
List of segmentation head modules (SegmentationHead)
519+
min_distance:
520+
Minimum distance between peaks in post-processing
521+
threshold_abs:
522+
Absolute threshold for peak detection in post-processing
523+
target_channels:
524+
List of target channel indices for post-processing
525+
output_class_dict:
526+
Optional dictionary mapping class names to indices
527+
postproc_tile_shape:
528+
Tile shape for post-processing with dask
482529
483530
Example:
484531
>>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
@@ -497,6 +544,7 @@ class KongNet(ModelABC):
497544
and Classification of Nuclei in Histopathology Images.", 2025,
498545
arXiv preprint arXiv:2510.23559.,
499546
URL: https://arxiv.org/abs/2510.23559
547+
500548
"""
501549

502550
def __init__(
@@ -530,6 +578,7 @@ def __init__(
530578
Whether to use a wider decoder architecture. Defaults to False.
531579
class_dict (dict | None):
532580
Optional dictionary mapping class names to indices. Defaults to None.
581+
533582
"""
534583
super().__init__()
535584

@@ -617,7 +666,8 @@ def forward( # skipcq: PYL-W0613
617666
"""Forward pass through the model.
618667
619668
Args:
620-
x (torch.Tensor): Input tensor of shape (B, C, H, W)
669+
x (torch.Tensor):
670+
Input tensor of shape (B, C, H, W)
621671
*args (tuple):
622672
Additional positional arguments (unused).
623673
**kwargs (dict):
@@ -626,6 +676,7 @@ def forward( # skipcq: PYL-W0613
626676
Returns:
627677
torch.Tensor: Concatenated output from all heads of shape
628678
(B, sum(num_channels_per_head), H, W)
679+
629680
"""
630681
features = self.encoder(x)
631682
decoder_outputs = [decoder(*features) for decoder in self.decoders]
@@ -720,7 +771,9 @@ def postproc(
720771
when it's called from dask.array.map_overlap.
721772
722773
Returns:
723-
out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere.
774+
out:
775+
NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere.
776+
724777
"""
725778
min_distance_to_use = (
726779
self.min_distance if min_distance is None else min_distance

0 commit comments

Comments
 (0)