Skip to content

Commit 3feaed0

Browse files
committed
framework update for inference on kraken data
1 parent 0d98b9f commit 3feaed0

File tree

24 files changed

+91
-233
lines changed

24 files changed

+91
-233
lines changed

app/api/main.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -156,12 +156,8 @@ def predict_stock(ticker: str, seq: SequenceRequest):
156156
Raises:
157157
HTTPException: If the ticker is not found in any group or if the CSV file cannot be read.
158158
"""
159-
group_name = get_group_for_ticker(ticker)
160-
if not group_name:
161-
raise HTTPException(status_code=400, detail="Ticker not found in groups")
162-
163159
try:
164-
df = pd.read_csv(f'data/{group_name}_{seq}.csv', parse_dates=['Date'])
160+
df = pd.read_csv(f'data/{ticker}_{seq}.csv', parse_dates=['Date'])
165161
except Exception as e:
166162
raise HTTPException(status_code=500, detail=f"Error reading CSV file: {str(e)}")
167163

app/sheduler/datamodel/ticker_cluster.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,17 @@
1515
MEMECOIN = ("memecoin", ['BONK-USD', 'SHIB-USD', 'PEPECOIN-USD', 'DOGE-USD'])
1616
"""
1717

18+
class TKGroup(Enum):
19+
CRYPTO_STANDARD = ("crypto-standard", ['SAND-USD', 'IMX-USD', "GALA-USD", "AXS-USD", "MANA-USD", "AAVE-USD", "ETH-USD", "CRO-USD", "BTC-USD", "XRP-USD", "ADA-USD", "SOL-USD", "PEPE-USD", "POPCAT-USD", "DOGE-USD", "TRUMP-USD", "SUI-USD"])
20+
21+
class TKGroupName(str, Enum):
22+
cryptoStandard = "crypto-standard"
23+
24+
"""
1825
class TKGroup(Enum):
1926
CRYPTO_STANDARD = ("crypto-standard", ["ADA-USD", "SOL-USD", "XRP-USD", "ETH-USD", "BTC-USD"])
2027
FREEDOM = ("thisiselonmusk", ["XRP-USD", "SHIB-USD", "SOL-USD", "BTC-USD"])
21-
WEB3 = ("web3", ["LINK-USD", "SOL"])
28+
WEB3 = ("web3", ["LINK-USD", "SOL"]) # "LINK-USD", "SOL"
2229
GAME = ('gaming', ['SAND', 'IMX', "GALA", "AXS", "MANA"])
2330
DEFI = ("defi", ['AVAX-USD', 'LINK-USD', 'UNI-USD', 'STX-USD', 'FTM-USD', "INJ-USD"])
2431
MEMECOIN = ("memecoin", [ "SHIB-USD"])
@@ -37,3 +44,4 @@ class TKGroupName(str, Enum):
3744
#finance = "finance"
3845
#energie = "energie"
3946
47+
"""

app/sheduler/finetune.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,13 +4,14 @@
44
from torch import save, load
55
from torch.utils.data import DataLoader
66
from torch import cuda, device
7-
from hygdra_forecasting.dataloader.dataloader import StockDataset
7+
from hygdra_forecasting.dataloader.dataloader import StockDataset as standard
8+
from hygdra_forecasting.dataloader.dataloader_kraken import StockDataset
89
from hygdra_forecasting.utils.learning_rate_sheduler import CosineWarmup
910
from hygdra_forecasting.model.train import train_model, setup_seed
1011
import torch.nn as nn
1112

1213
class StockFineTuner:
13-
def __init__(self, interval: str = 'days', base_weight: str = 'weight/days/best_model.pth', epoch=5, learnig_rate=0.01):
14+
def __init__(self, interval: str = 'days', base_weight: str = 'weight/days/best_model.pth', epoch=100, learnig_rate=0.01):
1415
self.interval = interval
1516
self.base_weight = base_weight
1617
self.device = device('cuda:0') if cuda.is_available() else device('cpu')
@@ -36,8 +37,9 @@ def finetune_one(self, tickers: List[str], path: str):
3637
epochs=self.epoch,
3738
learning_rate=self.learning_rate,
3839
save_epoch=False,
39-
lrfn=self.tuning_scheduler,
40-
criterion=nn.L1Loss(),
40+
# lrfn=self.tuning_scheduler,
41+
lrfn=CosineWarmup(self.learning_rate, self.epoch).lrfn,
42+
# criterion=nn.L1Loss(), # l1 seems to make it harder
4143
checkpoint_file=load(self.base_weight)
4244
)
4345

@@ -52,7 +54,7 @@ def finetune_many(self):
5254
if __name__ == "__main__":
5355
# {"days" : '1440', "minutes" : '1', "hours" : '60', "thrity" : "30"}
5456
interval = "minutes"
55-
tuner = StockFineTuner(interval=interval, base_weight=f'../../weight/best_model.pth')
57+
tuner = StockFineTuner(interval=interval, base_weight=f'weight/minutes/best_model.pth')
5658
tuner.finetune_many()
5759
# training is still realy weirdly reset while not in training phase
5860
# manage non crypto course via kraken

app/sheduler/finetune_temp.py

Lines changed: 0 additions & 48 deletions
This file was deleted.

