11import numpy as np
2- from gymnasium .wrappers import FlattenObservation
32import torch
3+ from gymnasium .wrappers import FlattenObservation
44
55from openrl .configs .config import create_config_parser
66from openrl .envs .common import make
77from openrl .envs .wrappers .base_wrapper import BaseWrapper
8- from openrl .envs .wrappers .extra_wrappers import FrameSkip , GIFWrapper ,ConvertEmptyBoxWrapper
8+ from openrl .envs .wrappers .extra_wrappers import (
9+ ConvertEmptyBoxWrapper ,
10+ FrameSkip ,
11+ GIFWrapper ,
12+ )
913from openrl .modules .common import PPONet as Net
1014from openrl .runners .common import PPOAgent as Agent
1115
12-
1316env_name = "dm_control/cartpole-balance-v0"
1417# env_name = "dm_control/walker-walk-v0"
1518
@@ -25,7 +28,7 @@ def train():
2528 env_name ,
2629 env_num = env_num ,
2730 asynchronous = True ,
28- env_wrappers = [FrameSkip , FlattenObservation ,ConvertEmptyBoxWrapper ],
31+ env_wrappers = [FrameSkip , FlattenObservation , ConvertEmptyBoxWrapper ],
2932 )
3033
3134 net = Net (env , cfg = cfg , device = "cuda" if torch .cuda .is_available () else "cpu" )
@@ -52,10 +55,10 @@ def evaluation():
5255 render_mode = render_mode ,
5356 env_num = 4 ,
5457 asynchronous = True ,
55- env_wrappers = [FrameSkip , FlattenObservation ,ConvertEmptyBoxWrapper ],
58+ env_wrappers = [FrameSkip , FlattenObservation , ConvertEmptyBoxWrapper ],
5659 )
5760 # Wrap the environment with GIFWrapper to record the GIF, and set the frame rate to 5.
58- # env = GIFWrapper(env, gif_path="./new.gif", fps=5)
61+ env = GIFWrapper (env , gif_path = "./new.gif" , fps = 5 )
5962
6063 net = Net (env , cfg = cfg , device = "cuda" if torch .cuda .is_available () else "cpu" )
6164 # initialize the trainer
0 commit comments