1616import secrets
1717import pandas as pd
1818import tensorflow as tf
19- from sklearn .preprocessing import MinMaxScaler
2019from datetime import datetime
21- import yfinance as yf
20+
21+ from stock_prediction_class import StockPrediction
2222from stock_prediction_lstm import LongShortTermMemory
2323from stock_prediction_numpy import StockData
2424from stock_prediction_plotter import Plotter
2525
2626os .environ ["PATH" ] += os .pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/'
2727
2828
29- def train_LSTM_network (start_date , ticker , validation_date ):
30- min_max = MinMaxScaler (feature_range = (0 , 1 ))
31- sec = yf .Ticker (ticker )
32- end_date = datetime .today ()
33- print ('End Date: ' + end_date .strftime ("%Y-%m-%d" ))
34- data = yf .download ([ticker ], start = start_date , end = end_date )[['Close' ]]
35- data = data .reset_index ()
36- print (data )
37-
38- plotter = Plotter (True , project_folder , sec .info ['shortName' ], sec .info ['currency' ], STOCK_TICKER )
29+ def train_LSTM_network (stock ):
30+ data = StockData (stock )
3931
40- training_data = data [data ['Date' ] < validation_date ].copy ()
41- test_data = data [data ['Date' ] >= validation_date ].copy ()
42- training_data = training_data .set_index ('Date' )
43- # Set the data frame index using column Date
44- test_data = test_data .set_index ('Date' )
45- print (test_data )
46- plotter .plot_histogram_data_split (training_data , test_data , validation_date )
32+ plotter = Plotter (True , stock .get_project_folder (), data .get_stock_short_name (), data .get_stock_currency (), stock .get_ticker ())
4733
48- data = StockData ()
49- (x_train , y_train ), (x_test , y_test ) = data .to_numpy (TIME_STEPS , min_max , training_data , test_data )
34+ (x_train , y_train ), (x_test , y_test ), (min_max , test_data ) = data .download_transform_to_numpy (TIME_STEPS )
5035
5136 print (x_test )
5237
53- lstm = LongShortTermMemory (project_folder )
38+ lstm = LongShortTermMemory (stock . get_project_folder () )
5439 model = lstm .create_model (x_train )
5540
5641 defined_metrics = [
@@ -63,7 +48,7 @@ def train_LSTM_network(start_date, ticker, validation_date):
6348 history = model .fit (x_train , y_train , epochs = EPOCHS , batch_size = BATCH_SIZE , validation_data = (x_test , y_test ),
6449 callbacks = [callback ])
6550 print ("saving weights" )
66- model .save (os .path .join (project_folder , 'model_weights.h5' ))
51+ model .save (os .path .join (stock . get_project_folder () , 'model_weights.h5' ))
6752 plotter .plot_loss (history )
6853 plotter .plot_mse (history )
6954
@@ -77,7 +62,7 @@ def train_LSTM_network(start_date, ticker, validation_date):
7762 test_predictions_baseline = model .predict (x_test )
7863 test_predictions_baseline = min_max .inverse_transform (test_predictions_baseline )
7964 test_predictions_baseline = pd .DataFrame (test_predictions_baseline )
80- test_predictions_baseline .to_csv (os .path .join (project_folder , 'predictions.csv' ))
65+ test_predictions_baseline .to_csv (os .path .join (stock . get_project_folder () , 'predictions.csv' ))
8166
8267 test_predictions_baseline .rename (columns = {0 : STOCK_TICKER + '_predicted' }, inplace = True )
8368 test_predictions_baseline = test_predictions_baseline .round (decimals = 0 )
@@ -105,9 +90,10 @@ def train_LSTM_network(start_date, ticker, validation_date):
10590 print ('Validation Date: ' + STOCK_START_DATE .strftime ("%Y-%m-%d" ))
10691 print ('Generating folder: ' + TOKEN )
10792 # create project run folder
108- project_folder = os .path .join (os .getcwd (), TOKEN )
109- if not os .path .exists (project_folder ):
110- os .makedirs (project_folder )
93+ PROJECT_FOLDER = os .path .join (os .getcwd (), TOKEN )
94+ if not os .path .exists (PROJECT_FOLDER ):
95+ os .makedirs (PROJECT_FOLDER )
11196
97+ stock_prediction = StockPrediction (STOCK_TICKER , STOCK_START_DATE , STOCK_VALIDATION_DATE , PROJECT_FOLDER )
11298 # Execute Deep Learning model
113- train_LSTM_network (STOCK_START_DATE , STOCK_TICKER , STOCK_VALIDATION_DATE )
99+ train_LSTM_network (stock_prediction )
0 commit comments