diff --git a/agent/SARL/encoder/util.py b/agent/SARL/encoder/util.py index 8564f2d3..254f6164 100644 --- a/agent/SARL/encoder/util.py +++ b/agent/SARL/encoder/util.py @@ -10,30 +10,20 @@ def prepart_m_lstm_data(df, num_day, technical_indicator): tic_list = df.tic.unique() df_list = [] label_list = [] - for index in df.index.unique()[num_day:]: - dfs = [] + for tic in tqdm(tic_list): + labels = [] - df_date = df[[ - True if i in range(index - num_day, index) else False - for i in df.index - ]] - for tic in tic_list: - df_tic = df_date[df_date.tic == tic] - np_tic = df_tic[technical_indicator].to_numpy() - # print(np_tic.shape) - dfs.append(np_tic) - old_price = float(df_tic[df_tic.index == index - 1].close) - new_price = float(df[(df.index == index) * (df.tic == tic)].close) - if new_price > old_price: - label = 1 - else: - label = 0 - labels.append(label) + df_tic = df[df.tic == tic] + dfs = df_tic[technical_indicator].to_numpy() + old_prices = df_tic[num_day-1:len(df_tic)-1]['close'].astype(float).values + new_prices = df_tic[num_day:len(df_tic)]['close'].astype(float).values + + labels = (new_prices > old_prices).astype(int) + label_list.append(np.expand_dims(labels, axis=1)) df_list.append(dfs) - label_list.append(labels) - label_list = np.array(label_list) - df_list = np.array(df_list) - return label_list, df_list + label_list = np.concatenate(label_list, axis=1) + df_list = np.array(df_list) + return label_list, df_list def prepart_lstm_data(df, num_day, technical_indicator): @@ -119,18 +109,18 @@ def dict_to_args(**kwargs): args = parser.parse_args() return args - class m_lstm_dataset(Dataset): - def __init__(self, df_list, label_list): + def __init__(self, df_list, label_list,num_day): self.df = df_list self.label = label_list + self.num_day = num_day self.X = torch.from_numpy(self.df).float() self.y = torch.from_numpy(self.label).float() def __len__(self): - return self.df.shape[0] + return self.label.shape[0] def __getitem__(self, idx): - X = self.X[idx, :, :, :] + X = self.X[:, idx:idx+self.num_day, :] y = self.y[idx, :] return X, y diff --git a/docs/source/introduction.md b/docs/source/introduction.md index 182b73c0..ff6517a6 100644 --- a/docs/source/introduction.md +++ b/docs/source/introduction.md @@ -16,9 +16,8 @@ Architecture of Trademaster framework could be visualizaed by the figure below. TradeMaster is evaluated in multiple dimenstions. Financial metrics like profit and risk metrics are applied. Additionally, decision tree and shapley value are used to evaluate the explainability of the model. Variability and Alpha decay are used for reliability evaluation. -
- -
+![Architecture.jpg](../../figure/Architecture.jpg) + ## Supported Trading Scenario diff --git a/docs/source/script/yahoo.md b/docs/source/script/yahoo.md index b8616842..e36ea3a0 100644 --- a/docs/source/script/yahoo.md +++ b/docs/source/script/yahoo.md @@ -1,3 +1,14 @@ # Download Data from Yahoo Finance +In order to build up your own dataset, Yahoo Finance is an open-source platform where you can get access to various types of financial market data such as US stock, forex and cryptocurrency via Yahoo Finance python API(yfinance). +Here is an example of script downloading Apple's stock data from yfinance, which contains the open, high, low, close, adjusted close price and volume. + + ``` + import yfinance as yf + start_date='2009-01-02' + end_date='2021-01-01' + df = yf.download('AAPL', start=start_date, end=end_date, interval='1d') + ``` +By modifying the instructions, you can customize your downloaded dataset. + diff --git a/docs/source/tool/csdi.md b/docs/source/tool/csdi.md index 067cb6f7..2592f74d 100644 --- a/docs/source/tool/csdi.md +++ b/docs/source/tool/csdi.md @@ -1,3 +1,9 @@ # Missing Value Imputation with CSDI +Most of the raw data retrieved from different data sources consist of missing values (NaN values), and the most common method of dealing with missing values is directly dropping them. However, we provide an alternative solution by using the imputation model proposed in the following paper. +[CSDI: Conditional Score-based Diffusion Models for Probabilistic Time Series Imputation (Yusuke Tashiro, etc.)](https://arxiv.org/abs/2107.03502) *NeurIPS 2021* + +CSDI is a diffusion model which generates missing values in raw data by diffusion process using observed values as conditional input. The model is trained by optimizing an unsupervised task: recovery of a certain ratio of masked observed data by using the rest observed data as conditional input. When performing real imputation on datasets, all missing values are imputation targets and all observed values serve as conditional input. Please refer to the original paper if you have any enquiries about the methodology. + +We implement the model into a ready-to-use toolbox for missing value imputation of financial data. Please refer to [CSDI for financial data imputation](https://github.com/ZONG0004/TradeMaster/blob/main/data/CSDI/README.md) for detailed guideline of usage and visualization results. diff --git a/docs/source/tool/example_figs/FInal_compass.png b/docs/source/tool/example_figs/FInal_compass.png new file mode 100644 index 00000000..c2684b52 Binary files /dev/null and b/docs/source/tool/example_figs/FInal_compass.png differ diff --git a/docs/source/tool/example_figs/Final-compass.svg b/docs/source/tool/example_figs/Final-compass.svg new file mode 100644 index 00000000..d7b1ba65 --- /dev/null +++ b/docs/source/tool/example_figs/Final-compass.svg @@ -0,0 +1,4133 @@ + +image/svg+xmlRiskControl +Proftability +Explainability +Reliability +Diversity +Universality +c +o +u +n +t +r +y +a +s +s +e +t +t +y +p +e +t +i +m +e +- +s +c +a +l +e +r +i +s +k +r +i +s +k +- +a +d +j +u +s +t +e +d +e +x +t +r +e +m +e +m +a +r +k +e +t +p +r +o +f +i +t +a +l +p +h +a +- +d +e +c +a +y +e +q +u +i +t +y +c +u +r +v +e +p +r +o +f +i +l +e +v +a +r +i +a +b +i +l +i +t +y +r +a +n +k +o +r +d +e +r +t +- +S +N +E +e +n +t +r +o +p +y +c +o +r +r +e +l +a +t +i +o +n +r +o +l +l +i +n +g +w +i +n +d +o +w +A2C +PPO +SAC +SARL +DeepTrader +AlphaMix+ + \ No newline at end of file diff --git a/docs/source/tool/example_figs/Radar_plot.png b/docs/source/tool/example_figs/Radar_plot.png new file mode 100644 index 00000000..6443926a Binary files /dev/null and b/docs/source/tool/example_figs/Radar_plot.png differ diff --git a/docs/source/tool/example_figs/crypto.svg b/docs/source/tool/example_figs/crypto.svg new file mode 100644 index 00000000..0382eee9 --- /dev/null +++ b/docs/source/tool/example_figs/crypto.svg @@ -0,0 +1,420 @@ + +image/svg+xmlA2CPPOSACSARLDTAlphaMix+202468ScoreaverageTRSR diff --git a/docs/source/tool/example_figs/dm_result_1.png b/docs/source/tool/example_figs/dm_result_1.png new file mode 100644 index 00000000..5df9c423 Binary files /dev/null and b/docs/source/tool/example_figs/dm_result_1.png differ diff --git a/docs/source/tool/example_figs/dm_result_2.png b/docs/source/tool/example_figs/dm_result_2.png new file mode 100644 index 00000000..0ec54618 Binary files /dev/null and b/docs/source/tool/example_figs/dm_result_2.png differ diff --git a/docs/source/tool/example_figs/dm_result_3.png b/docs/source/tool/example_figs/dm_result_3.png new file mode 100644 index 00000000..287dd661 Binary files /dev/null and b/docs/source/tool/example_figs/dm_result_3.png differ diff --git a/docs/source/tool/example_figs/octagon.PNG b/docs/source/tool/example_figs/octagon.PNG new file mode 100644 index 00000000..d80de04a Binary files /dev/null and b/docs/source/tool/example_figs/octagon.PNG differ diff --git a/docs/source/tool/example_figs/octagon/A2C.svg b/docs/source/tool/example_figs/octagon/A2C.svg new file mode 100644 index 00000000..509d1eac --- /dev/null +++ b/docs/source/tool/example_figs/octagon/A2C.svg @@ -0,0 +1,1043 @@ + +image/svg+xmlTRSRCRSoRVolMDDENTENB diff --git a/docs/source/tool/example_figs/octagon/AlphaMix.svg b/docs/source/tool/example_figs/octagon/AlphaMix.svg new file mode 100644 index 00000000..cb38b325 --- /dev/null +++ b/docs/source/tool/example_figs/octagon/AlphaMix.svg @@ -0,0 +1,1043 @@ + +image/svg+xmlTRSRCRSoRVolMDDENTENB diff --git a/docs/source/tool/example_figs/octagon/DeepTrader.svg b/docs/source/tool/example_figs/octagon/DeepTrader.svg new file mode 100644 index 00000000..75cb140f --- /dev/null +++ b/docs/source/tool/example_figs/octagon/DeepTrader.svg @@ -0,0 +1,1043 @@ + +image/svg+xmlTRSRCRSoRVolMDDENTENB diff --git a/docs/source/tool/example_figs/octagon/PPO.svg b/docs/source/tool/example_figs/octagon/PPO.svg new file mode 100644 index 00000000..113d9267 --- /dev/null +++ b/docs/source/tool/example_figs/octagon/PPO.svg @@ -0,0 +1,1043 @@ + +image/svg+xmlTRSRCRSoRVolMDDENTENB diff --git a/docs/source/tool/example_figs/octagon/SAC.svg b/docs/source/tool/example_figs/octagon/SAC.svg new file mode 100644 index 00000000..3159d22e --- /dev/null +++ b/docs/source/tool/example_figs/octagon/SAC.svg @@ -0,0 +1,1043 @@ + +image/svg+xmlTRSRCRSoRVolMDDENTENB diff --git a/docs/source/tool/example_figs/octagon/SARL.svg b/docs/source/tool/example_figs/octagon/SARL.svg new file mode 100644 index 00000000..31bb13c9 --- /dev/null +++ b/docs/source/tool/example_figs/octagon/SARL.svg @@ -0,0 +1,1043 @@ + +image/svg+xmlTRSRCRSoRVolMDDENTENB diff --git a/docs/source/tool/example_figs/overall.svg b/docs/source/tool/example_figs/overall.svg new file mode 100644 index 00000000..c77a627c --- /dev/null +++ b/docs/source/tool/example_figs/overall.svg @@ -0,0 +1,480 @@ + +image/svg+xml020406080100total return score ()0.000.250.500.751.00Fraction of runs with score >A2CPPOSACSARLDeepTraderAlphaMix+ diff --git a/docs/source/tool/example_figs/rank.svg b/docs/source/tool/example_figs/rank.svg new file mode 100644 index 00000000..14f21307 --- /dev/null +++ b/docs/source/tool/example_figs/rank.svg @@ -0,0 +1,2043 @@ + +image/svg+xml123456020406080100TR123456020406080100SR123456020406080100VOL123456020406080100EntropyFraction (in %)A2CPPOSACSARLDeepTraderAlphaMix+ diff --git a/docs/source/tool/prudex.md b/docs/source/tool/prudex.md index a65aa362..a82890cb 100644 --- a/docs/source/tool/prudex.md +++ b/docs/source/tool/prudex.md @@ -1,3 +1,181 @@ -# PRUDEX-Compass: Systematic Evaluation Toolkits +# Evaluation Toolkits: PRUDEX-Compass +- **PRUDEX-Compass** is designed for **benchmarking** methods with 6 axes and 16 measures. +- This is an official implementation of [PRUDEX-Compass: Towards Systematic Evaluation of Reinforcement Learning in Financial Markets](https://arxiv.org/abs/2302.00586). +- **PRUDEX-Compass** is an independt tool and is not intergarted in the pipline for now. +- We use **FinRL** methods to demonstrate how **PRUDEX-Compass** works. + + +## Install +To install the dependency of `PRUDEX-Compass`, run the command: +``` +pip install -r requirements.txt +``` +## Usages and Examples + +### PRUDEX-Compass +- The PRUDEX-Compass gives high-level evaluations +- Use the [`create_compass.py`](https://github.com/ai-gamer/PRUDEX-Compass/blob/main/Compass/generate/compass/create_compass.py) python script to fill in the templet and get the compass. + +
+ +
+ +#### Example Usage +The default setting reads the template file from [`blank.tex`](https://github.com/ai-gamer/PRUDEX-Compass/blob/main/Compass/generate/compass/blank.tex) and writes the filled output file into [`filled.tex`](https://github.com/ai-gamer/PRUDEX-Compass/blob/main/Compass/generate/compass/filled.tex) with the data specified via --data : +``` +$ python Compass/generate/compass/create_compass.py--data Compass/generate/compass/data.json +``` +The result is an .tex file [`filled.tex`](https://github.com/ai-gamer/PRUDEX-Compass/blob/main/Compass/generate/compass/filled.tex) +#### JSON Data Format +The JSON file specifies a list of entries, where each element defines a `color`, `label`, `inner_level`, and `outer_level`. The latter two specify the attributes visualized in the compass. + +`color`: Can be one of `["magenta", "green", "blue", "orange", "cyan", "brown"]`. + +`label`: A label describing the compass entry. + +`inner_level`: Specifies the inner compass level attributes. Attribute values must be between 1 and 100. + +`outer_level`: Specifies the outer compass level attributes. Attribute values must boolean `(true/false)`. + +We provide an example data file [`Compass/generate/compass/data.json`](https://github.com/ai-gamer/PRUDEX-Compass/blob/main/Compass/generate/compass/data.json) + + + +### PRIDE-Star +- The PRIDE-Star gives metric-level evaluation of profitability,risk-control and diversity + + + + + + + + + + + + + + + + + +
(a) A2C
(b) PPO
(c) SAC
(d) SARL
(e) DeepTrader
(f) AlphaMix+
+ + +
+ +The file structure of templates for `PRIDE-Star` is as following: +``` +-- PRIDE-Star + |-- A2C.tex + |-- Alphamix+.tex + |-- DeepTrader.tex + |-- PPO.tex + |-- SAC.tex + |-- SARL.tex + |-- blank.tex +``` +Here we provide a [blank tex](https://github.com/ai-gamer/PRUDEX-Compass/blob/main/Compass/generate/PRIDE-Star/blank.tex) that you can play with + +### Performance Profile +- The performance profile reports methods' score distribution of all runs across the different +financial markets that are statistically unbiased and more robust to outliers. + +
+ +
+ + + +#### Example usage +Prepare your result as `overall_dict={method name:result(seed,task)}` and use the following example code +to generate the distribution. Notice that we only use one metrics (`total return`) in the example. + +``` +import numpy as np +import pandas as pd +import matplotlib.pyplot as plt +import scipy.stats +from rliable import library as rly +from rliable import metrics +from rliable import plot_utils +import seaborn as sns +sns.set_style("white") +import matplotlib.patches as mpatches +import collections +import os +from Compass.distribution.distribution import make_distribution_plot +colors = ['moccasin','aquamarine','#dbc2ec','orchid','lightskyblue','pink','orange'] +xlabels = ['A2C','PPO','SAC','SARL','DeepTrader',"AlphaMix+"] +color_idxs = [0, 1,2,3,4,5,6] +ATARI_100K_COLOR_DICT = dict(zip(xlabels, [colors[idx] for idx in color_idxs])) +from scipy.stats.stats import find_repeats +xlabel=r'total return score $(\tau)$', +dict=tt_dict_crypto +algorithms = ['A2C','PPO','SAC','SARL','DeepTrader',"AlphaMix+"] +make_distribution_plot(dict,algorithms,2000,xlabel,"./distribution",ATARI_100K_COLOR_DICT) +``` + +For more precise information, please refer to [`Compass/generate/distribution/distribution.py`](https://github.com/ai-gamer/PRUDEX-Compass/blob/main/Compass/generate/distribution/distribution.py) + +### Rank Distribution +- This visualization shows the distribution of methods' performance ranking. + + +
+ + +
+ +#### Example usage +Prepare result into `dmc_scores={metric name: metric result dictionary}` and use the following example code +to generate the graph. + +``` +from Compass.generate.rank.rank import subsample_scores_mat,get_rank_matrix,make_rank_plot +dmc_scores = {} +dmc_scores["TR"]=tt_dict +dmc_scores["SR"]=sr_dict +dmc_scores["CR"]=cr_dict +dmc_scores["SoR"]=sor_dict +dmc_scores["VOL"]=vol_dict +dmc_scores["Entropy"]=Entropy_dict +indicator_list=['TR','SR','VOL','Entropy'] +algs=['A2C','PPO','SAC','SARL','DeepTrader','AlphaMix+'] +colors=['moccasin','aquamarine','#dbc2ec','salmon','lightskyblue','pink','orange'] +make_rank_plot(algs,indicator_list,".rank.pdf",colors) +``` +For more information, please refer to [`rank.py`](https://github.com/ai-gamer/PRUDEX-Compass/blob/main/Compass/generate/rank/rank.py) +### Performance under Extreme Markets +- The performance under extreme markets reflects methods' ability to deal with black swan events in terms of total return and sharpe ratio compared with uniform policy. +
+ +
+ + +#### Example usage + +Select a volatile period and get the daily return rate and use the following example code +to generate the graph. + +``` +from Compass.generate.exen.exen import evaualte,plot_pictures +A2C=pd.read_csv("./A2C.csv",index_col=0) +SARL=pd.read_csv("./SARL.csv",index_col=0) +DeepTrader=pd.read_csv("./DeepTrader.csv",index_col=0) +PPO=pd.read_csv("./PPO.csv",index_col=0) +SAC=pd.read_csv("./SAC.csv",index_col=0) +AlphaMix=pd.read_csv("AlphaMix+.csv",index_col=0) +path=".exen.pdf" +plot_pictures(new_models,path) +``` +For more information, please refer to [`exen.py`](https://github.com/ai-gamer/PRUDEX-Compass/blob/main/Compass/generate/exen/exen.py) + +## Acknowledgements +This repository is developed based on: +[RLKit](https://github.com/rail-berkeley/rlkit), +[FinRL](https://github.com/AI4Finance-Foundation/FinRL), +[Sunrise](https://github.com/pokaxpoka/sunrise) diff --git a/docs/source/tool/style.md b/docs/source/tool/style.md index 393aa49e..c9550919 100644 --- a/docs/source/tool/style.md +++ b/docs/source/tool/style.md @@ -1,3 +1,100 @@ -# Market Dymamics Modelling +# Evaluation Toolbox: Market Dynamics Modeling + +## Introduction +The evaluation toolbox provides a sandbox for user to evaluate their policy under different scenarios . +The toolbox shows visualizations and reports to assist user compare policies across market dynamic. + +## Market Dynamics Modeling +The Market Dynamics modeling is a module to label raw data with dynamics that is interpretable. +The dynamics are used as meta-information. For example, in the evaluation process, user can run evaluation on specific dynamics. + +## Usage & Example +The Evaluation Toolbox module prepare data for evaluation, to run a full test you should follow this pipeline: +- Run the [`run.py`]() in tools/market_dynamics_labeling or [`linear_model.py`]() to prepare the dataset + 1. Tune the parameters based on the visualization results +
+ +
+ 1. Increase `length_limit` +
+ +
+ 1. Modify `labeling_parameters` +
+ +
+- Update the 'test_style_path' in the config files to the dataset path you get from previous step. + +- Run the trainer with arguments `--task dynamics_test --test_dynamic dynamic_label` to perform evaluation on specific market dynamic. You will get reports and visualization result. +
+ +
+#### Parameters +- `fitting_parameters`: This is a set of parameters for the filter, please refer to the comment in lines for detailed description. +- `labeling_parameters`: This is a set of parameters for regime classification, please refer to the comment in lines for detailed description. +- `regime_number`: This is the number of regimes. +- `length_limit`: This is the minimum length of a consecutive time-series of same regime. + +#### Scoring +The scores of the visualization result are calculated as described: +- Do nothing metrics are used as score 0 +- Blind Buy metrics are used as score 50 (-50 if worse than Do Nothing) +- The score of other agents are given based on the assumption that the scores are following a normal distribution (50,$\sqrt{50}$) +##### Baselines + - Buy and Hold: This is and ideal policy where you spend all your cash on the first tick. + - Blind Buy: Continues buy until the cash runs out. + - Do Nothing: Do nothing + + + + +## Examples +### Use Market Dynamics Model to prepare evaluation datasets +It is recommended to run through the trademaster/evaluation/market_dynamics_labeling/example.ipynb notebook to visualize the labeling process. This will also give hints on +deciding the parameters for your dataset. The example.html contains the visualization results from example.ipynb. + +An example of labeling the data +
+ ``` + $ python tools/MarketRegimeLabeling/Label.py --data_path data/algorithmic_trading/BTC/test.csv --method linear --fitting_parameters 2/7 2/14 4 --labeling_parameters -0.5 0.5 + ``` + +DJ30 + ``` + $ python tools/MarketRegimeLabeling/Label.py --data_path data/portfolio_management/dj30/test.csv --method linear --fitting_parameters 2/7 2/14 4 --labeling_parameters -0.25 0.25 --regime_number 3 --length_limit 24 + ``` +for DJ30 applied in PM use-case, we would like to define the market regime based on DJ30 index. We have provided an example of +test_labeled_3_24.csv which is DJI_labeled_3_24.csv and test.csv merged on 'date' where DJI_labeled_3_24.csv is got from running + +DJI index + ``` + $ python tools/MarketRegimeLabeling/Label.py --data_path data/portfolio_management/dj30/DJI.csv --method linear --fitting_parameters 2/7 2/14 4 --labeling_parameters -0.25 0.25 --regime_number 3 --length_limit 24 --PM data/portfolio_management/dj30/test.csv + ``` + +BTC + + $ python tools/MarketRegimeLabeling/Label.py --data_path data/algorithmic_trading/BTC/test.csv --method linear --fitting_parameters 2/7 2/14 4 --labeling_parameters -0.15 0.15 --regime_number 3 --length_limit 24 +PD_BTC + + $ python tools/MarketRegimeLabeling/Label.py --data_path data/order_execution/PD_BTC/test.csv --method linear --fitting_parameters 2/7 2/14 4 --labeling_parameters -0.15 0.15 --regime_number 3 --length_limit 24 + +OE_BTC + + $ python tools/MarketRegimeLabeling/Label.py --data_path data/order_execution/BTC/test.csv --method linear --fitting_parameters 2/7 2/14 4 --labeling_parameters -0.01 0.01 --regime_number 3 --length_limit 32 --OE_BTC True + +Exchange + + $ python tools/MarketRegimeLabeling/Label.py --data_path data/portfolio_management/exchange/test.csv --method linear --fitting_parameters 2/7 2/14 4 --labeling_parameters -0.05 0.05 --regime_number 3 --length_limit 24 + + +
+ + +The script will take in a data file and output the file with a market regime label column. Besides the market label, we also provide a stock group label column based on DWT clustering. + +### Testing agent under a specific market dynamic +``` + python tools/algorithmic_trading/train.py --task_name dynamics_test --test_dynamic 0 +``` diff --git a/docs/source/tutorial/DeepScalper.jpg b/docs/source/tutorial/DeepScalper.jpg new file mode 100644 index 00000000..a3ad4528 Binary files /dev/null and b/docs/source/tutorial/DeepScalper.jpg differ diff --git a/docs/source/tutorial/EIIE.jpg b/docs/source/tutorial/EIIE.jpg new file mode 100644 index 00000000..614f2553 Binary files /dev/null and b/docs/source/tutorial/EIIE.jpg differ diff --git a/docs/source/tutorial/tutorial1.md b/docs/source/tutorial/tutorial1.md index 03f34dfd..a9f257a0 100644 --- a/docs/source/tutorial/tutorial1.md +++ b/docs/source/tutorial/tutorial1.md @@ -1,20 +1,18 @@ # Tutorial 1: Intraday Crypto Trading with DeepScalper +![DeepScalper.png](DeepScalper.png) -## Task Intraday trading is a fundamental quantitative trading task, where traders actively long/short one pre-selected financial asset within the same trading day to maximize future profit. -## Algorithm -DeepScalper contains 4 technical contributions which all together make it better than direct use of RL algorithms. -- RL optimization with action branching -- reward function with hindsight bonus -- intraday market embedding -- risk-aware auxiliary task +DeepScalper use deep q network to optimize the reward sum got from reinforcement learning where a hindsight reward is used to capture the long-term porfit trends and embedding from both micro-level and macro-level market information. -Here is the construction of the DeepScalper: -
- -
-Here is a [tutorial](https://github.com/DVampire/TradeMasterReBuild/tree/main/tutorial/DeepScalper.ipynb) about how you can build DeepScalper in a few lines of codes using TradeMaster. \ No newline at end of file +## Notebook and Script +In this notebook, we implement the training and testing process of DeepScalper based on the TradeMaster framework. + +[Tutorial1_DeepScalper](https://github.com/TradeMaster-NTU/TradeMaster/blob/main/tutorial/Tutorial1_DeepScalper.ipynb) + +And this is the script for training and testing. + +[train.py](https://github.com/TradeMaster-NTU/TradeMaster/blob/1.0.0/tools/algorithmic_trading/train.py) \ No newline at end of file diff --git a/docs/source/tutorial/tutorial2.md b/docs/source/tutorial/tutorial2.md index 623dfc66..3a6a81f5 100644 --- a/docs/source/tutorial/tutorial2.md +++ b/docs/source/tutorial/tutorial2.md @@ -1,19 +1,20 @@ # Tutorial 2: Portfolio Management with EIIE on US stocks - -## Task +![EIIE.png](EIIE.png) Portfolio management is the action of continuous reallocation of a capital into a number of financial assets periodically. -## Algorithm -EIIE contains 2 technical contributions which all together make it better than direct use of RL algorithms. -- Deterministic Policy Gradient -- Portfolio-Vector Memor +The framework consists of the Ensemble of Identical Independent Evaluators +(EIIE) topology, a Portfolio-Vector Memory (PVM), an Online Stochastic Batch Learning +(OSBL) scheme, and a fully exploiting and explicit reward function. + + + -Here is the construction of the EIIE: -
- -
+## Notebook and Script +In this notebook, we implement the training and testing process of EIIE based on the TradeMaster framework. +[Tutorial2_EIIE](https://github.com/TradeMaster-NTU/TradeMaster/blob/main/tutorial/Tutorial2_EIIE.ipynb) +And this is the script for training and testing. -Here is a [tutorial](https://github.com/DVampire/TradeMasterReBuild/tree/main/tutorial/EIIE.ipynb) about how you can build EIIE in a few lines of codes using TradeMaster. \ No newline at end of file +[train_eiie.py](https://github.com/TradeMaster-NTU/TradeMaster/blob/1.0.0/tools/portfolio_management/train_eiie.py) \ No newline at end of file diff --git a/docs/source/tutorial/tutorial5.md b/docs/source/tutorial/tutorial5.md index 7ee6a7fb..6fa1f6ce 100644 --- a/docs/source/tutorial/tutorial5.md +++ b/docs/source/tutorial/tutorial5.md @@ -1,11 +1,14 @@ # Tutorial 5: High Frequency Trading with Double DQN -## Task High Frequency Trading is a fundamental quantitative trading task, where traders actively buy/sell one pre-selected financial periodically in seconds with the consideration of order execution. -## Algorithm -HFT_DDQN contains 2 technical contributions which all together make it better than direct use of RL algorithms. -- Double deep q network -- Regulator from the true q table +HFT_DDQN use a decayed supervised regulator genereated from the real q table based on the future price information and a double q network to optimizer the portfit margine. -Here is a [tutorial](https://github.com/DVampire/TradeMasterReBuild/tree/main/tutorial/HFT.ipynb) about how you can build DDQN for HFT in a few lines of codes using TradeMaster. \ No newline at end of file +## Notebook and Script +In this notebook, we implement the training and testing process of HFTDDQN based on the TradeMaster framework. + +[Tutorial5_HFT](https://github.com/TradeMaster-NTU/TradeMaster/blob/main/tutorial/Tutorial5_HFT.ipynb) + +And this is the script for training and testing. + +[train.py](https://github.com/TradeMaster-NTU/TradeMaster/blob/1.0.0/tools/high_frequency_trading/train.py) \ No newline at end of file diff --git a/tutorial/Tutorial1_DeepScalper.ipynb b/tutorial/Tutorial1_DeepScalper.ipynb new file mode 100644 index 00000000..816dbb4a --- /dev/null +++ b/tutorial/Tutorial1_DeepScalper.ipynb @@ -0,0 +1,704 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Import Packages\n", + "Modify the system path and load the corresponding packages and functions " + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "from pathlib import Path\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "ROOT = str(Path(\"__file__\").resolve().parents[1])\n", + "sys.path.append(ROOT)\n", + "import torch\n", + "import argparse\n", + "import os.path as osp\n", + "from mmcv import Config\n", + "from trademaster.utils import replace_cfg_vals\n", + "from trademaster.nets.builder import build_net\n", + "from trademaster.environments.builder import build_environment\n", + "from trademaster.datasets.builder import build_dataset\n", + "from trademaster.agents.builder import build_agent\n", + "from trademaster.optimizers.builder import build_optimizer\n", + "from trademaster.losses.builder import build_loss\n", + "from trademaster.trainers.builder import build_trainer\n", + "from trademaster.transition.builder import build_transition" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Load Configs\n", + "Load default config from the folder `configs/algorithmic_trading/algorithmic_trading_BTC_dqn_dqn_adam_mse.py`" + ] + }, + { + "cell_type": "code", + "execution_count": 18, + "metadata": {}, + "outputs": [], + "source": [ + "parser = argparse.ArgumentParser(description='Download Alpaca Datasets')\n", + "parser.add_argument(\"--config\", default=osp.join(ROOT, \"configs\", \"algorithmic_trading\", \"algorithmic_trading_BTC_dqn_dqn_adam_mse.py\"),\n", + " help=\"download datasets config file path\")\n", + "parser.add_argument(\"--task_name\", type=str, default=\"train\")\n", + "parser.add_argument(\"--test_style\", type=str, default='-1')\n", + "args = parser.parse_args([])\n", + "cfg = Config.fromfile(args.config)\n", + "task_name = args.task_name\n", + "\n", + "cfg = replace_cfg_vals(cfg)\n", + "# update test style\n", + "cfg.data.update({'test_style': args.test_style})\n" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Config (path: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/configs/algorithmic_trading/algorithmic_trading_BTC_dqn_dqn_adam_mse.py): {'data': {'type': 'AlgorithmicTradingDataset', 'data_path': 'data/algorithmic_trading/BTC', 'train_path': 'data/algorithmic_trading/BTC/train.csv', 'valid_path': 'data/algorithmic_trading/BTC/valid.csv', 'test_path': 'data/algorithmic_trading/BTC/test.csv', 'test_style_path': 'data/algorithmic_trading/BTC/test_labeled_3_24_-0.15_0.15.csv', 'tech_indicator_list': ['high', 'low', 'open', 'close', 'adjcp', 'zopen', 'zhigh', 'zlow', 'zadjcp', 'zclose', 'zd_5', 'zd_10', 'zd_15', 'zd_20', 'zd_25', 'zd_30'], 'backward_num_day': 5, 'forward_num_day': 5, 'test_style': '-1'}, 'environment': {'type': 'AlgorithmicTradingEnvironment'}, 'agent': {'type': 'AlgorithmicTradingDQN', 'max_step': 12345, 'reward_scale': 1, 'repeat_times': 1, 'gamma': 0.9, 'batch_size': 64, 'clip_grad_norm': 3.0, 'soft_update_tau': 0, 'state_value_tau': 0.005}, 'trainer': {'type': 'AlgorithmicTradingTrainer', 'epochs': 20, 'work_dir': 'work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse', 'seeds_list': (12345,), 'batch_size': 64, 'horizon_len': 1024, 'buffer_size': 1000000.0, 'num_threads': 8, 'if_remove': False, 'if_discrete': True, 'if_off_policy': True, 'if_keep_save': True, 'if_over_write': False, 'if_save_buffer': False}, 'loss': {'type': 'MSELoss'}, 'optimizer': {'type': 'Adam', 'lr': 0.001}, 'act': {'type': 'QNet', 'state_dim': 82, 'action_dim': 3, 'dims': (64, 32), 'explore_rate': 0.25}, 'cri': None, 'transition': {'type': 'Transition'}, 'task_name': 'algorithmic_trading', 'dataset_name': 'BTC', 'optimizer_name': 'adam', 'loss_name': 'mse', 'net_name': 'dqn', 'agent_name': 'dqn', 'work_dir': 'work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse', 'batch_size': 64}" + ] + }, + "execution_count": 19, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cfg" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Build Dataset\n", + "Build datasets from cfg defined above" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = build_dataset(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 21, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Build Reinforcement Learning Environments\n", + "Build environments based on cfg and previously-defined dataset\n", + "\n", + "A style-test is provided as an option to test the algorithm's performance under different market conditions" + ] + }, + { + "cell_type": "code", + "execution_count": 22, + "metadata": {}, + "outputs": [], + "source": [ + "train_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"train\"))\n", + "valid_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"valid\"))\n", + "test_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"test\"))\n", + "if task_name.startswith(\"style_test\"):\n", + " test_style_environments = []\n", + " for i, path in enumerate(dataset.test_style_paths):\n", + " test_style_environments.append(build_environment(cfg, default_args=dict(dataset=dataset, task=\"test_style\",\n", + " style_test_path=path,\n", + " task_index=i)))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 23, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 23, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_environment" + ] + }, + { + "cell_type": "code", + "execution_count": 24, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 24, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "valid_environment" + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 25, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_environment" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: Build Net \n", + "Update information about the state and action dimension in the config and create nets and optimizer for DQN\n" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "action_dim = train_environment.action_dim\n", + "state_dim = train_environment.state_dim\n", + "\n", + "cfg.act.update(dict(action_dim=action_dim, state_dim=state_dim))\n", + "act = build_net(cfg.act)\n", + "act_optimizer = build_optimizer(cfg, default_args=dict(params=act.parameters()))\n", + "if cfg.cri:\n", + " cfg.cri.update(dict(action_dim=action_dim, state_dim=state_dim))\n", + " cri = build_net(cfg.cri)\n", + " cri_optimizer = build_optimizer(cfg, default_args=dict(params=cri.parameters()))\n", + "else:\n", + " cri = None\n", + " cri_optimizer = None" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: Build Loss\n", + "Build loss from config" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [], + "source": [ + "criterion = build_loss(cfg)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 7: Build Transition\n", + "Build transition from config" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "transition = build_transition(cfg)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 8: Build Agent\n", + "Build agent from config and detect device" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "agent = build_agent(cfg, default_args=dict(action_dim = action_dim,\n", + " state_dim = state_dim,\n", + " act = act,\n", + " cri = cri,\n", + " act_optimizer = act_optimizer,\n", + " cri_optimizer = cri_optimizer,\n", + " criterion = criterion,\n", + " transition = transition,\n", + " device=device))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 9: Build Trainer\n", + "Build trainer from config and create work directionary to save the result, model and config" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "| Arguments Keep work_dir: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse\n" + ] + } + ], + "source": [ + "if task_name.startswith(\"style_test\"):\n", + " trainers = []\n", + " for env in test_style_environments:\n", + " trainers.append(build_trainer(cfg, default_args=dict(train_environment=train_environment,\n", + " valid_environment=valid_environment,\n", + " test_environment=env,\n", + " agent=agent,\n", + " device=device)))\n", + "else:\n", + " trainer = build_trainer(cfg, default_args=dict(train_environment=train_environment,\n", + " valid_environment=valid_environment,\n", + " test_environment=test_environment,\n", + " agent=agent,\n", + " device=device))\n", + "\n", + "cfg.dump(osp.join(ROOT, cfg.work_dir, osp.basename(args.config)))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 10: Train the Trainer\n", + "Train the trainer based on the config and get results from workdir" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Episode: [1/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -344.513937% | -0.000214 | 6862.304186 | 6.214777 | -27423.674698 | -0.556031 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode: [1/20]\n", + "+---------------+-------------+-------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+--------------+---------------+\n", + "| 36.243173% | 0.001800 | 4108.651796 | 0.583186 | 62313.081138 | 0.551188 |\n", + "+---------------+-------------+-------------+--------------+--------------+---------------+\n", + "Valid Episode Reward Sum: 56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00001.pth\n", + "Train Episode: [2/20]\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "| 2550.222626% | 0.000294 | 74654.187304 | 0.984371 | 2592355.042556 | 0.737448 |\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "Valid Episode: [2/20]\n", + "+---------------+-------------+-------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+--------------+---------------+\n", + "| 38.089801% | 0.002813 | 2762.654677 | 0.456196 | 83702.883876 | 0.945576 |\n", + "+---------------+-------------+-------------+--------------+--------------+---------------+\n", + "Valid Episode Reward Sum: 57332.169254\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00002.pth\n", + "Train Episode: [3/20]\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "| 2129.816364% | 0.000301 | 60992.463565 | 0.990535 | 2151453.831529 | 0.761704 |\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "Valid Episode: [3/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00003.pth\n", + "Train Episode: [4/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -2334.279839% | -0.000274 | 62958.198404 | 24.260127 | -82729.892836 | -0.704792 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [4/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.064518% | -0.001804 | 4108.595508 | 2.561829 | -14212.530944 | -0.820658 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56365.177630\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00004.pth\n", + "Train Episode: [5/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -1641.409323% | -0.000287 | 43526.174039 | 14.819907 | -97857.697010 | -0.740099 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [5/20]\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| -0.000000% | 0.000000 | 0.000000 | 0.000000 | 0.000000 | 0nan |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "Valid Episode Reward Sum: 0.000000\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00005.pth\n", + "Train Episode: [6/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -1691.332913% | -0.000290 | 46229.282308 | 36.828678 | -42306.683350 | -0.752619 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [6/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -227.824723% | -0.001990 | 3952.099453 | 2.539431 | -15219.010810 | -0.887165 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -60502.612906\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00006.pth\n", + "Train Episode: [7/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -2174.481852% | -0.000292 | 59661.553550 | 47.345066 | -42758.782985 | -0.759986 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [7/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -176.839873% | -0.001765 | 3528.124043 | 2.195090 | -13939.332801 | -0.796095 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -52680.509863\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00007.pth\n", + "Train Episode: [8/20]\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "| 1885.983659% | 0.000443 | 36656.692362 | 0.972584 | 1941130.966521 | 1.154473 |\n", + "+---------------+-------------+--------------+--------------+----------------+---------------+\n", + "Valid Episode: [8/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00008.pth\n", + "Train Episode: [9/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -2724.499907% | -0.000291 | 74498.020183 | 52.412219 | -48016.107359 | -0.744638 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [9/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00009.pth\n", + "Train Episode: [10/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -2031.685492% | -0.000290 | 54747.564770 | 43.636693 | -42363.764305 | -0.757473 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [10/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00010.pth\n", + "Train Episode: [11/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -2399.121800% | -0.000292 | 64957.047457 | 51.684147 | -42618.599118 | -0.761166 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [11/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.053823% | -0.001886 | 4043.229698 | 2.569442 | -14579.510155 | -0.866852 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -57221.583660\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00011.pth\n", + "Train Episode: [12/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -1481.424033% | -0.000284 | 39030.031859 | 31.315723 | -41117.470365 | -0.737378 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [12/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00012.pth\n", + "Train Episode: [13/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -788.802406% | -0.000152 | 19854.508607 | 6.946587 | -50441.493440 | -0.380641 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [13/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -170.396222% | -0.002120 | 3555.765470 | 2.254471 | -16430.348220 | -0.931926 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -60797.589909\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00013.pth\n", + "Train Episode: [14/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -1499.160477% | -0.000287 | 39627.327133 | 30.212022 | -43745.116737 | -0.735405 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [14/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -169.507466% | -0.002174 | 3551.437396 | 2.245583 | -16891.155853 | -0.954286 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -61053.498659\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00014.pth\n", + "Train Episode: [15/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -688.419270% | 0.000683 | 12404.350216 | 2.003406 | 491928.343157 | 1.901195 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [15/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -170.679616% | -0.002148 | 3483.353434 | 2.133488 | -17229.228497 | -0.957504 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -60056.333731\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00015.pth\n", + "Train Episode: [16/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -972.226134% | 0.000515 | 15091.307919 | 2.283666 | 395975.151491 | 1.384039 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [16/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00016.pth\n", + "Train Episode: [17/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -1400.570121% | -0.000281 | 36908.469387 | 29.389929 | -41092.361952 | -0.717731 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [17/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00017.pth\n", + "Train Episode: [18/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -744.326652% | 0.000579 | 12140.578923 | 2.151154 | 379703.556119 | 1.557721 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [18/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -180.970136% | -0.001587 | 3330.163550 | 2.008192 | -12932.140987 | -0.734850 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -41071.900165\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00018.pth\n", + "Train Episode: [19/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -948.403058% | -0.000274 | 24082.344472 | 19.543640 | -39251.932592 | -0.693784 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [19/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -230.134541% | -0.001800 | 4108.651796 | 2.562529 | -14181.348608 | -0.819081 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -56390.914075\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00019.pth\n", + "Train Episode: [20/20]\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "| -1874.527862% | -0.000289 | 51213.539849 | 40.429782 | -42627.916754 | -0.738897 |\n", + "+---------------+-------------+--------------+--------------+---------------+---------------+\n", + "Valid Episode: [20/20]\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| -133.446994% | -0.002186 | 3066.785470 | 1.822390 | -18069.448586 | -0.946982 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Valid Episode Reward Sum: -54479.743152\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/checkpoint-00020.pth\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/best.pth\n", + "Resume checkpoint /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/algorithmic_trading_BTC_dqn_dqn_adam_mse/checkpoints/best.pth\n", + "Test Best Episode\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "| 49.996525% | -0.005364 | 4336.971275 | 1.165761 | -98033.866998 | -1.900394 |\n", + "+---------------+-------------+-------------+--------------+---------------+---------------+\n", + "Test Best Episode Reward Sum: -231100.467385\n", + "train end\n" + ] + } + ], + "source": [ + "if task_name.startswith(\"train\"):\n", + " trainer.train_and_valid()\n", + " trainer.test()\n", + " print(\"train end\")\n", + "elif task_name.startswith(\"test\"):\n", + " trainer.test()\n", + " print(\"test end\")\n", + "elif task_name.startswith(\"style_test\"):\n", + " daily_return_list = []\n", + " for trainer in trainers:\n", + " daily_return_list.extend(trainer.test())\n", + " print('win rate is: ', sum(r > 0 for r in daily_return_list) / len(daily_return_list))\n", + " print(\"style test end\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "HFT", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.15" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "c33605b009166d65f90ad63d824c8e63d22d0973c031452c4b4158e2872c99ad" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorial/Tutorial2_EIIE.ipynb b/tutorial/Tutorial2_EIIE.ipynb new file mode 100644 index 00000000..d79203f2 --- /dev/null +++ b/tutorial/Tutorial2_EIIE.ipynb @@ -0,0 +1,461 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Import Packages\n", + "Modify the system path and load the corresponding packages and functions " + ] + }, + { + "cell_type": "code", + "execution_count": 25, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "from pathlib import Path\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "ROOT = str(Path(\"__file__\").resolve().parents[1])\n", + "sys.path.append(ROOT)\n", + "import torch\n", + "import argparse\n", + "import os.path as osp\n", + "from mmcv import Config\n", + "from trademaster.utils import replace_cfg_vals\n", + "from trademaster.nets.builder import build_net\n", + "from trademaster.environments.builder import build_environment\n", + "from trademaster.datasets.builder import build_dataset\n", + "from trademaster.agents.builder import build_agent\n", + "from trademaster.optimizers.builder import build_optimizer\n", + "from trademaster.losses.builder import build_loss\n", + "from trademaster.trainers.builder import build_trainer\n", + "from trademaster.transition.builder import build_transition" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Load Configs\n", + "Load default config from the folder `configs/portfolio_management/portfolio_management_dj30_eiie_eiie_adam_mse.py`" + ] + }, + { + "cell_type": "code", + "execution_count": 26, + "metadata": {}, + "outputs": [], + "source": [ + "parser = argparse.ArgumentParser(description='Download Alpaca Datasets')\n", + "parser.add_argument(\"--config\", default=osp.join(ROOT, \"configs\", \"portfolio_management\", \"portfolio_management_dj30_eiie_eiie_adam_mse.py\"),\n", + " help=\"download datasets config file path\")\n", + "parser.add_argument(\"--task_name\", type=str, default=\"train\")\n", + "parser.add_argument(\"--test_style\", type=str, default='-1')\n", + "args = parser.parse_args([])\n", + "cfg = Config.fromfile(args.config)\n", + "task_name = args.task_name\n", + "\n", + "cfg = replace_cfg_vals(cfg)\n", + "# update test style\n", + "cfg.data.update({'test_style': args.test_style})\n" + ] + }, + { + "cell_type": "code", + "execution_count": 27, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Config (path: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/configs/portfolio_management/portfolio_management_dj30_eiie_eiie_adam_mse.py): {'data': {'type': 'PortfolioManagementDataset', 'data_path': 'data/portfolio_management/dj30', 'train_path': 'data/portfolio_management/dj30/train.csv', 'valid_path': 'data/portfolio_management/dj30/valid.csv', 'test_path': 'data/portfolio_management/dj30/test.csv', 'tech_indicator_list': ['zopen', 'zhigh', 'zlow', 'zadjcp', 'zclose', 'zd_5', 'zd_10', 'zd_15', 'zd_20', 'zd_25', 'zd_30'], 'length_day': 10, 'initial_amount': 100000, 'transaction_cost_pct': 0.001, 'test_style_path': 'data/portfolio_management/dj30/DJI_label_by_DJIindex_3_24_-0.25_0.25.csv', 'test_style': '-1'}, 'environment': {'type': 'PortfolioManagementEIIEEnvironment'}, 'agent': {'type': 'PortfolioManagementEIIE', 'memory_capacity': 1000, 'gamma': 0.99, 'policy_update_frequency': 500}, 'trainer': {'type': 'PortfolioManagementEIIETrainer', 'epochs': 10, 'work_dir': 'work_dir/portfolio_management_dj30_eiie_eiie_adam_mse', 'if_remove': True}, 'loss': {'type': 'MSELoss'}, 'optimizer': {'type': 'Adam', 'lr': 0.001}, 'act': {'type': 'EIIEConv', 'input_dim': None, 'output_dim': 1, 'time_steps': 10, 'kernel_size': 3, 'dims': [32]}, 'cri': {'type': 'EIIECritic', 'input_dim': None, 'action_dim': None, 'output_dim': 1, 'time_steps': None, 'num_layers': 1, 'hidden_size': 32}, 'transition': {'type': 'Transition'}, 'task_name': 'portfolio_management', 'dataset_name': 'dj30', 'net_name': 'eiie', 'agent_name': 'eiie', 'optimizer_name': 'adam', 'loss_name': 'mse', 'work_dir': 'work_dir/portfolio_management_dj30_eiie_eiie_adam_mse'}" + ] + }, + "execution_count": 27, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cfg" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Build Dataset\n", + "Build datasets from cfg defined above" + ] + }, + { + "cell_type": "code", + "execution_count": 28, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = build_dataset(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 29, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 29, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Build Reinforcement Learning Environments\n", + "Build environments based on cfg and previously-defined dataset\n", + "\n", + "A style-test is provided as an option to test the algorithm's performance under different market conditions" + ] + }, + { + "cell_type": "code", + "execution_count": 30, + "metadata": {}, + "outputs": [], + "source": [ + "train_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"train\"))\n", + "valid_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"valid\"))\n", + "test_environment = build_environment(cfg, default_args=dict(dataset=dataset, task=\"test\"))\n", + "if task_name.startswith(\"style_test\"):\n", + " test_style_environments = []\n", + " for i, path in enumerate(dataset.test_style_paths):\n", + " test_style_environments.append(build_environment(cfg, default_args=dict(dataset=dataset, task=\"test_style\",\n", + " style_test_path=path,\n", + " task_index=i)))\n" + ] + }, + { + "cell_type": "code", + "execution_count": 31, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 31, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_environment" + ] + }, + { + "cell_type": "code", + "execution_count": 32, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 32, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "valid_environment" + ] + }, + { + "cell_type": "code", + "execution_count": 33, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 33, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_environment" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: Build Net \n", + "Update information about the state and action dimension in the config and create nets and optimizer for EIIE\n" + ] + }, + { + "cell_type": "code", + "execution_count": 34, + "metadata": {}, + "outputs": [], + "source": [ + "action_dim = train_environment.action_dim # 29\n", + "state_dim = train_environment.state_dim # 11\n", + "input_dim = len(train_environment.tech_indicator_list)\n", + "time_steps = train_environment.time_steps\n", + "\n", + "cfg.act.update(dict(input_dim=input_dim, time_steps=time_steps))\n", + "cfg.cri.update(dict(input_dim=input_dim, action_dim= action_dim, time_steps=time_steps))\n", + "\n", + "act = build_net(cfg.act)\n", + "cri = build_net(cfg.cri)\n", + "act_optimizer = build_optimizer(cfg, default_args=dict(params=act.parameters()))\n", + "cri_optimizer = build_optimizer(cfg, default_args=dict(params=cri.parameters()))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: Build Loss\n", + "Build loss from config" + ] + }, + { + "cell_type": "code", + "execution_count": 35, + "metadata": {}, + "outputs": [], + "source": [ + "criterion = build_loss(cfg)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 7: Build Transition\n", + "Build transition from config" + ] + }, + { + "cell_type": "code", + "execution_count": 36, + "metadata": {}, + "outputs": [], + "source": [ + "transition = build_transition(cfg)\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 8: Build Agent\n", + "Build agent from config and detect device" + ] + }, + { + "cell_type": "code", + "execution_count": 37, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "agent = build_agent(cfg, default_args=dict(action_dim=action_dim,\n", + " state_dim=state_dim,\n", + " time_steps = time_steps,\n", + " act=act,\n", + " cri=cri,\n", + " act_optimizer=act_optimizer,\n", + " cri_optimizer = cri_optimizer,\n", + " criterion=criterion,\n", + " transition = transition,\n", + " device = device))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 9: Build Trainer\n", + "Build trainer from config and create work directionary to save the result, model and config" + ] + }, + { + "cell_type": "code", + "execution_count": 38, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "| Arguments Remove work_dir: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse\n" + ] + } + ], + "source": [ + "if task_name.startswith(\"style_test\"):\n", + " trainers = []\n", + " for env in test_style_environments:\n", + " trainers.append(build_trainer(cfg, default_args=dict(train_environment=train_environment,\n", + " valid_environment=valid_environment,\n", + " test_environment=env,\n", + " agent=agent,\n", + " device=device)))\n", + "else:\n", + " trainer = build_trainer(cfg, default_args=dict(train_environment=train_environment,\n", + " valid_environment=valid_environment,\n", + " test_environment=test_environment,\n", + " agent=agent,\n", + " device=device,\n", + " ))\n", + "work_dir = os.path.join(ROOT, cfg.trainer.work_dir)\n", + "\n", + "if not os.path.exists(work_dir):\n", + " os.makedirs(work_dir)\n", + "cfg.dump(osp.join(work_dir, osp.basename(args.config)))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 10: Train the Trainer\n", + "Train the trainer based on the config and get results from workdir" + ] + }, + { + "cell_type": "code", + "execution_count": 39, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Train Episode: [1/10]\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| 180.543176% | 0.001605 | 0.007575 | 0.645235 | 1.688276 | 4.176283 |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "Valid Episode: [1/10]\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| 9.519411% | 0.001804 | 0.021395 | 0.361807 | 0.406514 | 0.529803 |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "Valid Episode Reward Sum: 0.090932\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse/checkpoints/checkpoint-00001.pth\n", + "Train Episode: [2/10]\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| 180.534263% | 0.001605 | 0.007575 | 0.645216 | 1.688270 | 4.176311 |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "Valid Episode: [2/10]\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| 9.520920% | 0.001804 | 0.021394 | 0.361798 | 0.406550 | 0.529857 |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "Valid Episode Reward Sum: 0.090945\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse/checkpoints/checkpoint-00002.pth\n", + "Train Episode: [3/10]\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| 180.519361% | 0.001605 | 0.007574 | 0.645198 | 1.688229 | 4.176275 |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "Valid Episode: [3/10]\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| Profit Margin | Sharp Ratio | Volatility | Max Drawdown | Calmar Ratio | Sortino Ratio |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "| 9.521613% | 0.001804 | 0.021394 | 0.361794 | 0.406567 | 0.529882 |\n", + "+---------------+-------------+------------+--------------+--------------+---------------+\n", + "Valid Episode Reward Sum: 0.090952\n", + "save path /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/portfolio_management_dj30_eiie_eiie_adam_mse/checkpoints/checkpoint-00003.pth\n", + "Train Episode: [4/10]\n" + ] + } + ], + "source": [ + "if task_name.startswith(\"train\"):\n", + " trainer.train_and_valid()\n", + " trainer.test()\n", + " print(\"train end\")\n", + "elif task_name.startswith(\"test\"):\n", + " trainer.test()\n", + " print(\"test end\")\n", + "elif task_name.startswith(\"style_test\"):\n", + " daily_return_list = []\n", + " for trainer in trainers:\n", + " daily_return_list.extend(trainer.test())\n", + " print('win rate is: ', sum(r > 0 for r in daily_return_list) / len(daily_return_list))\n", + " print(\"style test end\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "HFT", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.15" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "c33605b009166d65f90ad63d824c8e63d22d0973c031452c4b4158e2872c99ad" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tutorial/Tutorial5_HFT.ipynb b/tutorial/Tutorial5_HFT.ipynb new file mode 100644 index 00000000..16e82ec7 --- /dev/null +++ b/tutorial/Tutorial5_HFT.ipynb @@ -0,0 +1,451 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 1: Import Packages\n", + "Modify the system path and load the corresponding packages and functions " + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import sys\n", + "from pathlib import Path\n", + "import warnings\n", + "warnings.filterwarnings(\"ignore\")\n", + "ROOT = str(Path(\"__file__\").resolve().parents[1])\n", + "sys.path.append(ROOT)\n", + "import torch\n", + "import argparse\n", + "import os.path as osp\n", + "from mmcv import Config\n", + "from trademaster.utils import replace_cfg_vals\n", + "from trademaster.nets.builder import build_net\n", + "from trademaster.environments.builder import build_environment\n", + "from trademaster.datasets.builder import build_dataset\n", + "from trademaster.agents.builder import build_agent\n", + "from trademaster.optimizers.builder import build_optimizer\n", + "from trademaster.losses.builder import build_loss\n", + "from trademaster.trainers.builder import build_trainer\n", + "from trademaster.transition.builder import build_transition" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 2: Load Configs\n", + "Load default config from the folder `configs/high_frequency_trading/high_frequency_trading_BTC_dqn_dqn_adam_mse.py`" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "metadata": {}, + "outputs": [], + "source": [ + "parser = argparse.ArgumentParser(description='Download Alpaca Datasets')\n", + "parser.add_argument(\"--config\", default=osp.join(ROOT, \"configs\", \"high_frequency_trading\", \"high_frequency_trading_BTC_dqn_dqn_adam_mse.py\"),\n", + " help=\"download datasets config file path\")\n", + "parser.add_argument(\"--task_name\", type=str, default=\"train\")\n", + "parser.add_argument(\"--test_style\", type=str, default='-1')\n", + "args = parser.parse_args([])\n", + "cfg = Config.fromfile(args.config)\n", + "task_name = args.task_name\n", + "\n", + "cfg = replace_cfg_vals(cfg)\n", + "# update test style\n", + "cfg.data.update({'test_style': args.test_style})\n" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "Config (path: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/configs/high_frequency_trading/high_frequency_trading_BTC_dqn_dqn_adam_mse.py): {'data': {'type': 'HighFrequencyTradingDataset', 'data_path': 'data/high_frequency_trading/small_BTC', 'train_path': 'data/high_frequency_trading/small_BTC/train.csv', 'valid_path': 'data/high_frequency_trading/small_BTC/valid.csv', 'test_path': 'data/high_frequency_trading/small_BTC/test.csv', 'test_style_path': 'data/high_frequency_trading/small_BTC/test.csv', 'tech_indicator_list': ['imblance_volume_oe', 'sell_spread_oe', 'buy_spread_oe', 'kmid2', 'bid1_size_n', 'ksft2', 'ma_10', 'ksft', 'kmid', 'ask1_size_n', 'trade_diff', 'qtlu_10', 'qtld_10', 'cntd_10', 'beta_10', 'roc_10', 'bid5_size_n', 'rsv_10', 'imxd_10', 'ask5_size_n', 'ma_30', 'max_10', 'qtlu_30', 'imax_10', 'imin_10', 'min_10', 'qtld_30', 'cntn_10', 'rsv_30', 'cntp_10', 'ma_60', 'max_30', 'qtlu_60', 'qtld_60', 'cntd_30', 'roc_30', 'beta_30', 'bid4_size_n', 'rsv_60', 'ask4_size_n', 'imxd_30', 'min_30', 'max_60', 'imax_30', 'imin_30', 'cntd_60', 'roc_60', 'beta_60', 'cntn_30', 'min_60', 'cntp_30', 'bid3_size_n', 'imxd_60', 'ask3_size_n', 'sell_volume_oe', 'imax_60', 'imin_60', 'cntn_60', 'buy_volume_oe', 'cntp_60', 'bid2_size_n', 'kup', 'bid1_size', 'ask1_size', 'std_30', 'ask2_size_n'], 'transcation_cost': 0, 'backward_num_timestamp': 1, 'max_holding_number': 0.01, 'num_action': 11, 'max_punish': 1000000000000.0, 'episode_length': 14400, 'test_style': '-1'}, 'environment': {'type': 'HighFrequencyTradingEnvironment'}, 'agent': {'type': 'HighFrequencyTradingDDQN', 'auxiliary_coffient': 512, 'reward_scale': 1, 'repeat_times': 1, 'gamma': 0.99, 'batch_size': 64, 'clip_grad_norm': 3.0, 'soft_update_tau': 0, 'state_value_tau': 0.005}, 'trainer': {'type': 'HighFrequencyTradingTrainer', 'epochs': 10, 'work_dir': 'work_dir/high_frequency_trading_BTC_high_frequency_trading_dqn_ddqn_adam_mse', 'seeds': 12345, 'batch_size': 512, 'horizon_len': 512, 'buffer_size': 100000.0, 'num_threads': 8, 'if_remove': False, 'if_discrete': True, 'if_off_policy': True, 'if_keep_save': True, 'if_over_write': False, 'if_save_buffer': False}, 'loss': {'type': 'HFTLoss', 'ada': 1}, 'optimizer': {'type': 'Adam', 'lr': 0.001}, 'act': {'type': 'HFTQNet', 'state_dim': 66, 'action_dim': 11, 'dims': 16, 'explore_rate': 0.25, 'max_punish': 0}, 'cri': None, 'task_name': 'high_frequency_trading', 'dataset_name': 'BTC', 'optimizer_name': 'adam', 'loss_name': 'mse', 'auxiliry_loss_name': 'KLdiv', 'net_name': 'high_frequency_trading_dqn', 'agent_name': 'ddqn', 'work_dir': 'work_dir/high_frequency_trading_BTC_high_frequency_trading_dqn_ddqn_adam_mse', 'batch_size': 512, 'train_environment': {'type': 'HighFrequencyTradingTrainingEnvironment'}}" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "cfg" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 3: Build Dataset\n", + "Build datasets from cfg defined above" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "metadata": {}, + "outputs": [], + "source": [ + "dataset = build_dataset(cfg)" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "dataset" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 4: Build Reinforcement Learning Environments\n", + "Build environments based on cfg and previously-defined dataset\n", + "\n" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "valid\n", + "test\n" + ] + } + ], + "source": [ + "valid_environment = build_environment(\n", + " cfg, default_args=dict(dataset=dataset, task=\"valid\")\n", + " )\n", + "test_environment = build_environment(\n", + " cfg, default_args=dict(dataset=dataset, task=\"test\")\n", + ")\n", + "cfg.environment = cfg.train_environment\n", + "train_environment = build_environment(\n", + " cfg, default_args=dict(dataset=dataset, task=\"train\")\n", + ")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 13, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "train_environment" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 14, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "valid_environment" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "test_environment" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 5: Build Net \n", + "Update information about the state and action dimension in the config and create nets and optimizer for DQN\n" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "metadata": {}, + "outputs": [], + "source": [ + "action_dim = train_environment.action_dim\n", + "state_dim = train_environment.state_dim\n", + "\n", + "cfg.act.update(dict(action_dim=action_dim, state_dim=state_dim))\n", + "act = build_net(cfg.act)\n", + "act_optimizer = build_optimizer(cfg, default_args=dict(params=act.parameters()))\n", + "if cfg.cri:\n", + " cfg.cri.update(dict(action_dim=action_dim, state_dim=state_dim))\n", + " cri = build_net(cfg.cri)\n", + " cri_optimizer = build_optimizer(cfg, default_args=dict(params=cri.parameters()))\n", + "else:\n", + " cri = None\n", + " cri_optimizer = None" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 6: Build Loss\n", + "Build loss from config" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "metadata": {}, + "outputs": [], + "source": [ + "criterion = build_loss(cfg)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 7: Build Agent\n", + "Build agent from config and detect device" + ] + }, + { + "cell_type": "code", + "execution_count": 19, + "metadata": {}, + "outputs": [], + "source": [ + "device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')\n", + "agent = build_agent(\n", + " cfg,\n", + " default_args=dict(\n", + " action_dim=action_dim,\n", + " state_dim=state_dim,\n", + " act=act,\n", + " cri=cri,\n", + " act_optimizer=act_optimizer,\n", + " cri_optimizer=cri_optimizer,\n", + " criterion=criterion,\n", + " device=device,\n", + " ),\n", + " )\n" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 8: Build Trainer\n", + "Build trainer from config and create work directionary to save the result, model and config" + ] + }, + { + "cell_type": "code", + "execution_count": 20, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "| Arguments Keep work_dir: /data1/sunshuo/qml/TradeMaster/TradeMasterReBuild/work_dir/high_frequency_trading_BTC_high_frequency_trading_dqn_ddqn_adam_mse\n" + ] + } + ], + "source": [ + "\n", + "trainer = build_trainer(cfg, default_args=dict(train_environment=train_environment,\n", + " valid_environment=valid_environment,\n", + " test_environment=test_environment,\n", + " agent=agent,\n", + " device=device))\n", + "\n", + "cfg.dump(osp.join(ROOT, cfg.work_dir, osp.basename(args.config)))" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Step 9: Train the Trainer\n", + "Train the trainer based on the config and get results from workdir" + ] + }, + { + "cell_type": "code", + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "Train Episode: [1/10]\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be sell all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n", + "the holding could not be bought all clear\n" + ] + } + ], + "source": [ + "if task_name.startswith(\"train\"):\n", + " trainer.train_and_valid()\n", + " trainer.test()\n", + " print(\"train end\")\n", + "elif task_name.startswith(\"test\"):\n", + " trainer.test()\n", + " print(\"test end\")\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "HFT", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.7.15" + }, + "orig_nbformat": 4, + "vscode": { + "interpreter": { + "hash": "c33605b009166d65f90ad63d824c8e63d22d0973c031452c4b4158e2872c99ad" + } + } + }, + "nbformat": 4, + "nbformat_minor": 2 +}