app/sheduler/inference.py

Lines changed: 55 additions & 29 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,12 @@
11
from hygdra_forecasting.model.build import ConvCausalLTSM
22
from datamodel.ticker_cluster import TKGroup
33
from torch import load, tensor, no_grad, cuda, device
4-
from hygdra_forecasting.utils.preprocessing import ohlv_to_dataframe_inference, dataframe_to_dataset_inference
5-
from pandas import DataFrame, DateOffset
4+
from hygdra_forecasting.utils.ohlv import get_kraken_data_to_json
5+
from hygdra_forecasting.utils.dataset import dict_to_dataset_inference
6+
import numpy as np
67
import redis
78
from os import getenv
9+
import json
810

911
class StockPredictor:
1012
def __init__(self, interval: str = 'days'):
@@ -27,34 +29,58 @@ def predict(self):
2729
self.model.load_state_dict(load(f'weight/{self.interval}/best_model.pth')["model_state_dict"]) # {groups_name}.pth
2830
self.model.eval()
2931

30-
df, dict_unorm = ohlv_to_dataframe_inference(tickers, interval=self.interval_transform[self.interval])
31-
sequences_dict = dataframe_to_dataset_inference(df, tickers)
32-
33-
df_result = DataFrame()
34-
df_result.index = df[tickers[0] + "_close"].iloc[-100:].index
35-
df_result["Date"] = df_result.index
36-
32+
df, dict_unorm, index_timestamp = get_kraken_data_to_json(tickers, interval=self.interval_transform[self.interval])
33+
sequences_dict = dict_to_dataset_inference(df)
34+
35+
# 1) join all as before
36+
# 2) optimize json dict of dataframe
3737
for ticker in tickers:
38-
sequence = tensor(sequences_dict[ticker]).float()
39-
40-
with no_grad():
41-
predictions = self.model(sequence)
42-
43-
predictions = predictions.squeeze().numpy()
44-
df_result[ticker + "_pred"] = predictions.reshape(-1)[-100:] * dict_unorm[ticker][1] + dict_unorm[ticker][0]
45-
df_result[ticker + "_close"] = df[ticker + '_close'].iloc[-100:] * dict_unorm[ticker][1] + dict_unorm[ticker][0]
46-
47-
df_result["pred_date"] = df_result.index + DateOffset(days=14)
48-
df_result.to_csv(f'data/{groups_name}_{self.interval}.csv')
49-
50-
json_data = df_result.to_json(orient="records", date_format="iso")
51-
redis_key = f"{groups_name}_{self.interval}"
52-
53-
try:
54-
self.redis_client.set(redis_key, json_data)
55-
print(f"Saved predictions for group '{groups_name}' to Redis with key '{redis_key}'")
56-
except Exception as e:
57-
print(f"Error saving predictions for group '{groups_name}' to Redis: {e}")
38+
try :
39+
ticker = ticker.split("-")[0] + "USD"
40+
sequence = tensor(sequences_dict[ticker]).float()
41+
42+
with no_grad():
43+
predictions = self.model(sequence)
44+
45+
date_array = np.array(index_timestamp[ticker], dtype='datetime64[s]')
46+
df[ticker]["Date"] = date_array
47+
predictions = predictions.squeeze().numpy().reshape(-1)
48+
49+
df[ticker]["forecasting"] = predictions * dict_unorm[ticker]["close"]["std"] + dict_unorm[ticker]["close"]["mean"]
50+
df[ticker]["close"] = df[ticker]["close"] * dict_unorm[ticker]["close"]["std"] + dict_unorm[ticker]["close"]["std"]
51+
52+
# prediction interval # (double check)
53+
offset = None
54+
if self.interval == "minutes":
55+
offset = np.timedelta64(14, "m")
56+
elif self.interval == "thirty":
57+
offset= np.timedelta64(14 * 30, "m")
58+
elif self.interval == "hours":
59+
offset = np.timedelta64(14, "h")
60+
else :
61+
offset = np.timedelta64(14, "D")
62+
63+
df[ticker]["pred_date"] = df[ticker]["Date"] + offset
64+
float_keys = ['close', 'forecasting', 'low', 'high', 'open', 'volume', 'upper', 'lower', 'width', 'rsi', 'roc', 'diff', 'percent_change_close']
65+
for key in float_keys:
66+
df[ticker][key] = df[ticker][key].astype(float).tolist()
67+
68+
date_keys = ['Date', 'pred_date']
69+
for key in date_keys:
70+
df[ticker][key] = df[ticker][key].astype(str).tolist()
71+
72+
redis_key = f"{ticker}_{self.interval}"
73+
74+
except Exception as e:
75+
print(f"error predicting stock : {e} on stock {ticker}")
76+
77+
# register data
78+
try:
79+
df_json = json.dumps(df[ticker])
80+
self.redis_client.set(redis_key, df_json)
81+
print(f"Saved predictions for group '{groups_name}' to Redis with key '{redis_key}'")
82+
except Exception as e:
83+
print(f"Error saving predictions for group '{groups_name}' to Redis: {e}")
5884

5985
if __name__ == "__main__":
6086
predictor = StockPredictor(interval='days')
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.
0 Bytes
Binary file not shown.

0 commit comments

Comments
 (0)