11"""GrandQC Tissue Detection Model Architecture.
22
33This module defines the GrandQC model for tissue detection in digital pathology.
4- It implements a UNet++ architecture with an EfficientNet encoder and a segmentation
4+ It implements a UNet++ architecture with an EfficientNetB0 encoder and a segmentation
55head for high-resolution tissue segmentation. The model is designed to identify
66tissue regions and background areas for quality control in whole slide images (WSIs).
77
@@ -205,7 +205,7 @@ class DecoderBlock(nn.Module):
205205
206206 This block performs upsampling and feature fusion using skip connections
207207 from the encoder. It consists of two convolutional layers with ReLU activation
208- and optional attention mechanisms.
208+ and optional attention mechanisms (not implemented) .
209209
210210 Attributes:
211211 conv1 (Conv2dReLU):
@@ -222,9 +222,9 @@ class DecoderBlock(nn.Module):
222222
223223 Example:
224224 >>> block = DecoderBlock(in_channels=128, skip_channels=64, out_channels=64)
225- >>> x = torch.randn(1, 128, 64, 64)
225+ >>> input_tensor = torch.randn(1, 128, 64, 64)
226226 >>> skip = torch.randn(1, 64, 128, 128)
227- >>> output = block(x , skip)
227+ >>> output = block(input_tensor , skip)
228228 >>> output.shape
229229 ... torch.Size([1, 64, 128, 128])
230230
@@ -268,7 +268,7 @@ def __init__(
268268
269269 def forward (
270270 self : DecoderBlock ,
271- x : torch .Tensor ,
271+ input_tensor : torch .Tensor ,
272272 skip : torch .Tensor | None = None ,
273273 ) -> torch .Tensor :
274274 """Forward pass through the decoder block.
@@ -277,29 +277,33 @@ def forward(
277277 (if provided), and applies two convolutional layers with attention.
278278
279279 Args:
280- x (torch.Tensor):
281- Input tensor from the previous decoder layer.
280+ input_tensor (torch.Tensor):
281+ (B, C_in, H, W). Input tensor from the previous decoder layer.
282282 skip (torch.Tensor | None):
283+ (B, C_skip, H*2, W*2).
283284 Skip connection tensor from the encoder. Defaults to None.
284285
285286 Returns:
286287 torch.Tensor:
288+ (B, C_out, H*2, W*2).
287289 Output tensor after decoding and feature refinement.
288290
289291 """
290- x = torch .nn .functional .interpolate (x , scale_factor = 2.0 , mode = "nearest" )
292+ input_tensor = torch .nn .functional .interpolate (
293+ input_tensor , scale_factor = 2.0 , mode = "nearest"
294+ )
291295 if skip is not None :
292- x = torch .cat ([x , skip ], dim = 1 )
293- x = self .attention1 (x )
294- x = self .conv1 (x )
295- x = self .conv2 (x )
296- return self .attention2 (x )
296+ input_tensor = torch .cat ([input_tensor , skip ], dim = 1 )
297+ input_tensor = self .attention1 (input_tensor )
298+ input_tensor = self .conv1 (input_tensor )
299+ input_tensor = self .conv2 (input_tensor )
300+ return self .attention2 (input_tensor )
297301
298302
299303class CenterBlock (nn .Sequential ):
300304 """Center block for UNet++ architecture.
301305
302- This block is placed at the bottleneck of the UNet++ architecture.
306+ This block can be placed at the bottleneck of the UNet++ architecture.
303307 It consists of two convolutional layers with ReLU activation, used
304308 to process the deepest feature maps before decoding begins.
305309
@@ -311,8 +315,8 @@ class CenterBlock(nn.Sequential):
311315
312316 Example:
313317 >>> center = CenterBlock(in_channels=256, out_channels=512)
314- >>> x = torch.randn(1, 256, 32, 32)
315- >>> output = center(x )
318+ >>> input_tensor = torch.randn(1, 256, 32, 32)
319+ >>> output = center(input_tensor )
316320 >>> output.shape
317321 ... torch.Size([1, 512, 32, 32])
318322
0 commit comments