-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathdqn.py
More file actions
19 lines (17 loc) · 787 Bytes
/
dqn.py
File metadata and controls
19 lines (17 loc) · 787 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
import torch
class DQN(torch.nn.Module):
def __init__(self, input_size, output_size):
super(DQN, self).__init__()
self.model = torch.nn.Sequential( torch.nn.Linear(input_size, 32),
torch.nn.ReLU(),
torch.nn.Linear(32, 32),
torch.nn.ReLU(),
torch.nn.Linear(32, 32),
torch.nn.ReLU(),
torch.nn.Linear(32, output_size) )
self.loss = torch.nn.MSELoss()
self.param = torch.nn.ModuleList(self.model.children())
def forward(self, x):
for f in self.param:
x = f(x)
return x