Skip to content

Commit 3a8914e

Browse files
committed
initial commit
1 parent 53d9f32 commit 3a8914e

File tree

7 files changed

+865
-0
lines changed

7 files changed

+865
-0
lines changed
Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,16 @@
1+
[base]
2+
package = ocean
3+
env_name = puffer_slimevolley
4+
policy_name = Policy
5+
6+
[env]
7+
; 1 for single-agent (vs bot), 2 for two-agent (self-play)
8+
num_agents=2
9+
10+
[train]
11+
learning_rate = 0.015
12+
total_timesteps = 10_000_000
13+
num_envs=128
14+
num_workers=8
15+
batch_size=1024
16+
minibatch_size=128

pufferlib/ocean/environment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,7 @@ def make_multiagent(buf=None, **kwargs):
158158
'whisker_racer': 'WhiskerRacer',
159159
'spaces': make_spaces,
160160
'multiagent': make_multiagent,
161+
'slimevolley': 'SlimeVolley',
161162
}
162163

163164
def env_creator(name='squared', *args, **kwargs):
Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#include "slimevolley.h"
2+
3+
#define Env SlimeVolley
4+
#include "../env_binding.h"
5+
6+
static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
7+
env->num_agents = unpack(kwargs, "num_agents");
8+
init(env);
9+
return 0;
10+
}
11+
12+
static int my_log(PyObject* dict, Log* log) {
13+
assign_to_dict(dict, "perf", log->perf);
14+
assign_to_dict(dict, "score", log->score);
15+
assign_to_dict(dict, "episode_return", log->episode_return);
16+
assign_to_dict(dict, "episode_length", log->episode_length);
17+
return 0;
18+
}
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
import gymnasium
2+
import numpy as np
3+
4+
from pufferlib.ocean.slimevolley import binding
5+
import pufferlib
6+
from pufferlib.ocean.torch import Policy
7+
import torch
8+
9+
class SlimeVolley(pufferlib.PufferEnv):
10+
def __init__(self, num_envs=1, render_mode=None, log_interval=128, buf=None, seed=0,
11+
num_agents=1):
12+
assert num_agents in {1, 2}, "num_agents must be 1 or 2"
13+
num_obs = 12
14+
self.single_observation_space = gymnasium.spaces.Box(low=0, high=1,
15+
shape=(num_obs,), dtype=np.float32)
16+
self.single_action_space = gymnasium.spaces.MultiDiscrete([2, 2, 2])
17+
18+
self.render_mode = render_mode
19+
self.num_agents = num_envs * num_agents
20+
self.log_interval = log_interval
21+
22+
super().__init__(buf)
23+
c_envs = []
24+
for i in range(num_envs):
25+
c_env = binding.env_init(
26+
self.observations[i*num_agents:(i+1)*num_agents],
27+
self.actions[i*num_agents:(i+1)*num_agents],
28+
self.rewards[i*num_agents:(i+1)*num_agents],
29+
self.terminals[i*num_agents:(i+1)*num_agents],
30+
self.truncations[i*num_agents:(i+1)*num_agents],
31+
seed,
32+
num_agents=num_agents
33+
)
34+
c_envs.append(c_env)
35+
36+
self.c_envs = binding.vectorize(*c_envs)
37+
38+
def reset(self, seed=0):
39+
binding.vec_reset(self.c_envs, seed)
40+
self.tick = 0
41+
return self.observations, []
42+
43+
def step(self, actions):
44+
self.tick += 1
45+
self.actions[:] = actions
46+
binding.vec_step(self.c_envs)
47+
48+
info = []
49+
if self.tick % self.log_interval == 0:
50+
log = binding.vec_log(self.c_envs)
51+
if log:
52+
info.append(log)
53+
54+
return (self.observations, self.rewards,
55+
self.terminals, self.truncations, info)
56+
57+
def render(self):
58+
binding.vec_render(self.c_envs, 0)
59+
60+
def close(self):
61+
binding.vec_close(self.c_envs)
62+
63+
64+
if __name__ == "__main__":
65+
env = SlimeVolley(num_envs=1, num_agents=1)
66+
observations, _ = env.reset()
67+
env.render()
68+
policy = Policy(env)
69+
policy.load_state_dict(torch.load("checkpoint.pt", map_location="cpu"))
70+
with torch.no_grad():
71+
while True:
72+
actions = policy(torch.from_numpy(observations))
73+
actions = [float(torch.argmax(a)) for a in actions[0]]
74+
o, r, t, _, i = env.step([actions])
75+
env.render()
76+
if t[0]:
77+
break
Lines changed: 66 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,66 @@
1+
/* Pure C demo file for SlimeVolley. Build it with:
2+
* bash scripts/build_ocean.sh target local (debug)
3+
* bash scripts/build_ocean.sh target fast
4+
* We suggest building and debugging your env in pure C first. You
5+
* get faster builds and better error messages
6+
*/
7+
#include "slimevolley.h"
8+
#include <stdio.h>
9+
10+
11+
void abranti_simple_policy(float* obs, float* action) {
12+
float x_agent = obs[0];
13+
float x_ball = obs[4];
14+
float vx_ball = obs[6];
15+
float backward = (-23.757145f * x_agent + 23.206863f * x_ball + 0.7943352f * vx_ball) + 1.4617119f;
16+
float forward = -64.6463748f * backward + 22.4668393f;
17+
action[0] = forward;
18+
action[1] = backward;
19+
action[2] = 1.0f; // always jump
20+
}
21+
22+
void random_policy(float* obs, float* action) {
23+
action[0] = rand() * 2 - 1;
24+
action[1] = rand() * 2 - 1;
25+
action[2] = rand() * 2 - 1;
26+
}
27+
28+
int main() {
29+
int num_obs = 12;
30+
int num_actions = 3;
31+
SlimeVolley env = {.num_agents = 1};
32+
init(&env);
33+
env.observations = (float*)calloc(env.num_agents*num_obs, sizeof(float));
34+
env.actions = (float*)calloc(num_actions*env.num_agents, sizeof(float));
35+
env.rewards = (float*)calloc(env.num_agents, sizeof(float));
36+
env.terminals = (unsigned char*)calloc(env.num_agents, sizeof(unsigned char));
37+
// Always call reset and render first
38+
c_reset(&env);
39+
c_render(&env);
40+
41+
fprintf(stderr, "num agents: %d\n", env.num_agents);
42+
43+
while (!WindowShouldClose()) {
44+
for (int i=0; i<env.num_agents; i++) {
45+
if (i == 0) {
46+
random_policy(&env.observations[12*i], &env.actions[3*i]);
47+
48+
} else {
49+
abranti_simple_policy(&env.observations[12*i], &env.actions[3*i]);
50+
}
51+
}
52+
c_step(&env);
53+
c_render(&env);
54+
if (env.terminals[0] || env.terminals[1]) {
55+
fprintf(stderr, "Episode ended. Rewards: %f, %f\n", env.rewards[0], env.rewards[1]);
56+
break;
57+
}
58+
}
59+
60+
// Try to clean up after yourself
61+
free(env.observations);
62+
free(env.actions);
63+
free(env.rewards);
64+
free(env.terminals);
65+
c_close(&env);
66+
}

0 commit comments

Comments
 (0)