Skip to content

Commit aa249c0

Browse files
committed
better understanding
1 parent 0f70458 commit aa249c0

File tree

2 files changed

+137
-18
lines changed

2 files changed

+137
-18
lines changed

README.md

Lines changed: 29 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -10,20 +10,20 @@
1010
- 🧠 **Modèles de deep learning** pour la prédiction des tendances
1111
- 📊 **Extraction automatique de caractéristiques techniques** (Bollinger Bands, RSI, ROC, etc.)
1212
- 🔥 **Optimisation dynamique du taux d'apprentissage** (scheduler Cosine Warmup)
13-
- 🏗️ **Architecture modulaire et extensible** pour différents horizons temporels (journaliers, horaires, minutes)
13+
- 🏠 **Architecture modulaire et extensible** pour différents horizons temporels (journaliers, horaires, minutes)
1414
-**Compatibilité GPU** pour un entraînement rapide
1515

1616
---
1717

1818
## ⚙️ Installation
1919

20-
### 📋 Prérequis
20+
### 👋 Prérequis
2121

2222
- **Python** `>=3.8`
2323
- **GPU compatible CUDA** (optionnel, mais recommandé)
2424
- **Minimum** : 2 cœurs CPU, 2 Go RAM
2525

26-
### 🏗️ Installation via Docker
26+
### 🏠 Installation via Docker
2727

2828
Utilisez Docker pour une configuration rapide et reproductible :
2929

@@ -33,7 +33,7 @@ docker-compose up -d
3333

3434
> **Note :** Assurez-vous d'avoir installé Docker et Docker Compose sur votre machine.
3535
36-
### 🏗️ Installation Locale
36+
### 🏠 Installation Locale
3737

3838
Il est recommandé d'exécuter le projet dans un environnement virtuel.
3939

@@ -108,25 +108,43 @@ python app/scheduler/scheduler.py
108108

109109
---
110110

111-
## 📜 Licence
111+
## 🐟 Sélection du Modèle et du Mode d'Exécution
112112

113-
Ce projet est sous licence **GNU**.
113+
Le script principal vous permet de choisir dynamiquement :
114114

115-
---
115+
- Le modèle (ex. `ConvCausalLTSM`, `LtsmAttentionforecastPred`, `VisionLiquidNet`)
116+
- Le type de chargeur de données (`StockDataset`, `StockGraphDataset`)
117+
- Le mode d'exécution (`inférence`, `évaluation`, `entraînement`)
116118

117-
## 📧 Contact
119+
Utilisation :
118120

119-
Bucamp Axle - [[email protected]](mailto:[email protected])
121+
```bash
122+
python main.py --model ConvCausalLTSM --dataloader StockDataset --mode inference
123+
```
120124

121125
---
122126

123-
## 🚀 Projets et Améliorations Futures
127+
## 🌟 Améliorations Futures
124128

129+
- Intégration de nouveaux modèles (Liquid Neural Networks, Transformers, etc.)
125130
- Tests unitaires et d'intégration
126131
- Mode en direct via l'API Kraken
127132

128133
---
129134

135+
## 📄 Licence
136+
137+
Ce projet est sous licence **GNU**.
138+
139+
---
140+
141+
## 📧 Contact
142+
143+
Bucamp Axle - [[email protected]](mailto:[email protected])
144+
145+
---
146+
130147
Profitez du trading assisté par IA avec **Hygdra Forecasting** ! 🚀
131148

132-
---
149+
---
150+

main.py

