Skip to content

Commit 42d7155

Browse files
committed
refactoring
1 parent 1b14a78 commit 42d7155

File tree

3 files changed

+94
-31
lines changed

3 files changed

+94
-31
lines changed

stock_prediction_class.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,46 @@
1+
# Copyright 2020 Jordi Corbilla. All Rights Reserved.
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
# ==============================================================================
15+
16+
17+
class StockPrediction:
18+
def __init__(self, ticker, start_date, validation_date, project_folder):
19+
self._ticker = ticker
20+
self._start_date = start_date
21+
self._validation_date = validation_date
22+
self._project_folder = project_folder
23+
24+
def get_ticker(self):
25+
return self._ticker
26+
27+
def set_ticker(self, value):
28+
self._ticker = value
29+
30+
def get_start_date(self):
31+
return self._start_date
32+
33+
def set_start_date(self, value):
34+
self._start_date = value
35+
36+
def get_validation_date(self):
37+
return self._validation_date
38+
39+
def set_validation_date(self, value):
40+
self._validation_date = value
41+
42+
def get_project_folder(self):
43+
return self._project_folder
44+
45+
def set_project_folder(self, value):
46+
self._project_folder = value

stock_prediction_deep_learning.py

Lines changed: 14 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -16,41 +16,26 @@
1616
import secrets
1717
import pandas as pd
1818
import tensorflow as tf
19-
from sklearn.preprocessing import MinMaxScaler
2019
from datetime import datetime
21-
import yfinance as yf
20+
21+
from stock_prediction_class import StockPrediction
2222
from stock_prediction_lstm import LongShortTermMemory
2323
from stock_prediction_numpy import StockData
2424
from stock_prediction_plotter import Plotter
2525

2626
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/'
2727

2828

29-
def train_LSTM_network(start_date, ticker, validation_date):
30-
min_max = MinMaxScaler(feature_range=(0, 1))
31-
sec = yf.Ticker(ticker)
32-
end_date = datetime.today()
33-
print('End Date: ' + end_date.strftime("%Y-%m-%d"))
34-
data = yf.download([ticker], start=start_date, end=end_date)[['Close']]
35-
data = data.reset_index()
36-
print(data)
37-
38-
plotter = Plotter(True, project_folder, sec.info['shortName'], sec.info['currency'], STOCK_TICKER)
29+
def train_LSTM_network(stock):
30+
data = StockData(stock)
3931

40-
training_data = data[data['Date'] < validation_date].copy()
41-
test_data = data[data['Date'] >= validation_date].copy()
42-
training_data = training_data.set_index('Date')
43-
# Set the data frame index using column Date
44-
test_data = test_data.set_index('Date')
45-
print(test_data)
46-
plotter.plot_histogram_data_split(training_data, test_data, validation_date)
32+
plotter = Plotter(True, stock.get_project_folder(), data.get_stock_short_name(), data.get_stock_currency(), stock.get_ticker())
4733

48-
data = StockData()
49-
(x_train, y_train), (x_test, y_test) = data.to_numpy(TIME_STEPS, min_max, training_data, test_data)
34+
(x_train, y_train), (x_test, y_test), (min_max, test_data) = data.download_transform_to_numpy(TIME_STEPS)
5035

5136
print(x_test)
5237

53-
lstm = LongShortTermMemory(project_folder)
38+
lstm = LongShortTermMemory(stock.get_project_folder())
5439
model = lstm.create_model(x_train)
5540

