forked from notadamking/Stock-Trading-Environment
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.py
More file actions
26 lines (19 loc) · 656 Bytes
/
main.py
File metadata and controls
26 lines (19 loc) · 656 Bytes
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
import gym
import json
import datetime as dt
from stable_baselines.common.policies import MlpPolicy
from stable_baselines.common.vec_env import DummyVecEnv
from stable_baselines import PPO2
from env.StockTradingEnv import StockTradingEnv
import pandas as pd
df = pd.read_csv('./data/AAPL.csv')
df = df.sort_values('Date')
# The algorithms require a vectorized environment to run
env = DummyVecEnv([lambda: StockTradingEnv(df)])
model = PPO2(MlpPolicy, env, verbose=1)
model.learn(total_timesteps=20000)
obs = env.reset()
for i in range(2000):
action, _states = model.predict(obs)
obs, rewards, done, info = env.step(action)
env.render()