Skip to content

Commit 729bd12

Browse files
committed
add obs_dist to control how much of the state the agent can see on each frame
1 parent 3f81e29 commit 729bd12

File tree

5 files changed

+85
-34
lines changed

5 files changed

+85
-34
lines changed

pufferlib/config/ocean/lock_key.ini

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@ num_envs = 4096
99
num_keys = 1
1010
size = 8
1111
log_interval = 128
12+
obs_dist = 2
1213

1314
[train]
1415
total_timesteps = 20_000_000

pufferlib/ocean/lock_key/binding.c

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,20 @@
66
static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
77
env->size = unpack(kwargs, "size");
88
env->num_keys = unpack(kwargs, "num_keys");
9+
env->obs_dist = unpack(kwargs, "obs_dist");
10+
11+
int tiles = env->size * env->size;
12+
env->state = (unsigned char*)calloc(tiles, sizeof(unsigned char));
13+
if (!env->state) return -1;
14+
15+
return 0;
16+
}
17+
18+
static int my_close(Env* env) {
19+
if (env->state) {
20+
free(env->state);
21+
env->state = NULL;
22+
}
923
return 0;
1024
}
1125

pufferlib/ocean/lock_key/lock_key.c

Lines changed: 22 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,23 @@
1+
#include <time.h>
12
#include "lock_key.h"
23

