1+ """Model definitions for the CycleGAN-style architecture."""
2+
3+ from torch import Tensor
14import torch .nn as nn
25import torch .nn .functional as F
36
69
710
811class ResidualBlock (nn .Module ):
9- def __init__ (self , in_features ):
12+ """Simple residual block with two conv layers."""
13+
14+ def __init__ (self , in_features : int ) -> None :
1015 super ().__init__ ()
1116
1217 conv_block = [
@@ -21,12 +26,15 @@ def __init__(self, in_features):
2126
2227 self .conv_block = nn .Sequential (* conv_block )
2328
24- def forward (self , x ):
25- return x + self .conv_block (x ) # skip connection
29+ def forward (self , x : Tensor ) -> Tensor :
30+ """Apply the residual block."""
31+ return x + self .conv_block (x )
2632
2733
2834class Generator (nn .Module ):
29- def __init__ (self , ngf , n_residual_blocks = 9 ):
35+ """U-Net style generator used for domain translation."""
36+
37+ def __init__ (self , ngf : int , n_residual_blocks : int = 9 ) -> None :
3038 super ().__init__ ()
3139
3240 # Initial convlution block
@@ -85,12 +93,15 @@ def __init__(self, ngf, n_residual_blocks=9):
8593
8694 self .model = nn .Sequential (* model )
8795
88- def forward (self , x ):
96+ def forward (self , x : Tensor ) -> Tensor :
97+ """Generate an image from ``x``."""
8998 return self .model (x )
9099
91100
92101class Discriminator (nn .Module ):
93- def __init__ (self , ndf ):
102+ """PatchGAN discriminator."""
103+
104+ def __init__ (self , ndf : int ) -> None :
94105 super ().__init__ ()
95106
96107 model = [
@@ -125,13 +136,10 @@ def __init__(self, ndf):
125136
126137 self .model = nn .Sequential (* model )
127138
128- def forward (self , x ):
129- # x: (B, 3, H, W)
130- x = self .model (x ) # (B, 1, H//8-2, W//8-2)
131- # Average pooling and flatten
132- return F .avg_pool2d (x , x .size ()[2 :]).view (
133- x .size ()[0 ], - 1
134- ) # global average -> (B, 1, 1, 1) -> flatten to (B, 1)
139+ def forward (self , x : Tensor ) -> Tensor :
140+ """Return discriminator logits for input ``x``."""
141+ x = self .model (x )
142+ return F .avg_pool2d (x , x .size ()[2 :]).view (x .size ()[0 ], - 1 )
135143
136144
137145# # Discriminator: PatchGAN 70x70
@@ -187,7 +195,8 @@ def initialize_models(
187195 ngf : int = 32 ,
188196 ndf : int = 32 ,
189197 n_blocks : int = 9 ,
190- ):
198+ ) -> tuple [Generator , Generator , Discriminator , Discriminator ]:
199+ """Instantiate generators and discriminators with default sizes."""
191200 # G = smp.Unet(
192201 # encoder_name="resnet34",
193202 # encoder_weights="imagenet", # preload low-level filters
0 commit comments