5641
defined_metrics = [
@@ -63,7 +48,7 @@ def train_LSTM_network(start_date, ticker, validation_date):
6348
history = model.fit(x_train, y_train, epochs=EPOCHS, batch_size=BATCH_SIZE, validation_data=(x_test, y_test),
6449
callbacks=[callback])
6550
print("saving weights")
66-
model.save(os.path.join(project_folder, 'model_weights.h5'))
51+
model.save(os.path.join(stock.get_project_folder(), 'model_weights.h5'))
6752
plotter.plot_loss(history)
6853
plotter.plot_mse(history)
6954

@@ -77,7 +62,7 @@ def train_LSTM_network(start_date, ticker, validation_date):
7762
test_predictions_baseline = model.predict(x_test)
7863
test_predictions_baseline = min_max.inverse_transform(test_predictions_baseline)
7964
test_predictions_baseline = pd.DataFrame(test_predictions_baseline)
80-
test_predictions_baseline.to_csv(os.path.join(project_folder, 'predictions.csv'))
65+
test_predictions_baseline.to_csv(os.path.join(stock.get_project_folder(), 'predictions.csv'))
8166

8267
test_predictions_baseline.rename(columns={0: STOCK_TICKER + '_predicted'}, inplace=True)
8368
test_predictions_baseline = test_predictions_baseline.round(decimals=0)
@@ -105,9 +90,10 @@ def train_LSTM_network(start_date, ticker, validation_date):
10590
print('Validation Date: ' + STOCK_START_DATE.strftime("%Y-%m-%d"))
10691
print('Generating folder: ' + TOKEN)
10792
# create project run folder
108-
project_folder = os.path.join(os.getcwd(), TOKEN)
109-
if not os.path.exists(project_folder):
110-
os.makedirs(project_folder)
93+
PROJECT_FOLDER = os.path.join(os.getcwd(), TOKEN)
94+
if not os.path.exists(PROJECT_FOLDER):
95+
os.makedirs(PROJECT_FOLDER)
11196

97+
stock_prediction = StockPrediction(STOCK_TICKER, STOCK_START_DATE, STOCK_VALIDATION_DATE, PROJECT_FOLDER)
11298
# Execute Deep Learning model
113-
train_LSTM_network(STOCK_START_DATE, STOCK_TICKER, STOCK_VALIDATION_DATE)
99+
train_LSTM_network(stock_prediction)

stock_prediction_numpy.py

Lines changed: 34 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -13,19 +13,50 @@
1313
# limitations under the License.
1414
# ==============================================================================
1515
import numpy as np
16-
import pandas as pd
1716
from datetime import timedelta
1817
import random
18+
import pandas as pd
19+
from sklearn.preprocessing import MinMaxScaler
20+
from datetime import datetime
21+
import yfinance as yf
22+
from stock_prediction_plotter import Plotter
1923

2024

2125
class StockData:
26+
def __init__(self, stock):
27+
self._stock = stock
28+
self._sec = yf.Ticker(self._stock.get_ticker())
29+
2230
def __data_verification(self, train):
2331
print('mean:', train.mean(axis=0))
2432
print('max', train.max())
2533
print('min', train.min())
2634
print('Std dev:', train.std(axis=0))
2735

28-
def to_numpy(self, time_steps, min_max, training_data, test_data):
36+
def get_stock_short_name(self):
37+
return self._sec.info['shortName']
38+
39+
def get_stock_currency(self):
40+
return self._sec.info['currency']
41+
42+
def download_transform_to_numpy(self, time_steps):
43+
min_max = MinMaxScaler(feature_range=(0, 1))
44+
end_date = datetime.today()
45+
print('End Date: ' + end_date.strftime("%Y-%m-%d"))
46+
data = yf.download([self._stock.get_ticker()], start=self._stock.get_start_date(), end=end_date)[['Close']]
47+
data = data.reset_index()
48+
print(data)
49+
50+
plotter = Plotter(True, self._stock.get_project_folder(), self._sec.info['shortName'], self._sec.info['currency'], self._stock.get_ticker())
51+
52+
training_data = data[data['Date'] < self._stock.get_validation_date()].copy()
53+
test_data = data[data['Date'] >= self._stock.get_validation_date()].copy()
54+
training_data = training_data.set_index('Date')
55+
# Set the data frame index using column Date
56+
test_data = test_data.set_index('Date')
57+
print(test_data)
58+
plotter.plot_histogram_data_split(training_data, test_data, self._stock.get_validation_date())
59+
2960
train_scaled = min_max.fit_transform(training_data)
3061
self.__data_verification(train_scaled)
3162

@@ -52,7 +83,7 @@ def to_numpy(self, time_steps, min_max, training_data, test_data):
5283

5384
x_test, y_test = np.array(x_test), np.array(y_test)
5485
x_test = np.reshape(x_test, (x_test.shape[0], x_test.shape[1], 1))
55-
return (x_train, y_train), (x_test, y_test)
86+
return (x_train, y_train), (x_test, y_test), (min_max, test_data)
5687

5788
def __daterange(self, start_date, end_date):
5889
for n in range(int((end_date - start_date).days)):

0 commit comments

Comments
 (0)