@@ -134,6 +134,7 @@ def __init__(
134134 Drop path rate of the encoder. Default is 0.0.
135135 pretrained (bool):
136136 Whether to use pretrained weights. Default is True.
137+
137138 """
138139 super ().__init__ ()
139140 if drop_path_rate is None :
@@ -165,11 +166,14 @@ def forward(self, x: torch.Tensor) -> list[torch.Tensor]:
165166 """Forward pass through the encoder.
166167
167168 Args:
168- x (torch.Tensor): Input tensor of shape (B, C, H, W)
169+ x (torch.Tensor):
170+ Input tensor of shape (B, C, H, W)
169171
170172 Returns:
171- List[torch.Tensor]: List of feature tensors at different scales,
173+ list[torch.Tensor]:
174+ List of feature tensors at different scales,
172175 including the input as the first element
176+
173177 """
174178 features = self .model (x )
175179 return [x , * features ]
@@ -179,7 +183,9 @@ def out_channels(self) -> list[int]:
179183 """Get output channels for each feature level.
180184
181185 Returns:
182- List[int]: Number of channels at each feature level
186+ list[int]:
187+ Number of channels at each feature level
188+
183189 """
184190 return self ._out_channels
185191
@@ -188,7 +194,9 @@ def output_stride(self) -> int:
188194 """Get the output stride of the encoder.
189195
190196 Returns:
191- int: Output stride value
197+ int:
198+ Output stride value
199+
192200 """
193201 return min (self ._output_stride , 2 ** self ._depth )
194202
@@ -200,9 +208,13 @@ class SubPixelUpsample(nn.Module):
200208 which is more efficient than transposed convolution and produces better results.
201209
202210 Args:
203- in_channels (int): Number of input channels
204- out_channels (int): Number of output channels
205- upscale_factor (int): Factor to increase spatial resolution. Default: 2
211+ in_channels (int):
212+ Number of input channels
213+ out_channels (int):
214+ Number of output channels
215+ upscale_factor (int):
216+ Factor to increase spatial resolution. Default: 2
217+
206218 """
207219
208220 def __init__ (
@@ -248,6 +260,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
248260 torch.Tensor:
249261 Upsampled tensor of shape
250262 (B, out_channels, H*upscale_factor, W*upscale_factor)
263+
251264 """
252265 x = self .conv1 (x )
253266 x = self .pixel_shuffle (x )
@@ -262,10 +275,15 @@ class DecoderBlock(nn.Module):
262275 and processes through convolutions.
263276
264277 Args:
265- in_channels (int): Number of input channels
266- skip_channels (int): Number of channels from skip connection
267- out_channels (int): Number of output channels
268- attention_type (str): Type of attention mechanism. Default: 'scse'
278+ in_channels (int):
279+ Number of input channels
280+ skip_channels (int):
281+ Number of channels from skip connection
282+ out_channels (int):
283+ Number of output channels
284+ attention_type (str):
285+ Type of attention mechanism. Default: 'scse'.
286+
269287 """
270288
271289 def __init__ (
@@ -285,7 +303,8 @@ def __init__(
285303 out_channels (int):
286304 Number of output channels
287305 attention_type (str):
288- Type of attention mechanism. Default: 'scse'
306+ Type of attention mechanism. Default: 'scse'.
307+
289308 """
290309 super ().__init__ ()
291310 self .up = SubPixelUpsample (in_channels , in_channels , upscale_factor = 2 )
@@ -322,7 +341,9 @@ def forward(
322341 Skip connection tensor from encoder. Default: None
323342
324343 Returns:
325- torch.Tensor: Processed output tensor
344+ torch.Tensor:
345+ Processed output tensor
346+
326347 """
327348 x = self .up (x )
328349 if skip is not None :
@@ -340,14 +361,18 @@ class CenterBlock(nn.Module):
340361 to enhance feature representation using attention mechanisms.
341362
342363 Args:
343- in_channels (int): Number of input channels
364+ in_channels (int):
365+ Number of input channels
366+
344367 """
345368
346369 def __init__ (self , in_channels : int ) -> None :
347370 """Initialize CenterBlock with attention.
348371
349372 Args:
350- in_channels (int): Number of input channels
373+ in_channels (int):
374+ Number of input channels.
375+
351376 """
352377 super ().__init__ ()
353378 self .attention = AttentionModule (name = "scse" , in_channels = in_channels )
@@ -356,10 +381,13 @@ def forward(self, x: torch.Tensor) -> torch.Tensor:
356381 """Forward pass through center block.
357382
358383 Args:
359- x (torch.Tensor): Input tensor
384+ x (torch.Tensor):
385+ Input tensor.
360386
361387 Returns:
362- torch.Tensor: Output tensor with attention applied
388+ torch.Tensor:
389+ Output tensor with attention applied.
390+
363391 """
364392 return self .attention (x )
365393
@@ -371,14 +399,21 @@ class KongNetDecoder(nn.Module):
371399 attention mechanisms, and optional center block at the bottleneck.
372400
373401 Args:
374- encoder_channels (List[int]): Number of channels at each encoder level
375- decoder_channels (Tuple[int, ...]): Number of channels at each decoder level
376- n_blocks (int): Number of decoder blocks. Default: 5
377- attention_type (str): Type of attention mechanism. Default: 'scse'
378- center (bool): Whether to use center block at bottleneck. Default: True
402+ encoder_channels (List[int]):
403+ Number of channels at each encoder level
404+ decoder_channels (Tuple[int, ...]):
405+ Number of channels at each decoder level
406+ n_blocks (int):
407+ Number of decoder blocks. Default: 5
408+ attention_type (str):
409+ Type of attention mechanism. Default: 'scse'
410+ center (bool):
411+ Whether to use center block at bottleneck. Default: True
379412
380413 Raises:
381- ValueError: If n_blocks doesn't match length of decoder_channels
414+ ValueError:
415+ If n_blocks doesn't match length of decoder_channels
416+
382417 """
383418
384419 def __init__ (
@@ -404,6 +439,7 @@ def __init__(
404439 center (bool):
405440 Whether to include a center block at the bottleneck.
406441 Default is True.
442+
407443 """
408444 super ().__init__ ()
409445
@@ -442,10 +478,13 @@ def forward(self, *features: torch.Tensor) -> torch.Tensor:
442478 """Forward pass through the decoder.
443479
444480 Args:
445- *features: Feature tensors from encoder at different scales
481+ *features:
482+ Feature tensors from encoder at different scales
446483
447484 Returns:
448- torch.Tensor: Decoded output tensor
485+ torch.Tensor:
486+ Decoded output tensor
487+
449488 """
450489 features = features [1 :] # remove first skip with same spatial resolution
451490 features = features [::- 1 ] # reverse channels to start from head of encoder
@@ -471,14 +510,22 @@ class KongNet(ModelABC):
471510
472511
473512 Attributes:
474- encoder: Encoder module (e.g., TimmEncoderFixed)
475- decoders: List of decoder modules (KongNetDecoder)
476- heads: List of segmentation head modules (SegmentationHead)
477- min_distance: Minimum distance between peaks in post-processing
478- threshold_abs: Absolute threshold for peak detection in post-processing
479- target_channels: List of target channel indices for post-processing
480- output_class_dict: Optional dictionary mapping class names to indices
481- postproc_tile_shape: Tile shape for post-processing with dask
513+ encoder:
514+ Encoder module (e.g., TimmEncoderFixed)
515+ decoders:
516+ List of decoder modules (KongNetDecoder)
517+ heads:
518+ List of segmentation head modules (SegmentationHead)
519+ min_distance:
520+ Minimum distance between peaks in post-processing
521+ threshold_abs:
522+ Absolute threshold for peak detection in post-processing
523+ target_channels:
524+ List of target channel indices for post-processing
525+ output_class_dict:
526+ Optional dictionary mapping class names to indices
527+ postproc_tile_shape:
528+ Tile shape for post-processing with dask
482529
483530 Example:
484531 >>> from tiatoolbox.models.engine.nucleus_detector import NucleusDetector
@@ -497,6 +544,7 @@ class KongNet(ModelABC):
497544 and Classification of Nuclei in Histopathology Images.", 2025,
498545 arXiv preprint arXiv:2510.23559.,
499546 URL: https://arxiv.org/abs/2510.23559
547+
500548 """
501549
502550 def __init__ (
@@ -530,6 +578,7 @@ def __init__(
530578 Whether to use a wider decoder architecture. Defaults to False.
531579 class_dict (dict | None):
532580 Optional dictionary mapping class names to indices. Defaults to None.
581+
533582 """
534583 super ().__init__ ()
535584
@@ -617,7 +666,8 @@ def forward( # skipcq: PYL-W0613
617666 """Forward pass through the model.
618667
619668 Args:
620- x (torch.Tensor): Input tensor of shape (B, C, H, W)
669+ x (torch.Tensor):
670+ Input tensor of shape (B, C, H, W)
621671 *args (tuple):
622672 Additional positional arguments (unused).
623673 **kwargs (dict):
@@ -626,6 +676,7 @@ def forward( # skipcq: PYL-W0613
626676 Returns:
627677 torch.Tensor: Concatenated output from all heads of shape
628678 (B, sum(num_channels_per_head), H, W)
679+
629680 """
630681 features = self .encoder (x )
631682 decoder_outputs = [decoder (* features ) for decoder in self .decoders ]
@@ -720,7 +771,9 @@ def postproc(
720771 when it's called from dask.array.map_overlap.
721772
722773 Returns:
723- out: NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere.
774+ out:
775+ NumPy array (H, W, C) with 1.0 at peaks, 0 elsewhere.
776+
724777 """
725778 min_distance_to_use = (
726779 self .min_distance if min_distance is None else min_distance
0 commit comments