-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCustomDiscriminator.py
More file actions
69 lines (60 loc) · 2.74 KB
/
CustomDiscriminator.py
File metadata and controls
69 lines (60 loc) · 2.74 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
import torch
class CustomDiscriminator:
def __init__(self, input_dim, hidden_dims, activation):
self.layers = self.get_custom_discriminator(input_dim, hidden_dims, activation)
self.activation = activation
def get_custom_discriminator(self, input_dim, hidden_dims, activation):
layers = []
hidden_dims.append(1) # Son katmanın çıktı boyutunu ekleyin
for i in range(len(hidden_dims)):
# Son katman için aktivasyonu None olarak ayarlayın
act = activation if i < len(hidden_dims) - 1 else None
layers += self.get_dense_block(input_dim, hidden_dims[i], act)
input_dim = hidden_dims[i]
return layers
def get_dense_block(self, input_dim, output_dim, activation, use_batchnorm=False):
weight = torch.randn((output_dim, input_dim), requires_grad=True) * 0.02
bias = torch.zeros(output_dim, requires_grad=True)
block = [('linear', weight, bias)]
if use_batchnorm:
gamma = torch.ones(output_dim, requires_grad=True)
beta = torch.zeros(output_dim, requires_grad=True)
block.append(('batchnorm', gamma, beta))
if activation == 'relu':
block.append(('relu',))
elif activation == 'leakyrelu':
block.append(('leakyrelu',))
elif activation == 'sigmoid':
block.append(('sigmoid',))
return block
def forward(self, x):
for layer in self.layers:
if layer[0] == 'linear':
weight, bias = layer[1], layer[2]
x = x @ weight.t() + bias
elif layer[0] == 'batchnorm':
gamma, beta = layer[1], layer[2]
mean = x.mean(dim=0, keepdim=True)
std = x.std(dim=0, unbiased=False, keepdim=True)
x = (x - mean) / (std + 1e-5) * gamma + beta
elif layer[0] == 'relu':
x = torch.clamp(x, min=0)
elif layer[0] == 'leakyrelu':
x = torch.clamp(x, min=0) + 0.01 * torch.clamp(x, max=0)
return x
def to(self, device):
for i, layer in enumerate(self.layers):
if layer[0] == 'linear':
weight, bias = layer[1], layer[2]
self.layers[i] = ('linear', weight.to(device), bias.to(device))
elif layer[0] == 'batchnorm':
gamma, beta = layer[1], layer[2]
self.layers[i] = ('batchnorm', gamma.to(device), beta.to(device))
def parameters(self):
params = []
for layer in self.layers:
if layer[0] == 'linear':
weight, bias = layer[1], layer[2]
params.append(weight)
params.append(bias)
return params