1
1
"""
2
2
Discriminator and Generator implementation from DCGAN paper
3
+
4
+ Programmed by Aladdin Persson <aladdin.persson at hotmail dot com>
5
+ * 2020-11-01: Initial coding
6
+ * 2022-12-20: Small revision of code, checked that it works with latest PyTorch version
3
7
"""
4
8
5
9
import torch
@@ -11,9 +15,7 @@ def __init__(self, channels_img, features_d):
11
15
super (Discriminator , self ).__init__ ()
12
16
self .disc = nn .Sequential (
13
17
# input: N x channels_img x 64 x 64
14
- nn .Conv2d (
15
- channels_img , features_d , kernel_size = 4 , stride = 2 , padding = 1
16
- ),
18
+ nn .Conv2d (channels_img , features_d , kernel_size = 4 , stride = 2 , padding = 1 ),
17
19
nn .LeakyReLU (0.2 ),
18
20
# _block(in_channels, out_channels, kernel_size, stride, padding)
19
21
self ._block (features_d , features_d * 2 , 4 , 2 , 1 ),
@@ -34,7 +36,7 @@ def _block(self, in_channels, out_channels, kernel_size, stride, padding):
34
36
padding ,
35
37
bias = False ,
36
38
),
37
- #nn.BatchNorm2d(out_channels),
39
+ # nn.BatchNorm2d(out_channels),
38
40
nn .LeakyReLU (0.2 ),
39
41
)
40
42
@@ -68,7 +70,7 @@ def _block(self, in_channels, out_channels, kernel_size, stride, padding):
68
70
padding ,
69
71
bias = False ,
70
72
),
71
- #nn.BatchNorm2d(out_channels),
73
+ # nn.BatchNorm2d(out_channels),
72
74
nn .ReLU (),
73
75
)
74
76
@@ -82,6 +84,7 @@ def initialize_weights(model):
82
84
if isinstance (m , (nn .Conv2d , nn .ConvTranspose2d , nn .BatchNorm2d )):
83
85
nn .init .normal_ (m .weight .data , 0.0 , 0.02 )
84
86
87
+
85
88
def test ():
86
89
N , in_channels , H , W = 8 , 3 , 64 , 64
87
90
noise_dim = 100
@@ -91,6 +94,8 @@ def test():
91
94
gen = Generator (noise_dim , in_channels , 8 )
92
95
z = torch .randn ((N , noise_dim , 1 , 1 ))
93
96
assert gen (z ).shape == (N , in_channels , H , W ), "Generator test failed"
97
+ print ("Success, tests passed!" )
94
98
95
99
96
- # test()
100
+ if __name__ == "__main__" :
101
+ test ()
0 commit comments