Skip to content

Commit 15a2db3

Browse files
committed
update inference
1 parent cd55b1e commit 15a2db3

File tree

1 file changed

+49
-51
lines changed

1 file changed

+49
-51
lines changed

stock_prediction_deep_learning_inference.py

Lines changed: 49 additions & 51 deletions
Original file line numberDiff line numberDiff line change
@@ -20,22 +20,21 @@
2020

2121
from stock_prediction_class import StockPrediction
2222
from stock_prediction_numpy import StockData
23-
from datetime import timedelta
24-
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/'
23+
from datetime import timedelta, datetime
2524

25+
os.environ["PATH"] += os.pathsep + 'C:/Program Files (x86)/Graphviz2.38/bin/'
2626

2727
def main(argv):
2828
print(tf.version.VERSION)
2929
inference_folder = os.path.join(os.getcwd(), RUN_FOLDER)
30-
stock = StockPrediction(STOCK_TICKER, STOCK_START_DATE, STOCK_VALIDATION_DATE, inference_folder)
30+
stock = StockPrediction(STOCK_TICKER, STOCK_START_DATE, STOCK_VALIDATION_DATE, inference_folder, GITHUB_URL, EPOCHS, TIME_STEPS, TOKEN, BATCH_SIZE)
3131

3232
data = StockData(stock)
3333

3434
(x_train, y_train), (x_test, y_test), (training_data, test_data) = data.download_transform_to_numpy(TIME_STEPS, inference_folder)
3535
min_max = data.get_min_max()
3636

3737
# load future data
38-
3938
print('Latest Stock Price')
4039
latest_close_price = test_data.Close.iloc[-1]
4140
latest_date = test_data[-1:]['Close'].idxmin()
@@ -45,7 +44,7 @@ def main(argv):
4544

4645
tomorrow_date = latest_date + timedelta(1)
4746
# Specify the next 300 days
48-
next_year = latest_date + timedelta(TIME_STEPS*100)
47+
next_year = latest_date + timedelta(TIME_STEPS * 100)
4948

5049
print('Future Date')
5150
print(tomorrow_date)
@@ -55,53 +54,52 @@ def main(argv):
5554

5655
x_test, y_test, test_data = data.generate_future_data(TIME_STEPS, min_max, tomorrow_date, next_year, latest_close_price)
5756

58-
# load the weights from our best model
59-
model = tf.keras.models.load_model(os.path.join(inference_folder, 'model_weights.h5'))
60-
model.summary()
61-
62-
#print(x_test)
63-
#print(test_data)
64-
# display the content of the model
65-
baseline_results = model.evaluate(x_test, y_test, verbose=2)
66-
for name, value in zip(model.metrics_names, baseline_results):
67-
print(name, ': ', value)
68-
print()
69-
70-
# perform a prediction
71-
test_predictions_baseline = model.predict(x_test)
72-
test_predictions_baseline = min_max.inverse_transform(test_predictions_baseline)
73-
test_predictions_baseline = pd.DataFrame(test_predictions_baseline)
74-
75-
test_predictions_baseline.rename(columns={0: STOCK_TICKER + '_predicted'}, inplace=True)
76-
test_predictions_baseline = test_predictions_baseline.round(decimals=0)
77-
test_data.to_csv(os.path.join(inference_folder, 'generated.csv'))
78-
test_predictions_baseline.to_csv(os.path.join(inference_folder, 'inference.csv'))
79-
80-
print("plotting predictions")
81-
plt.figure(figsize=(14, 5))
82-
plt.plot(test_predictions_baseline[STOCK_TICKER + '_predicted'], color='red', label='Predicted [' + 'GOOG' + '] price')
83-
plt.xlabel('Time')
84-
plt.ylabel('Price [' + 'USD' + ']')
85-
plt.legend()
86-
plt.title('Prediction')
87-
plt.savefig(os.path.join(inference_folder, STOCK_TICKER + '_future_prediction.png'))
88-
plt.pause(0.001)
89-
90-
plt.figure(figsize=(14, 5))
91-
plt.plot(test_data.Close, color='green', label='Simulated [' + 'GOOG' + '] price')
92-
plt.xlabel('Time')
93-
plt.ylabel('Price [' + 'USD' + ']')
94-
plt.legend()
95-
plt.title('Random')
96-
plt.savefig(os.path.join(inference_folder, STOCK_TICKER + '_future_random.png'))
97-
plt.pause(0.001)
98-
plt.show()
99-
57+
# Check if the future data is not empty
58+
if x_test.shape[0] > 0:
59+
# load the weights from our best model
60+
model = tf.keras.models.load_model(os.path.join(inference_folder, 'model_weights.h5'))
61+
model.summary()
62+
63+
# display the content of the model
64+
baseline_results = model.evaluate(x_test, y_test, verbose=2)
65+
for name, value in zip(model.metrics_names, baseline_results):
66+
print(name, ': ', value)
67+
print()
68+
69+
# perform a prediction
70+
test_predictions_baseline = model.predict(x_test)
71+
test_predictions_baseline = min_max.inverse_transform(test_predictions_baseline)
72+
test_predictions_baseline = pd.DataFrame(test_predictions_baseline)
73+
74+
test_predictions_baseline.rename(columns={0: STOCK_TICKER + '_predicted'}, inplace=True)
75+
test_predictions_baseline = test_predictions_baseline.round(decimals=0)
76+
test_data.to_csv(os.path.join(inference_folder, 'generated.csv'))
77+
test_predictions_baseline.to_csv(os.path.join(inference_folder, 'inference.csv'))
78+
79+
# Plotting predictions
80+
plt.figure(figsize=(14, 5))
81+
plt.plot(test_data.Close, color='green', label='Simulated [' + STOCK_TICKER + '] price')
82+
plt.plot(test_predictions_baseline[STOCK_TICKER + '_predicted'], color='red', label='Predicted [' + STOCK_TICKER + '] price')
83+
plt.xlabel('Time')
84+
plt.ylabel('Price [USD]')
85+
plt.legend()
86+
plt.title('Simulated vs Predicted Prices')
87+
plt.savefig(os.path.join(inference_folder, STOCK_TICKER + '_future_comparison.png'))
88+
plt.show()
89+
else:
90+
print("Error: Future data is empty.")
10091

10192
if __name__ == '__main__':
10293
TIME_STEPS = 3
103-
RUN_FOLDER = 'GOOG_20200711_76c9683d2457682b0e2e918d8ef6670e'
104-
STOCK_TICKER = 'GOOG'
105-
STOCK_START_DATE = pd.to_datetime('2004-08-01')
106-
STOCK_VALIDATION_DATE = pd.to_datetime('2017-01-01')
94+
RUN_FOLDER = '^FTSE_20240103_edae6b8f5fc742031805151aeba98571'
95+
TOKEN = 'edae6b8f5fc742031805151aeba98571'
96+
STOCK_TICKER = '^FTSE'
97+
BATCH_SIZE = 10
98+
STOCK_START_DATE = pd.to_datetime('2017-11-01')
99+
start_date = pd.to_datetime('2017-01-01')
100+
end_date = datetime.today()
101+
duration = end_date - start_date
102+
STOCK_VALIDATION_DATE = start_date + 0.8 * duration
103+
GITHUB_URL = "https://github.com/JordiCorbilla/stock-prediction-deep-neural-learning/raw/master/"
104+
EPOCHS = 100
107105
app.run(main)

0 commit comments

Comments
 (0)