|
1 | 1 | import torch.nn as nn |
2 | 2 | import torch.nn.utils as nn_utils |
3 | 3 | import segmentation_models_pytorch as smp |
4 | | -import torch.nn.functional as F |
5 | | - |
6 | | -# ------------------------------------------------------------ |
7 | | -# 9‑residual‑block ResNet generator (CycleGAN, 256×256) |
8 | | -# ------------------------------------------------------------ |
9 | | -class ResnetBlock(nn.Module): |
10 | | - def __init__(self, channels, padding_type="reflect"): |
11 | | - super().__init__() |
12 | | - pad = nn.ReflectionPad2d if padding_type == "reflect" else nn.ZeroPad2d |
13 | | - |
14 | | - self.block = nn.Sequential( |
15 | | - pad(1), |
16 | | - nn.Conv2d(channels, channels, 3, bias=False), |
17 | | - nn.InstanceNorm2d(channels, affine=True), |
18 | | - nn.ReLU(), |
19 | | - nn.Dropout(0.5), |
20 | | - pad(1), |
21 | | - nn.Conv2d(channels, channels, 3, bias=False), |
22 | | - nn.InstanceNorm2d(channels, affine=True), |
23 | | - ) |
24 | | - |
25 | | - def forward(self, x): |
26 | | - return x + self.block(x) # residual add |
27 | | - |
28 | | - |
29 | | -class ResnetGenerator(nn.Module): |
30 | | - def __init__(self, in_c=3, out_c=3, n_blocks=9, ngf=64): |
31 | | - super().__init__() |
32 | | - assert n_blocks >= 1 |
33 | | - |
34 | | - layers = [ |
35 | | - nn.ReflectionPad2d(3), |
36 | | - nn.Conv2d(in_c, ngf, 7, bias=False), |
37 | | - nn.InstanceNorm2d(ngf, affine=True), |
38 | | - nn.ReLU(), |
39 | | - ] |
40 | | - |
41 | | - # downsample twice: 256→128→64 spatial, 64→128→256 channels |
42 | | - mult = 1 |
43 | | - for _ in range(2): |
44 | | - layers += [ |
45 | | - nn.Conv2d(ngf * mult, ngf * mult * 2, 3, 2, 1, bias=False), |
46 | | - nn.InstanceNorm2d(ngf * mult * 2, affine=True), |
47 | | - nn.ReLU(), |
48 | | - ] |
49 | | - mult *= 2 # 1->2->4 |
50 | | - |
51 | | - # residual blocks |
52 | | - layers += [ResnetBlock(ngf * mult) for _ in range(n_blocks)] |
53 | | - |
54 | | - # upsample back to 256×256 |
55 | | - for _ in range(2): |
56 | | - layers += [ |
57 | | - nn.ConvTranspose2d( |
58 | | - ngf * mult, ngf * mult // 2, |
59 | | - 3, 2, 1, output_padding=1, bias=False |
60 | | - ), |
61 | | - nn.InstanceNorm2d(ngf * mult // 2, affine=True), |
62 | | - nn.ReLU(), |
63 | | - ] |
64 | | - mult //= 2 # 4->2->1 |
65 | | - |
66 | | - layers += [ |
67 | | - nn.ReflectionPad2d(3), |
68 | | - nn.Conv2d(ngf, out_c, 7), # bias=True is fine here |
69 | | - nn.Tanh(), |
70 | | - ] |
71 | | - self.model = nn.Sequential(*layers) |
72 | | - |
73 | | - # weight init (Conv / ConvT) |
74 | | - for m in self.modules(): |
75 | | - if isinstance(m, (nn.Conv2d, nn.ConvTranspose2d)): |
76 | | - nn.init.normal_(m.weight, 0.0, 0.02) |
77 | | - # InstanceNorm affine params |
78 | | - for m in self.modules(): |
79 | | - if isinstance(m, nn.InstanceNorm2d): |
80 | | - nn.init.constant_(m.weight, 1.0) |
81 | | - nn.init.constant_(m.bias, 0.0) |
82 | | - |
83 | | - def forward(self, x): |
84 | | - return self.model(x) |
85 | | - |
86 | | - |
87 | 4 |
|
88 | 5 | # Discriminator: PatchGAN 70x70 |
89 | 6 | class PatchDiscriminator(nn.Module): |
@@ -138,22 +55,19 @@ def unfreeze_encoders(G, F): |
138 | 55 | # Initialize and return the generators and discriminators used for training |
139 | 56 | def initialize_models(): |
140 | 57 | # initialize the generators |
141 | | - # G = smp.Unet( |
142 | | - # encoder_name="resnet34", |
143 | | - # encoder_weights="imagenet", # preload low-level filters |
144 | | - # in_channels=3, # RGB input |
145 | | - # classes=3, # RGB output |
146 | | - # ) |
147 | | - |
148 | | - # F = smp.Unet( |
149 | | - # encoder_name="resnet34", |
150 | | - # encoder_weights="imagenet", # preload low-level filters |
151 | | - # in_channels=3, # RGB input |
152 | | - # classes=3, # RGB output |
153 | | - # ) |
154 | | - |
155 | | - G = ResnetGenerator() |
156 | | - F = ResnetGenerator() |
| 58 | + G = smp.Unet( |
| 59 | + encoder_name="resnet34", |
| 60 | + encoder_weights="imagenet", # preload low-level filters |
| 61 | + in_channels=3, # RGB input |
| 62 | + classes=3, # RGB output |
| 63 | + ) |
| 64 | + |
| 65 | + F = smp.Unet( |
| 66 | + encoder_name="resnet34", |
| 67 | + encoder_weights="imagenet", # preload low-level filters |
| 68 | + in_channels=3, # RGB input |
| 69 | + classes=3, # RGB output |
| 70 | + ) |
157 | 71 |
|
158 | 72 | # initlize the discriminator |
159 | 73 | DX = PatchDiscriminator() |
|
0 commit comments