Skip to content

Commit 636cbb3

Browse files
committed
Add lock_key environment
1 parent 7a99b3b commit 636cbb3

File tree

8 files changed

+282
-0
lines changed

8 files changed

+282
-0
lines changed
Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,17 @@
1+
[base]
2+
package = ocean
3+
env_name = puffer_lock_key
4+
policy_name = Policy
5+
rnn_name = Recurrent
6+
7+
[env]
8+
num_envs = 4096
9+
num_keys = 1
10+
size = 8
11+
log_interval = 128
12+
13+
[train]
14+
total_timesteps = 20_000_000
15+
gamma = 0.95
16+
learning_rate = 0.05
17+
minibatch_size = 32768

pufferlib/ocean/environment.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -163,6 +163,7 @@ def make_multiagent(buf=None, **kwargs):
163163
'spaces': make_spaces,
164164
'multiagent': make_multiagent,
165165
'slimevolley': 'SlimeVolley',
166+
'lock_key': 'LockKey',
166167
}
167168

168169
def env_creator(name='squared', *args, **kwargs):

pufferlib/ocean/lock_key/binding.c

Lines changed: 18 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,18 @@
1+
#include "lock_key.h"
2+
3+
#define Env LockKey
4+
#include "../env_binding.h"
5+
6+
static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
7+
env->size = unpack(kwargs, "size");
8+
env->num_keys = unpack(kwargs, "num_keys");
9+
return 0;
10+
}
11+
12+
static int my_log(PyObject* dict, Log* log) {
13+
assign_to_dict(dict, "score", log->score);
14+
assign_to_dict(dict, "perf", log->perf);
15+
assign_to_dict(dict, "episode_return", log->episode_return);
16+
assign_to_dict(dict, "episode_length", log->episode_length);
17+
return 0;
18+
}

pufferlib/ocean/lock_key/lock_key

