|
1 | | -# TODO |
2 | | -# - train function to cmd |
3 | | -# - eval to cmd |
4 | | -# - deploy model ? (api, scheduler, frontend, redis) |
5 | | -# - deploy managing multiple type of model |
6 | | -# - train model_list |
7 | | -# - eval pth file, model type |
| 1 | +import argparse |
| 2 | +from torch import cuda, device, load, no_grad, tensor, nn |
| 3 | +from torch.utils.data import DataLoader |
| 4 | +from pandas import DataFrame |
| 5 | + |
| 6 | +# Import necessary modules from your package |
| 7 | +from hygdra_forecasting.utils.preprocessing import ( |
| 8 | + ohlv_to_dataframe_inference, dataframe_to_dataset_inference |
| 9 | +) |
| 10 | +from hygdra_forecasting.model.build import ConvCausalLTSM, LtsmAttentionforecastPred |
| 11 | +from hygdra_forecasting.model.build_graph import GraphforecastPred, GraphTransformerforecastPred |
| 12 | + |
| 13 | +from hygdra_forecasting.dataloader.dataloader import StockDataset |
| 14 | +from hygdra_forecasting.dataloader.GraphDataloader import StockGraphDataset |
| 15 | +from hygdra_forecasting.model.train import train_model, setup_seed |
| 16 | +from hygdra_forecasting.model.eval import validate |
| 17 | +from hygdra_forecasting.utils.learning_rate_sheduler import CosineWarmup |
| 18 | +from liquidnet.vision_liquidnet import VisionLiquidNet |
| 19 | + |
| 20 | + |
| 21 | +def get_device(): |
| 22 | + if cuda.is_available(): |
| 23 | + print("Running on the GPU") |
| 24 | + return device("cuda:0") |
| 25 | + print("Running on the CPU") |
| 26 | + return device("cpu") |
| 27 | + |
| 28 | +def load_model(model_name, input_shape, checkpoint_path=None): |
| 29 | + if model_name == "ConvCausalLTSM": |
| 30 | + model = ConvCausalLTSM(input_shape=input_shape) |
| 31 | + elif model_name == "LtsmAttentionforecastPred": |
| 32 | + model = LtsmAttentionforecastPred(input_shape=input_shape) |
| 33 | + elif model_name == "VisionLiquidNet": |
| 34 | + model = VisionLiquidNet(64, 10) |
| 35 | + elif model_name == "GraphTransformerforecastPred": |
| 36 | + model = GraphTransformerforecastPred(input_shape=input_shape) |
| 37 | + elif model_name == "GraphforecastPred": |
| 38 | + model = GraphforecastPred(input_shape=input_shape) |
| 39 | + else: |
| 40 | + raise ValueError("Unknown model type") |
| 41 | + |
| 42 | + if checkpoint_path: |
| 43 | + model.load_state_dict(load(checkpoint_path, weights_only=True)) |
| 44 | + model.eval() |
| 45 | + return model |
| 46 | + |
| 47 | +def inference(tickers): |
| 48 | + df, dict_unorm = ohlv_to_dataframe_inference(tickers) |
| 49 | + sequences_dict = dataframe_to_dataset_inference(df, tickers) |
| 50 | + input_shape = sequences_dict[tickers[0]][0].shape |
| 51 | + model = load_model("ConvCausalLTSM", input_shape, "weight/basemodel.pt") |
| 52 | + |
| 53 | + df_result = DataFrame() |
| 54 | + df_result.index = df[tickers[0] + "_close"].iloc[-100:].index |
| 55 | + for ticker in tickers: |
| 56 | + sequence = tensor(sequences_dict[ticker]).float() |
| 57 | + with no_grad(): |
| 58 | + predictions = model(sequence) |
| 59 | + predictions = predictions.squeeze().numpy() |
| 60 | + df_result[ticker + "_pred"] = predictions.reshape(-1)[-100:] * dict_unorm[ticker][1] + dict_unorm[ticker][0] |
| 61 | + df_result[ticker + "_close"] = df[ticker + "_close"].iloc[-100:] * dict_unorm[ticker][1] + dict_unorm[ticker][0] |
| 62 | + |
| 63 | + print(df_result) |
| 64 | + |
| 65 | +def evaluate(model_name, tickers): |
| 66 | + dataset = StockDataset(ticker=tickers) |
| 67 | + dataloader = DataLoader(dataset, batch_size=256, shuffle=True, num_workers=1) |
| 68 | + input_sample, _ = dataset.__getitem__(0) |
| 69 | + model = load_model(model_name, input_sample.shape, "weight/epoch-380_loss-0.2699198153614998.pt") |
| 70 | + criterion = nn.L1Loss() |
| 71 | + print(validate(model, dataloader, criterion)) |
| 72 | + |
| 73 | +def train(model_name, tickers, tickers_val, etf_tickers): |
| 74 | + if "graph" in model_name : |
| 75 | + dataset = StockGraphDataset(ticker=tickers, indics=etf_tickers) |
| 76 | + dataset_val = StockGraphDataset(ticker=tickers_val, indics=etf_tickers) |
| 77 | + else : |
| 78 | + dataset = StockDataset(ticker=tickers) |
| 79 | + dataset_val = StockDataset(ticker=tickers_val) |
| 80 | + |
| 81 | + dataloader = DataLoader(dataset, batch_size=32, shuffle=True, num_workers=1) |
| 82 | + dataloader_val = DataLoader(dataset_val, batch_size=32, shuffle=True, num_workers=1) |
| 83 | + |
| 84 | + input_sample, _ = dataset.__getitem__(0) |
| 85 | + setup_seed(20) |
| 86 | + model = load_model(model_name, input_sample.shape) |
| 87 | + train_model(model, dataloader, val_dataloader=dataloader_val, epochs=100, learning_rate=0.01, |
| 88 | + lrfn=CosineWarmup(0.01, 100).lrfn) |
| 89 | + |
| 90 | +def main(): |
| 91 | + parser = argparse.ArgumentParser() |
| 92 | + parser.add_argument("--mode", choices=["inference", "evaluate", "train"], required=True) |
| 93 | + parser.add_argument("--model", choices=["ConvCausalLTSM", "LtsmAttentionforecastPred", "GraphforecastPred", "GraphTransformerforecastPred"], required=True) |
| 94 | + parser.add_argument("--tickers", nargs='+', default=["DEFI", "PANW", "MRVL", "NKLA", "AFRM", "EBIT.TO", "^FCHI", "NKE", "^GSPC", "^IXIC", "BILL", "EXPE", 'LINK-USD', "TTWO", "NET", 'ICP-USD', 'FET-USD', 'FIL-USD', 'THETA-USD','AVAX-USD', 'HBAR-USD', 'UNI-USD', 'STX-USD', 'OM-USD', 'FTM-USD', "INJ-USD", "INTC", "SQ", "XOM", "COST", "BP", "BAC", "JPM", "GS", "CVX", "BA", "PFE", "PYPL", "SBUX", "DIS", "NFLX", 'GOOG', "NVDA", "JNJ", "META", "GOOGL", "AAPL", "MSFT", "BTC-EUR", "CRO-EUR", "ETH-USD", "CRO-USD", "BTC-USD", "BNB-USD", "XRP-USD", "ADA-USD", "SOL-USD"]) |
| 95 | + parser.add_argument("--tickers_val", nargs='+', default=["AMZN", "AMD", "ETH-EUR", "ELF", "UBER"]) |
| 96 | + parser.add_argument("--etf_tickers", nargs='+', default=["^GSPC", "^FCHI", "^IXIC","EBIT.TO", "BTC-USD"]) |
| 97 | + args = parser.parse_args() |
| 98 | + |
| 99 | + if args.mode == "inference": |
| 100 | + inference(args.tickers) |
| 101 | + elif args.mode == "evaluate": |
| 102 | + evaluate(args.model, args.tickers) |
| 103 | + elif args.mode == "train": |
| 104 | + train(args.model, args.tickers, args.tickers_val, args.etf_tickers) |
| 105 | + |
| 106 | +if __name__ == "__main__": |
| 107 | + main() |
| 108 | + |
0 commit comments