-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathLSTM_model.py
More file actions
42 lines (32 loc) · 1.35 KB
/
LSTM_model.py
File metadata and controls
42 lines (32 loc) · 1.35 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
"""
@author: Anish Cheraku
@website: NA
"""
import torch
import torch.nn as nn
import numpy as np
# Run this statement in the terminal: python LSTM_train.py --data_file history_price.csv --num_epochs 100 --batch_size 32 --checkpoint_path lstm_checkpoint.pth
class PricePredictionModel(nn.Module):
def __init__(self, input_size, hidden_size, num_layers, output_size, dropout_rate=0.2):
super(PricePredictionModel, self).__init__()
self.hidden_size = hidden_size
self.num_layers = num_layers
self.lstm = nn.LSTM(input_size, hidden_size, num_layers, batch_first=True, dropout=dropout_rate)
self.dropout = nn.Dropout(dropout_rate)
self.fc = nn.Linear(hidden_size, output_size)
def forward(self, x):
h0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device).to(x.dtype)
c0 = torch.zeros(self.num_layers, x.size(0), self.hidden_size).to(x.device).to(x.dtype)
out, _ = self.lstm(x, (h0, c0))
out = self.dropout(out[:, -1, :])
out = self.fc(out)
return out
def create_sequences(data, seq_length):
sequences = []
targets = []
for i in range(len(data) - seq_length):
seq = data[i:i + seq_length]
target = data[i + seq_length]
sequences.append(seq)
targets.append(target)
return np.array(sequences), np.array(targets)