34
int main() {
4-
LockKey env = {.size = 8, .num_keys = 3};
5-
env.observations = (unsigned char*)calloc(env.size*env.size, sizeof(unsigned char));
6-
env.actions = (int*)calloc(1, sizeof(int));
7-
env.rewards = (float*)calloc(1, sizeof(float));
8-
env.terminals = (unsigned char*)calloc(1, sizeof(unsigned char));
5+
srand((unsigned int)time(NULL));
6+
7+
LockKey env = {.size = 8, .num_keys = 3, .obs_dist = 2};
8+
9+
int tiles = env.size * env.size;
10+
11+
env.state = (unsigned char*)calloc(tiles, sizeof(unsigned char));
12+
env.observations = (unsigned char*)calloc(tiles, sizeof(unsigned char));
13+
env.actions = (int*)calloc(1, sizeof(int));
14+
env.rewards = (float*)calloc(1, sizeof(float));
15+
env.terminals = (unsigned char*)calloc(1, sizeof(unsigned char));
16+
env.truncations = (unsigned char*)calloc(1, sizeof(unsigned char)); // optional
917

1018
c_reset(&env);
1119
c_render(&env);
20+
1221
while (!WindowShouldClose()) {
1322
if (IsKeyDown(KEY_LEFT_SHIFT)) {
1423
if (IsKeyDown(KEY_A) || IsKeyDown(KEY_LEFT)) {
@@ -20,18 +29,23 @@ int main() {
2029
} else if (IsKeyDown(KEY_S) || IsKeyDown(KEY_DOWN)) {
2130
env.actions[0] = 3;
2231
} else {
23-
env.actions[0] = -1;
32+
env.actions[0] = -1; // no-op
2433
}
2534
} else {
26-
env.actions[0] = rand() % 5;
35+
env.actions[0] = rand() % 5; // 4 == no-op, still fine
2736
}
37+
2838
c_step(&env);
2939
c_render(&env);
3040
}
41+
42+
free(env.state);
3143
free(env.observations);
3244
free(env.actions);
3345
free(env.rewards);
3446
free(env.terminals);
47+
if (env.truncations) free(env.truncations);
48+
3549
c_close(&env);
50+
return 0;
3651
}
37-

pufferlib/ocean/lock_key/lock_key.h

Lines changed: 45 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@
99
static const Color PUFF_RED = (Color){187, 0, 0, 255};
1010
static const Color PUFF_CYAN = (Color){0, 187, 187, 255};
1111
static const Color PUFF_GREEN = (Color){0, 187, 0, 255};
12-
static const Color PUFF_BACKGROUND = (Color){45, 30, 20, 255};
12+
static const Color PUFF_BACKGROUND = (Color){65, 30, 40, 255};
13+
static const Color PUFF_BLACK = (Color){0, 0, 0, 255};
1314

1415
typedef struct {
1516
float perf;
@@ -22,7 +23,11 @@ typedef struct {
2223
typedef struct {
2324
Log log;
2425

26+
// observations: partial view for agent
2527
unsigned char* observations;
28+
// state: full system state (same tile encoding as before)
29+
unsigned char* state;
30+
2631
int* actions;
2732
float* rewards;
2833
unsigned char* terminals;
@@ -34,12 +39,32 @@ typedef struct {
3439
int x;
3540
int y;
3641
int num_keys_collected;
42+
int obs_dist;
3743
} LockKey;
3844

3945
static inline int lk_pos(LockKey* env, int x, int y) {
4046
return y * env->size + x;
4147
}
4248

49+
static inline int lk_visible(LockKey* env, int x, int y) {
50+
int dx = x - env->x; if (dx < 0) dx = -dx;
51+
int dy = y - env->y; if (dy < 0) dy = -dy;
52+
return (dx > dy ? dx : dy) <= env->obs_dist;
53+
}
54+
55+
static inline void lk_update_observations(LockKey* env) {
56+
int tiles = env->size * env->size;
57+
memset(env->observations, 0, tiles * sizeof(unsigned char));
58+
59+
for (int y = 0; y < env->size; y++) {
60+
for (int x = 0; x < env->size; x++) {
61+
if (!lk_visible(env, x, y)) continue;
62+
int pos = lk_pos(env, x, y);
63+
env->observations[pos] = env->state[pos];
64+
}
65+
}
66+
}
67+
4368
void add_log(LockKey* env) {
4469
env->log.perf += (env->rewards[0] > 0) ? 1 : 0;
4570
env->log.score += env->rewards[0];
@@ -50,27 +75,28 @@ void add_log(LockKey* env) {
5075

5176
static inline void c_reset(LockKey* env) {
5277
int tiles = env->size * env->size;
53-
memset(env->observations, 0, tiles * sizeof(unsigned char));
78+
memset(env->state, 0, tiles * sizeof(unsigned char));
5479

5580
env->x = env->size / 2;
5681
env->y = env->size / 2;
5782
int player_pos = lk_pos(env, env->x, env->y);
58-
env->observations[player_pos] = 1;
83+
env->state[player_pos] = 1;
5984
env->tick = 0;
6085

6186
int lock_idx;
6287
do lock_idx = rand() % tiles;
6388
while (lock_idx == player_pos);
64-
env->observations[lock_idx] = 2;
89+
env->state[lock_idx] = 2;
6590

6691
for (int i = 0; i < env->num_keys; i++) {
6792
int key_idx;
6893
do key_idx = rand() % tiles;
69-
while (env->observations[key_idx] != 0);
70-
env->observations[key_idx] = 3;
94+
while (env->state[key_idx] != 0);
95+
env->state[key_idx] = 3;
7196
}
7297

7398
env->num_keys_collected = 0;
99+
lk_update_observations(env);
74100
}
75101

76102
static inline void c_step(LockKey* env) {
@@ -79,17 +105,16 @@ static inline void c_step(LockKey* env) {
79105
env->terminals[0] = 0;
80106
if (env->truncations) env->truncations[0] = 0;
81107

82-
// clear agent from previous position if not on lock
83-
if (env->observations[lk_pos(env, env->x, env->y)] != 2)
84-
env->observations[lk_pos(env, env->x, env->y)] = 0;
108+
int prev_pos = lk_pos(env, env->x, env->y);
109+
if (env->state[prev_pos] != 2)
110+
env->state[prev_pos] = 0;
85111

86112
int a = env->actions[0];
87113
if (a == 0) env->x--;
88114
else if (a == 1) env->x++;
89115
else if (a == 2) env->y--;
90116
else if (a == 3) env->y++;
91117

92-
// terminal if out of bounds or max steps reached
93118
int max_steps = 3*env->size + env->num_keys*env->num_keys;
94119
if (env->tick > max_steps || env->x < 0 || env->x >= env->size || env->y < 0 || env->y >= env->size) {
95120
env->rewards[0] = -3.0f;
@@ -101,24 +126,23 @@ static inline void c_step(LockKey* env) {
101126

102127
int pos = lk_pos(env, env->x, env->y);
103128

104-
// collect key
105-
if (env->observations[pos] == 3) {
129+
if (env->state[pos] == 3) {
106130
env->rewards[0] += 1.0f;
107131
env->num_keys_collected++;
108132
}
109133

110-
// open lock if all keys collected
111-
if (env->observations[pos] == 2 && env->num_keys_collected == env->num_keys) {
134+
if (env->state[pos] == 2 && env->num_keys_collected == env->num_keys) {
112135
env->rewards[0] = 3.0f;
113136
env->terminals[0] = 1;
114137
add_log(env);
115138
c_reset(env);
116139
return;
117140
}
118141

119-
// move agent, but don't override observations if on lock
120-
if (env->observations[pos] != 2)
121-
env->observations[pos] = 1;
142+
if (env->state[pos] != 2)
143+
env->state[pos] = 1;
144+
145+
lk_update_observations(env);
122146
}
123147

124148
static inline void c_render(LockKey* env) {
@@ -130,11 +154,13 @@ static inline void c_render(LockKey* env) {
130154
if (IsKeyDown(KEY_ESCAPE)) exit(0);
131155

132156
BeginDrawing();
133-
ClearBackground(PUFF_BACKGROUND);
134157

135158
for (int y = 0; y < env->size; y++) {
136159
for (int x = 0; x < env->size; x++) {
137-
int pos = y * env->size + x;
160+
Color bg = lk_visible(env, x, y) ? PUFF_BACKGROUND : PUFF_BLACK;
161+
DrawRectangle(x * 64, y * 64, 64, 64, bg);
162+
163+
int pos = lk_pos(env, x, y);
138164
unsigned char v = env->observations[pos];
139165
if (!v) continue;
140166

pufferlib/ocean/lock_key/lock_key.py

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -4,16 +4,11 @@
44
from pufferlib.ocean.lock_key import binding
55

66
class LockKey(pufferlib.PufferEnv):
7-
def __init__(self, num_envs=1, render_mode=None, log_interval=128, size=8, num_keys=3, buf=None, seed=0):
8-
self.size = size
9-
self.num_keys = num_keys
10-
11-
# C writes a flattened size*size uint8 grid with values {0,1,2}
7+
def __init__(self, num_envs=1, render_mode=None, log_interval=128, size=8, num_keys=3, buf=None, seed=0, obs_dist=2):
128
self.single_observation_space = gymnasium.spaces.Box(
13-
low=0, high=2, shape=(size * size,), dtype=np.uint8
9+
low=0, high=3, shape=(size * size,), dtype=np.uint8
1410
)
1511

16-
# 0=L, 1=R, 2=U, 3=D, 4=NOOP
1712
self.single_action_space = gymnasium.spaces.Discrete(5)
1813

1914
self.render_mode = render_mode
@@ -26,6 +21,7 @@ def __init__(self, num_envs=1, render_mode=None, log_interval=128, size=8, num_k
2621
self.terminals, self.truncations, num_envs, seed,
2722
size=size,
2823
num_keys=num_keys,
24+
obs_dist=obs_dist,
2925
)
3026

3127
def reset(self, seed=0):

0 commit comments

Comments
 (0)