@@ -60,26 +60,25 @@ def main(argv):
6060 model = tf .keras .models .load_model (os .path .join (inference_folder , 'model_weights.h5' ))
6161 model .summary ()
6262
63- # display the content of the model
64- baseline_results = model .evaluate (x_test , y_test , verbose = 2 )
65- for name , value in zip (model .metrics_names , baseline_results ):
66- print (name , ': ' , value )
67- print ()
68-
6963 # perform a prediction
7064 test_predictions_baseline = model .predict (x_test )
7165 test_predictions_baseline = min_max .inverse_transform (test_predictions_baseline )
72- test_predictions_baseline = pd .DataFrame (test_predictions_baseline )
73-
74- test_predictions_baseline .rename (columns = {0 : STOCK_TICKER + '_predicted' }, inplace = True )
75- test_predictions_baseline = test_predictions_baseline .round (decimals = 0 )
76- test_data .to_csv (os .path .join (inference_folder , 'generated.csv' ))
77- test_predictions_baseline .to_csv (os .path .join (inference_folder , 'inference.csv' ))
78-
66+ test_predictions_baseline = pd .DataFrame (test_predictions_baseline , columns = ['Predicted_Price' ])
67+
68+ # Combine the predicted values with dates from the test data
69+ predicted_dates = pd .date_range (start = test_data .index [0 ], periods = len (test_predictions_baseline ))
70+ test_predictions_baseline ['Date' ] = predicted_dates
71+
72+ # Reset the index for proper concatenation
73+ test_data .reset_index (inplace = True )
74+
75+ # Concatenate the test_data and predicted data
76+ combined_data = pd .concat ([test_data , test_predictions_baseline ], ignore_index = True )
77+
7978 # Plotting predictions
8079 plt .figure (figsize = (14 , 5 ))
81- plt .plot (test_data .Close , color = 'green' , label = 'Simulated [' + STOCK_TICKER + '] price' )
82- plt .plot (test_predictions_baseline [ STOCK_TICKER + '_predicted ' ], color = 'red' , label = 'Predicted [' + STOCK_TICKER + '] price' )
80+ plt .plot (combined_data [ 'Date' ], combined_data .Close , color = 'green' , label = 'Simulated [' + STOCK_TICKER + '] price' )
81+ plt .plot (combined_data [ 'Date' ], combined_data [ 'Predicted_Price ' ], color = 'red' , label = 'Predicted [' + STOCK_TICKER + '] price' )
8382 plt .xlabel ('Time' )
8483 plt .ylabel ('Price [USD]' )
8584 plt .legend ()
0 commit comments