1+ import gymnasium
2+ import numpy as np
3+
4+ from pufferlib .ocean .slimevolley import binding
5+ import pufferlib
6+ from pufferlib .ocean .torch import Policy
7+ import torch
8+
9+ class SlimeVolley (pufferlib .PufferEnv ):
10+ def __init__ (self , num_envs = 1 , render_mode = None , log_interval = 128 , buf = None , seed = 0 ,
11+ num_agents = 1 ):
12+ assert num_agents in {1 , 2 }, "num_agents must be 1 or 2"
13+ num_obs = 12
14+ self .single_observation_space = gymnasium .spaces .Box (low = 0 , high = 1 ,
15+ shape = (num_obs ,), dtype = np .float32 )
16+ self .single_action_space = gymnasium .spaces .MultiDiscrete ([2 , 2 , 2 ])
17+
18+ self .render_mode = render_mode
19+ self .num_agents = num_envs * num_agents
20+ self .log_interval = log_interval
21+
22+ super ().__init__ (buf )
23+ c_envs = []
24+ for i in range (num_envs ):
25+ c_env = binding .env_init (
26+ self .observations [i * num_agents :(i + 1 )* num_agents ],
27+ self .actions [i * num_agents :(i + 1 )* num_agents ],
28+ self .rewards [i * num_agents :(i + 1 )* num_agents ],
29+ self .terminals [i * num_agents :(i + 1 )* num_agents ],
30+ self .truncations [i * num_agents :(i + 1 )* num_agents ],
31+ seed ,
32+ num_agents = num_agents
33+ )
34+ c_envs .append (c_env )
35+
36+ self .c_envs = binding .vectorize (* c_envs )
37+
38+ def reset (self , seed = 0 ):
39+ binding .vec_reset (self .c_envs , seed )
40+ self .tick = 0
41+ return self .observations , []
42+
43+ def step (self , actions ):
44+ self .tick += 1
45+ self .actions [:] = actions
46+ binding .vec_step (self .c_envs )
47+
48+ info = []
49+ if self .tick % self .log_interval == 0 :
50+ log = binding .vec_log (self .c_envs )
51+ if log :
52+ info .append (log )
53+
54+ return (self .observations , self .rewards ,
55+ self .terminals , self .truncations , info )
56+
57+ def render (self ):
58+ binding .vec_render (self .c_envs , 0 )
59+
60+ def close (self ):
61+ binding .vec_close (self .c_envs )
62+
63+
64+ if __name__ == "__main__" :
65+ env = SlimeVolley (num_envs = 1 , num_agents = 1 )
66+ observations , _ = env .reset ()
67+ env .render ()
68+ policy = Policy (env )
69+ policy .load_state_dict (torch .load ("checkpoint.pt" , map_location = "cpu" ))
70+ with torch .no_grad ():
71+ while True :
72+ actions = policy (torch .from_numpy (observations ))
73+ actions = [float (torch .argmax (a )) for a in actions [0 ]]
74+ o , r , t , _ , i = env .step ([actions ])
75+ env .render ()
76+ if t [0 ]:
77+ break
0 commit comments