11from hygdra_forecasting .model .build import ConvCausalLTSM
22from datamodel .ticker_cluster import TKGroup
33from torch import load , tensor , no_grad , cuda , device
4- from hygdra_forecasting .utils .preprocessing import ohlv_to_dataframe_inference , dataframe_to_dataset_inference
5- from pandas import DataFrame , DateOffset
4+ from hygdra_forecasting .utils .ohlv import get_kraken_data_to_json
5+ from hygdra_forecasting .utils .dataset import dict_to_dataset_inference
6+ import numpy as np
67import redis
78from os import getenv
9+ import json
810
911class StockPredictor :
1012 def __init__ (self , interval : str = 'days' ):
@@ -27,34 +29,58 @@ def predict(self):
2729 self .model .load_state_dict (load (f'weight/{ self .interval } /best_model.pth' )["model_state_dict" ]) # {groups_name}.pth
2830 self .model .eval ()
2931
30- df , dict_unorm = ohlv_to_dataframe_inference (tickers , interval = self .interval_transform [self .interval ])
31- sequences_dict = dataframe_to_dataset_inference (df , tickers )
32-
33- df_result = DataFrame ()
34- df_result .index = df [tickers [0 ] + "_close" ].iloc [- 100 :].index
35- df_result ["Date" ] = df_result .index
36-
32+ df , dict_unorm , index_timestamp = get_kraken_data_to_json (tickers , interval = self .interval_transform [self .interval ])
33+ sequences_dict = dict_to_dataset_inference (df )
34+
35+ # 1) join all as before
36+ # 2) optimize json dict of dataframe
3737 for ticker in tickers :
38- sequence = tensor (sequences_dict [ticker ]).float ()
39-
40- with no_grad ():
41- predictions = self .model (sequence )
42-
43- predictions = predictions .squeeze ().numpy ()
44- df_result [ticker + "_pred" ] = predictions .reshape (- 1 )[- 100 :] * dict_unorm [ticker ][1 ] + dict_unorm [ticker ][0 ]
45- df_result [ticker + "_close" ] = df [ticker + '_close' ].iloc [- 100 :] * dict_unorm [ticker ][1 ] + dict_unorm [ticker ][0 ]
46-
47- df_result ["pred_date" ] = df_result .index + DateOffset (days = 14 )
48- df_result .to_csv (f'data/{ groups_name } _{ self .interval } .csv' )
49-
50- json_data = df_result .to_json (orient = "records" , date_format = "iso" )
51- redis_key = f"{ groups_name } _{ self .interval } "
52-
53- try :
54- self .redis_client .set (redis_key , json_data )
55- print (f"Saved predictions for group '{ groups_name } ' to Redis with key '{ redis_key } '" )
56- except Exception as e :
57- print (f"Error saving predictions for group '{ groups_name } ' to Redis: { e } " )
38+ try :
39+ ticker = ticker .split ("-" )[0 ] + "USD"
40+ sequence = tensor (sequences_dict [ticker ]).float ()
41+
42+ with no_grad ():
43+ predictions = self .model (sequence )
44+
45+ date_array = np .array (index_timestamp [ticker ], dtype = 'datetime64[s]' )
46+ df [ticker ]["Date" ] = date_array
47+ predictions = predictions .squeeze ().numpy ().reshape (- 1 )
48+
49+ df [ticker ]["forecasting" ] = predictions * dict_unorm [ticker ]["close" ]["std" ] + dict_unorm [ticker ]["close" ]["mean" ]
50+ df [ticker ]["close" ] = df [ticker ]["close" ] * dict_unorm [ticker ]["close" ]["std" ] + dict_unorm [ticker ]["close" ]["std" ]
51+
52+ # prediction interval # (double check)
53+ offset = None
54+ if self .interval == "minutes" :
55+ offset = np .timedelta64 (14 , "m" )
56+ elif self .interval == "thirty" :
57+ offset = np .timedelta64 (14 * 30 , "m" )
58+ elif self .interval == "hours" :
59+ offset = np .timedelta64 (14 , "h" )
60+ else :
61+ offset = np .timedelta64 (14 , "D" )
62+
63+ df [ticker ]["pred_date" ] = df [ticker ]["Date" ] + offset
64+ float_keys = ['close' , 'forecasting' , 'low' , 'high' , 'open' , 'volume' , 'upper' , 'lower' , 'width' , 'rsi' , 'roc' , 'diff' , 'percent_change_close' ]
65+ for key in float_keys :
66+ df [ticker ][key ] = df [ticker ][key ].astype (float ).tolist ()
67+
68+ date_keys = ['Date' , 'pred_date' ]
69+ for key in date_keys :
70+ df [ticker ][key ] = df [ticker ][key ].astype (str ).tolist ()
71+
72+ redis_key = f"{ ticker } _{ self .interval } "
73+
74+ except Exception as e :
75+ print (f"error predicting stock : { e } on stock { ticker } " )
76+
77+ # register data
78+ try :
79+ df_json = json .dumps (df [ticker ])
80+ self .redis_client .set (redis_key , df_json )
81+ print (f"Saved predictions for group '{ groups_name } ' to Redis with key '{ redis_key } '" )
82+ except Exception as e :
83+ print (f"Error saving predictions for group '{ groups_name } ' to Redis: { e } " )
5884
5985if __name__ == "__main__" :
6086 predictor = StockPredictor (interval = 'days' )
0 commit comments