-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmodel.py
More file actions
29 lines (25 loc) · 888 Bytes
/
model.py
File metadata and controls
29 lines (25 loc) · 888 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
import torch
import torch.nn as nn
import math
# weights_init_uniform_() was taken from https://stackoverflow.com/questions/49433936/how-do-i-initialize-weights-in-pytorch
class SimpleModel(nn.Module):
def __init__(self) -> None:
super().__init__()
self.model = nn.Sequential(
nn.Linear(2, 5),
nn.ReLU(),
nn.Linear(5, 5),
nn.ReLU(),
nn.Linear(5, 1),
)
self.model.apply(self.weights_init_uniform_)
def weights_init_uniform_(self, m):
classname = m.__class__.__name__
if classname.find('Linear') != -1:
# get the number of the inputs
n = m.in_features
y = 1.0 / math.sqrt(n)
m.weight.data.uniform_(-y, y)
m.bias.data.fill_(0)
def forward(self, x):
return self.model(x)