-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathCustomGenerator.py
More file actions
73 lines (59 loc) · 2.66 KB
/
CustomGenerator.py
File metadata and controls
73 lines (59 loc) · 2.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
import os
import torch
os.environ['KMP_DUPLICATE_LIB_OK'] = 'TRUE'
class CustomGenerator:
def __init__(self, input_dim, output_dim, hidden_dims):
self.layers = self.get_custom_generator(input_dim, output_dim, hidden_dims)
def get_custom_generator(self, input_dim, output_dim, hidden_dims):
layers = []
for i in range(len(hidden_dims)):
layers += self.get_dense_block(input_dim, hidden_dims[i], activation='relu')
input_dim = hidden_dims[i]
layers += self.get_dense_block(input_dim, output_dim, activation='sigmoid')
return layers
def get_dense_block(self, input_dim, output_dim, activation, use_batchnorm=False):
weight = torch.randn((output_dim, input_dim), requires_grad=True) * 0.02
bias = torch.zeros(output_dim, requires_grad=True)
block = [('linear', weight, bias)]
if use_batchnorm:
gamma = torch.ones(output_dim, requires_grad=True)
beta = torch.zeros(output_dim, requires_grad=True)
block.append(('batchnorm', gamma, beta))
if activation == 'relu':
block.append(('relu',))
elif activation == 'leakyrelu':
block.append(('leakyrelu',))
elif activation == 'sigmoid':
block.append(('sigmoid',))
return block
def forward(self, x):
for layer in self.layers:
if layer[0] == 'linear':
weight, bias = layer[1], layer[2]
x = x.matmul(weight.t()) + bias
elif layer[0] == 'batchnorm':
gamma, beta = layer[1], layer[2]
mean = x.mean(dim=0, keepdim=True)
std = x.std(dim=0, unbiased=False, keepdim=True)
x = (x - mean) / (std + 1e-5) * gamma + beta
elif layer[0] == 'relu':
x = torch.relu(x)
elif layer[0] == 'sigmoid':
x = torch.sigmoid(x)
return x
def to(self, device):
for i, layer in enumerate(self.layers):
if layer[0] == 'linear':
weight, bias = layer[1], layer[2]
self.layers[i] = ('linear', weight.to(device), bias.to(device))
elif layer[0] == 'batchnorm':
gamma, beta = layer[1], layer[2]
self.layers[i] = ('batchnorm', gamma.to(device), beta.to(device))
def parameters(self):
params = []
for layer in self.layers:
if layer[0] == 'linear':
weight, bias = layer[1], layer[2]
params.append(weight)
params.append(bias)
return params