Lines changed: 108 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,108 @@
1-
# TODO
2-
# - train function to cmd
3-
# - eval to cmd
4-
# - deploy model ? (api, scheduler, frontend, redis)
5-
# - deploy managing multiple type of model
6-
# - train model_list
7-
# - eval pth file, model type
1+
import argparse
2+
from torch import cuda, device, load, no_grad, tensor, nn
3+
from torch.utils.data import DataLoader
4+
from pandas import DataFrame
5+
6+
# Import necessary modules from your package
7+
from hygdra_forecasting.utils.preprocessing import (
8+
ohlv_to_dataframe_inference, dataframe_to_dataset_inference
9+
)
10+
from hygdra_forecasting.model.build import ConvCausalLTSM, LtsmAttentionforecastPred
11+
from hygdra_forecasting.model.build_graph import GraphforecastPred, GraphTransformerforecastPred
12+
13+
from hygdra_forecasting.dataloader.dataloader import StockDataset
14+
from hygdra_forecasting.dataloader.GraphDataloader import StockGraphDataset
15+
from hygdra_forecasting.model.train import train_model, setup_seed
16+
from hygdra_forecasting.model.eval import validate
17+
from hygdra_forecasting.utils.learning_rate_sheduler import CosineWarmup
18+
from liquidnet.vision_liquidnet import VisionLiquidNet
19+
20+
21+
def get_device():
22+
if cuda.is_available():
23+
print("Running on the GPU")
24+
return device("cuda:0")
25+
print("Running on the CPU")
26+
return device("cpu")
27+
28+
def load_model(model_name, input_shape, checkpoint_path=None):
29+
if model_name == "ConvCausalLTSM":
30+
model = ConvCausalLTSM(input_shape=input_shape)
31+
elif model_name == "LtsmAttentionforecastPred":
32+
model = LtsmAttentionforecastPred(input_shape=input_shape)
33+
elif model_name == "VisionLiquidNet":
34+
model = VisionLiquidNet(64, 10)
35+
elif model_name == "GraphTransformerforecastPred":
36+
model = GraphTransformerforecastPred(input_shape=input_shape)
37+
elif model_name == "GraphforecastPred":
38+
model = GraphforecastPred(input_shape=input_shape)
39+
else:
40+
raise ValueError("Unknown model type")
41+
42+
if checkpoint_path:
43+
model.load_state_dict(load(checkpoint_path, weights_only=True))
44+
model.eval()
45+
return model
46+
47+
def inference(tickers):
48+
df, dict_unorm = ohlv_to_dataframe_inference(tickers)
49+
sequences_dict = dataframe_to_dataset_inference(df, tickers)
50+
input_shape = sequences_dict[tickers[0]][0].shape
51+
model = load_model("ConvCausalLTSM", input_shape, "weight/basemodel.pt")
52+
53+
df_result = DataFrame()
54+
df_result.index = df[tickers[0] + "_close"].iloc[-100:].index
55+
for ticker in tickers:
56+
sequence = tensor(sequences_dict[ticker]).float()
57+
with no_grad():
58+
predictions = model(sequence)
59+
predictions = predictions.squeeze().numpy()
60+
df_result[ticker + "_pred"] = predictions.reshape(-1)[-100:] * dict_unorm[ticker][1] + dict_unorm[ticker][0]
61+
df_result[ticker + "_close"] = df[ticker + "_close"].iloc[-100:] * dict_unorm[ticker][1] + dict_unorm[ticker][0]
62+
63+
print(df_result)
64+
65+
def evaluate(model_name, tickers):
66+
dataset = StockDataset(ticker=tickers)
67+
dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=1)
68+
input_sample, _ = dataset.__getitem__(0)
69+
model = load_model(model_name, input_sample.shape, "weight/epoch-380_loss-0.2699198153614998.pt")
70+
criterion = nn.L1Loss()
71+
print(validate(model, dataloader, criterion))
72+
73+
def train(model_name, tickers, tickers_val, etf_tickers):
74+
if "graph" in model_name :
75+
dataset = StockGraphDataset(ticker=tickers, indics=etf_tickers)
76+
dataset_val = StockGraphDataset(ticker=tickers_val, indics=etf_tickers)
77+
else :
78+
dataset = StockDataset(ticker=tickers)
79+
dataset_val = StockDataset(ticker=tickers_val)
80+
81+
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=1)
82+
dataloader_val = DataLoader(dataset_val, batch_size=32, shuffle=True, num_workers=1)
83+
84+
input_sample, _ = dataset.__getitem__(0)
85+
setup_seed(20)
86+
model = load_model(model_name, input_sample.shape)
87+
train_model(model, dataloader, val_dataloader=dataloader_val, epochs=100, learning_rate=0.01,
88+
lrfn=CosineWarmup(0.01, 100).lrfn)
89+
90+
def main():
91+
parser = argparse.ArgumentParser()
92+
parser.add_argument("--mode", choices=["inference", "evaluate", "train"], required=True)
93+
parser.add_argument("--model", choices=["ConvCausalLTSM", "LtsmAttentionforecastPred", "GraphforecastPred", "GraphTransformerforecastPred"], required=True)
94+
parser.add_argument("--tickers", nargs='+', default=["DEFI", "PANW", "MRVL", "NKLA", "AFRM", "EBIT.TO", "^FCHI", "NKE", "^GSPC", "^IXIC", "BILL", "EXPE", 'LINK-USD', "TTWO", "NET", 'ICP-USD', 'FET-USD', 'FIL-USD', 'THETA-USD','AVAX-USD', 'HBAR-USD', 'UNI-USD', 'STX-USD', 'OM-USD', 'FTM-USD', "INJ-USD", "INTC", "SQ", "XOM", "COST", "BP", "BAC", "JPM", "GS", "CVX", "BA", "PFE", "PYPL", "SBUX", "DIS", "NFLX", 'GOOG', "NVDA", "JNJ", "META", "GOOGL", "AAPL", "MSFT", "BTC-EUR", "CRO-EUR", "ETH-USD", "CRO-USD", "BTC-USD", "BNB-USD", "XRP-USD", "ADA-USD", "SOL-USD"])
95+
parser.add_argument("--tickers_val", nargs='+', default=["AMZN", "AMD", "ETH-EUR", "ELF", "UBER"])
96+
parser.add_argument("--etf_tickers", nargs='+', default=["^GSPC", "^FCHI", "^IXIC","EBIT.TO", "BTC-USD"])
97+
args = parser.parse_args()
98+
99+
if args.mode == "inference":
100+
inference(args.tickers)
101+
elif args.mode == "evaluate":
102+
evaluate(args.model, args.tickers)
103+
elif args.mode == "train":
104+
train(args.model, args.tickers, args.tickers_val, args.etf_tickers)
105+
106+
if __name__ == "__main__":
107+
main()
108+

0 commit comments

Comments
 (0)