Skip to content

Commit 2b12b51

Browse files
committed
update inference example
1 parent 15a2db3 commit 2b12b51

File tree

1 file changed

+14
-15
lines changed

1 file changed

+14
-15
lines changed

stock_prediction_deep_learning_inference.py

Lines changed: 14 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -60,26 +60,25 @@ def main(argv):
6060
model = tf.keras.models.load_model(os.path.join(inference_folder, 'model_weights.h5'))
6161
model.summary()
6262

63-
# display the content of the model
64-
baseline_results = model.evaluate(x_test, y_test, verbose=2)
65-
for name, value in zip(model.metrics_names, baseline_results):
66-
print(name, ': ', value)
67-
print()
68-
6963
# perform a prediction
7064
test_predictions_baseline = model.predict(x_test)
7165
test_predictions_baseline = min_max.inverse_transform(test_predictions_baseline)
72-
test_predictions_baseline = pd.DataFrame(test_predictions_baseline)
73-
74-
test_predictions_baseline.rename(columns={0: STOCK_TICKER + '_predicted'}, inplace=True)
75-
test_predictions_baseline = test_predictions_baseline.round(decimals=0)
76-
test_data.to_csv(os.path.join(inference_folder, 'generated.csv'))
77-
test_predictions_baseline.to_csv(os.path.join(inference_folder, 'inference.csv'))
78-
66+
test_predictions_baseline = pd.DataFrame(test_predictions_baseline, columns=['Predicted_Price'])
67+
68+
# Combine the predicted values with dates from the test data
69+
predicted_dates = pd.date_range(start=test_data.index[0], periods=len(test_predictions_baseline))
70+
test_predictions_baseline['Date'] = predicted_dates
71+
72+
# Reset the index for proper concatenation
73+
test_data.reset_index(inplace=True)
74+
75+
# Concatenate the test_data and predicted data
76+
combined_data = pd.concat([test_data, test_predictions_baseline], ignore_index=True)
77+
7978
# Plotting predictions
8079
plt.figure(figsize=(14, 5))
81-
plt.plot(test_data.Close, color='green', label='Simulated [' + STOCK_TICKER + '] price')
82-
plt.plot(test_predictions_baseline[STOCK_TICKER + '_predicted'], color='red', label='Predicted [' + STOCK_TICKER + '] price')
80+
plt.plot(combined_data['Date'], combined_data.Close, color='green', label='Simulated [' + STOCK_TICKER + '] price')
81+
plt.plot(combined_data['Date'], combined_data['Predicted_Price'], color='red', label='Predicted [' + STOCK_TICKER + '] price')
8382
plt.xlabel('Time')
8483
plt.ylabel('Price [USD]')
8584
plt.legend()

0 commit comments

Comments
 (0)