33.4 KB
Binary file not shown.
Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
#include "lock_key.h"
2+
3+
int main() {
4+
LockKey env = {.size = 8, .num_keys = 3};
5+
env.observations = (unsigned char*)calloc(env.size*env.size, sizeof(unsigned char));
6+
env.actions = (int*)calloc(1, sizeof(int));
7+
env.rewards = (float*)calloc(1, sizeof(float));
8+
env.terminals = (unsigned char*)calloc(1, sizeof(unsigned char));
9+
10+
c_reset(&env);
11+
c_render(&env);
12+
while (!WindowShouldClose()) {
13+
if (IsKeyDown(KEY_LEFT_SHIFT)) {
14+
if (IsKeyDown(KEY_A) || IsKeyDown(KEY_LEFT)) {
15+
env.actions[0] = 0;
16+
} else if (IsKeyDown(KEY_D) || IsKeyDown(KEY_RIGHT)) {
17+
env.actions[0] = 1;
18+
} else if (IsKeyDown(KEY_W) || IsKeyDown(KEY_UP)) {
19+
env.actions[0] = 2;
20+
} else if (IsKeyDown(KEY_S) || IsKeyDown(KEY_DOWN)) {
21+
env.actions[0] = 3;
22+
} else {
23+
env.actions[0] = -1;
24+
}
25+
} else {
26+
env.actions[0] = rand() % 5;
27+
}
28+
c_step(&env);
29+
c_render(&env);
30+
}
31+
free(env.observations);
32+
free(env.actions);
33+
free(env.rewards);
34+
free(env.terminals);
35+
c_close(&env);
36+
}
37+
Lines changed: 159 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,159 @@
1+
#ifndef LOCK_KEY_H
2+
#define LOCK_KEY_H
3+
4+
#include <stdio.h>
5+
#include <stdlib.h>
6+
#include <string.h>
7+
#include "raylib.h"
8+
9+
static const Color PUFF_RED = (Color){187, 0, 0, 255};
10+
static const Color PUFF_CYAN = (Color){0, 187, 187, 255};
11+
static const Color PUFF_GREEN = (Color){0, 187, 0, 255};
12+
static const Color PUFF_BACKGROUND = (Color){45, 30, 20, 255};
13+
14+
typedef struct {
15+
float perf;
16+
float score;
17+
float episode_return;
18+
float episode_length;
19+
float n;
20+
} Log;
21+
22+
typedef struct {
23+
Log log;
24+
25+
unsigned char* observations;
26+
int* actions;
27+
float* rewards;
28+
unsigned char* terminals;
29+
unsigned char* truncations;
30+
31+
int size;
32+
int num_keys;
33+
int tick;
34+
int x;
35+
int y;
36+
int num_keys_collected;
37+
} LockKey;
38+
39+
static inline int lk_pos(LockKey* env, int x, int y) {
40+
return y * env->size + x;
41+
}
42+
43+
void add_log(LockKey* env) {
44+
env->log.perf += (env->rewards[0] > 0) ? 1 : 0;
45+
env->log.score += env->rewards[0];
46+
env->log.episode_return += env->rewards[0];
47+
env->log.episode_length += env->tick;
48+
env->log.n++;
49+
}
50+
51+
static inline void c_reset(LockKey* env) {
52+
int tiles = env->size * env->size;
53+
memset(env->observations, 0, tiles * sizeof(unsigned char));
54+
55+
env->x = env->size / 2;
56+
env->y = env->size / 2;
57+
int player_pos = lk_pos(env, env->x, env->y);
58+
env->observations[player_pos] = 1;
59+
env->tick = 0;
60+
61+
int lock_idx;
62+
do lock_idx = rand() % tiles;
63+
while (lock_idx == player_pos);
64+
env->observations[lock_idx] = 2;
65+
66+
for (int i = 0; i < env->num_keys; i++) {
67+
int key_idx;
68+
do key_idx = rand() % tiles;
69+
while (env->observations[key_idx] != 0);
70+
env->observations[key_idx] = 3;
71+
}
72+
73+
env->num_keys_collected = 0;
74+
}
75+
76+
static inline void c_step(LockKey* env) {
77+
env->tick++;
78+
env->rewards[0] = -0.1f;
79+
env->terminals[0] = 0;
80+
if (env->truncations) env->truncations[0] = 0;
81+
82+
// clear agent from previous position if not on lock
83+
if (env->observations[lk_pos(env, env->x, env->y)] != 2)
84+
env->observations[lk_pos(env, env->x, env->y)] = 0;
85+
86+
int a = env->actions[0];
87+
if (a == 0) env->x--;
88+
else if (a == 1) env->x++;
89+
else if (a == 2) env->y--;
90+
else if (a == 3) env->y++;
91+
92+
// terminal if out of bounds or max steps reached
93+
int max_steps = 3*env->size + env->num_keys*env->num_keys;
94+
if (env->tick > max_steps || env->x < 0 || env->x >= env->size || env->y < 0 || env->y >= env->size) {
95+
env->rewards[0] = -3.0f;
96+
env->terminals[0] = 1;
97+
add_log(env);
98+
c_reset(env);
99+
return;
100+
}
101+
102+
int pos = lk_pos(env, env->x, env->y);
103+
104+
// collect key
105+
if (env->observations[pos] == 3) {
106+
env->rewards[0] += 1.0f;
107+
env->num_keys_collected++;
108+
}
109+
110+
// open lock if all keys collected
111+
if (env->observations[pos] == 2 && env->num_keys_collected == env->num_keys) {
112+
env->rewards[0] = 3.0f;
113+
env->terminals[0] = 1;
114+
add_log(env);
115+
c_reset(env);
116+
return;
117+
}
118+
119+
// move agent, but don't override observations if on lock
120+
if (env->observations[pos] != 2)
121+
env->observations[pos] = 1;
122+
}
123+
124+
static inline void c_render(LockKey* env) {
125+
if (!IsWindowReady()) {
126+
InitWindow(64*env->size, 64*env->size, "LockKey");
127+
SetTargetFPS(5);
128+
}
129+
130+
if (IsKeyDown(KEY_ESCAPE)) exit(0);
131+
132+
BeginDrawing();
133+
ClearBackground(PUFF_BACKGROUND);
134+
135+
for (int y = 0; y < env->size; y++) {
136+
for (int x = 0; x < env->size; x++) {
137+
int pos = y * env->size + x;
138+
unsigned char v = env->observations[pos];
139+
if (!v) continue;
140+
141+
Color color =
142+
(v == 1) ? PUFF_CYAN :
143+
(v == 2) ? PUFF_RED :
144+
(v == 3) ? PUFF_GREEN :
145+
PUFF_BACKGROUND;
146+
147+
DrawRectangle(x * 64, y * 64, 64, 64, color);
148+
}
149+
}
150+
151+
EndDrawing();
152+
}
153+
154+
static inline void c_close(LockKey* env) {
155+
(void)env;
156+
if (IsWindowReady()) CloseWindow();
157+
}
158+
159+
#endif
Lines changed: 50 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,50 @@
1+
import gymnasium
2+
import numpy as np
3+
import pufferlib
4+
from pufferlib.ocean.lock_key import binding
5+
6+
class LockKey(pufferlib.PufferEnv):
7+
def __init__(self, num_envs=1, render_mode=None, log_interval=128, size=8, num_keys=3, buf=None, seed=0):
8+
self.size = size
9+
self.num_keys = num_keys
10+
11+
# C writes a flattened size*size uint8 grid with values {0,1,2}
12+
self.single_observation_space = gymnasium.spaces.Box(
13+
low=0, high=2, shape=(size * size,), dtype=np.uint8
14+
)
15+
16+
# 0=L, 1=R, 2=U, 3=D, 4=NOOP
17+
self.single_action_space = gymnasium.spaces.Discrete(5)
18+
19+
self.render_mode = render_mode
20+
self.num_agents = num_envs
21+
self.log_interval = log_interval
22+
super().__init__(buf)
23+
24+
self.c_envs = binding.vec_init(
25+
self.observations, self.actions, self.rewards,
26+
self.terminals, self.truncations, num_envs, seed,
27+
size=size,
28+
num_keys=num_keys,
29+
)
30+
31+
def reset(self, seed=0):
32+
binding.vec_reset(self.c_envs, seed)
33+
self.tick = 0
34+
return self.observations, []
35+
36+
def step(self, actions):
37+
self.tick += 1
38+
self.actions[:] = actions
39+
binding.vec_step(self.c_envs)
40+
info = []
41+
if self.tick % self.log_interval == 0:
42+
info.append(binding.vec_log(self.c_envs))
43+
return (self.observations, self.rewards,
44+
self.terminals, self.truncations, info)
45+
46+
def render(self):
47+
binding.vec_render(self.c_envs, 0)
48+
49+
def close(self):
50+
binding.vec_close(self.c_envs)
552 KB
Binary file not shown.

0 commit comments

Comments
 (0)