|
21 | 21 | ParallelEnv,
|
22 | 22 | Resize,
|
23 | 23 | RewardSum,
|
| 24 | + set_gym_backend, |
24 | 25 | SignTransform,
|
25 | 26 | StepCounter,
|
26 | 27 | ToTensorImage,
|
|
45 | 46 |
|
46 | 47 |
|
47 | 48 | def make_base_env(
|
48 |
| - env_name="BreakoutNoFrameskip-v4", frame_skip=4, device="cpu", is_test=False |
| 49 | + env_name="BreakoutNoFrameskip-v4", |
| 50 | + gym_backend="gymnasium", |
| 51 | + frame_skip=4, |
| 52 | + device="cpu", |
| 53 | + is_test=False, |
49 | 54 | ):
|
50 |
| - env = GymEnv( |
51 |
| - env_name, |
52 |
| - frame_skip=frame_skip, |
53 |
| - from_pixels=True, |
54 |
| - pixels_only=False, |
55 |
| - device=device, |
56 |
| - ) |
| 55 | + with set_gym_backend(gym_backend): |
| 56 | + env = GymEnv( |
| 57 | + env_name, |
| 58 | + frame_skip=frame_skip, |
| 59 | + from_pixels=True, |
| 60 | + pixels_only=False, |
| 61 | + device=device, |
| 62 | + ) |
57 | 63 | env = TransformedEnv(env)
|
58 | 64 | env.append_transform(NoopResetEnv(noops=30, random=True))
|
59 | 65 | if not is_test:
|
60 | 66 | env.append_transform(EndOfLifeTransform())
|
61 | 67 | return env
|
62 | 68 |
|
63 | 69 |
|
64 |
| -def make_parallel_env(env_name, num_envs, device, is_test=False): |
| 70 | +def make_parallel_env(env_name, num_envs, device, gym_backend, is_test=False): |
65 | 71 | env = ParallelEnv(
|
66 | 72 | num_envs,
|
67 |
| - EnvCreator(lambda: make_base_env(env_name)), |
| 73 | + EnvCreator( |
| 74 | + lambda: make_base_env(env_name, gym_backend=gym_backend, is_test=is_test), |
| 75 | + ), |
68 | 76 | serial_for_single=True,
|
| 77 | + gym_backend=gym_backend, |
69 | 78 | device=device,
|
70 | 79 | )
|
71 | 80 | env = TransformedEnv(env)
|
@@ -175,9 +184,11 @@ def make_ppo_modules_pixels(proof_environment, device):
|
175 | 184 | return common_module, policy_module, value_module
|
176 | 185 |
|
177 | 186 |
|
178 |
| -def make_ppo_models(env_name, device): |
| 187 | +def make_ppo_models(env_name, device, gym_backend): |
179 | 188 |
|
180 |
| - proof_environment = make_parallel_env(env_name, 1, device="cpu") |
| 189 | + proof_environment = make_parallel_env( |
| 190 | + env_name, num_envs=1, device="cpu", gym_backend=gym_backend |
| 191 | + ) |
181 | 192 | common_module, policy_module, value_module = make_ppo_modules_pixels(
|
182 | 193 | proof_environment, device=device
|
183 | 194 | )
|
|
0 commit comments