Skip to content

Commit 9307ce5

Browse files
committed
minor changes to fix a bug in the env and add sweeped params and a cnn
1 parent f0f952a commit 9307ce5

File tree

3 files changed

+39
-43
lines changed

3 files changed

+39
-43
lines changed

pufferlib/config/ocean/g2048.ini

Lines changed: 18 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -2,11 +2,26 @@
22
package = ocean
33
env_name = puffer_g2048
44
policy_name = G2048
5-
rnn_name = Recurrent
65

76
[env]
8-
num_envs = 4096
7+
num_envs = 4024
98

109
[train]
11-
total_timesteps = 1_000_000_000
10+
total_timesteps = 10_000_000_000 #4.6B
11+
adam_beta1 = 0.8081024539479613
12+
adam_beta2 = 0.9978536811174212
13+
adam_eps = 1.4542006937471102e-9
14+
bptt_horizon = 64
15+
clip_coef = 0.095627870395359
16+
ent_coef = 0.08439222625935798
17+
gae_lambda = 0.9155041484587568
18+
gamma = 0.9661669903070148
19+
learning_rate = 0.0014756768275156805
20+
max_grad_norm = 0.8813109722891985
1221
minibatch_size = 32768
22+
prio_alpha = 0.565686548517019
23+
prio_beta0 = 0.811234742153397
24+
vf_clip_coef = 0.1
25+
vf_coef = 2.2900867366155664
26+
vtrace_c_clip = 0
27+
vtrace_rho_clip = 0.8647738667214979

pufferlib/ocean/g2048/2048.h

Lines changed: 11 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -21,16 +21,15 @@ typedef struct {
2121
float n;
2222
} Log;
2323

24-
// Required struct for env_binding.h compatibility
2524
typedef struct {
2625
Log log; // Required
27-
unsigned char* observations; // Required (flattened 256 floats: 16 tiles * 16 one-hot)
26+
unsigned char* observations; // Cheaper in memory if encoded in uint_8
2827
int* actions; // Required
2928
float* rewards; // Required
3029
unsigned char* terminals; // Required
3130
int score;
3231
int tick;
33-
unsigned char grid[SIZE][SIZE]; // Store tile values directly as floats
32+
unsigned char grid[SIZE][SIZE];
3433
float episode_reward; // Accumulate episode reward
3534
} Game;
3635

@@ -43,27 +42,15 @@ void c_step(Game* env);
4342
void c_render(Game* env);
4443
void c_close(Game* env);
4544

46-
// Update the observation vector to be one-hot encoded for all tiles
4745
static void update_observations(Game* game) {
48-
for (int i = 0; i < SIZE; i++) {
49-
for (int j = 0; j < SIZE; j++) {
50-
game->observations[i * SIZE + j] = game->grid[i][j];
51-
}
52-
}
46+
memcpy(game->observations, game->grid, sizeof(unsigned char) * SIZE * SIZE);
5347
}
5448

5549
// --- Implementation ---
5650

5751
void add_log(Game* game) {
58-
int max_tile = 0;
59-
for (int i = 0; i < SIZE; i++) {
60-
for (int j = 0; j < SIZE; j++) {
61-
int tile = (int)(game->grid[i][j]);
62-
if (tile > max_tile) max_tile = tile;
63-
}
64-
}
65-
game->log.score = (float)pow(2,max_tile);
66-
game->log.perf += (game->rewards[0] > 0) ? 1 : 0;
52+
game->log.score = pow(2,(float)game->score);
53+
game->log.perf += (float)game->score/11.;
6754
game->log.episode_length += game->tick;
6855
game->log.episode_return += game->episode_reward;
6956
game->log.n += 1;
@@ -103,22 +90,12 @@ void add_random_tile(Game* game) {
10390
}
10491
if (count > 0) {
10592
int random_index = rand() % count;
106-
int value = (rand() % 10 == 0) ? 4 : 2;
107-
game->grid[empty_cells[random_index][0]][empty_cells[random_index][1]] = (float)value;
93+
int value = (rand() % 10 == 0) ? 2 : 1;
94+
game->grid[empty_cells[random_index][0]][empty_cells[random_index][1]] = value;
10895
}
10996
update_observations(game);
11097
}
11198

112-
void print_grid(Game* game) {
113-
for (int i = 0; i < SIZE; i++) {
114-
for (int j = 0; j < SIZE; j++) {
115-
printf("%4.0f ", game->grid[i][j]);
116-
}
117-
printf("\n");
118-
}
119-
printf("Score: %d\n", game->score);
120-
}
121-
12299
bool slide_and_merge_row(float* row, float* reward) {
123100
bool moved = false;
124101
// Slide left
@@ -231,11 +208,11 @@ bool is_game_over(Game* game) {
231208
return true;
232209
}
233210

234-
int calc_score(Game* game) {
235-
int max_tile = 0;
211+
unsigned char calc_score(Game* game) {
212+
unsigned char max_tile = 0;
236213
for (int i = 0; i < SIZE; i++) {
237214
for (int j = 0; j < SIZE; j++) {
238-
int tile = (int)(game->grid[i][j]);
215+
int tile = (game->grid[i][j]);
239216
if (tile > max_tile) max_tile = tile;
240217
}
241218
}
@@ -248,7 +225,7 @@ void c_step(Game* game) {
248225
game->tick += 1;
249226
if (did_move) {
250227
add_random_tile(game);
251-
game->score = calc_score(game);
228+
game->score = (float)calc_score(game);
252229
}
253230

254231
game->terminals[0] = is_game_over(game) ? 1 : 0;

pufferlib/ocean/torch.py

Lines changed: 10 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -190,15 +190,20 @@ def decode_actions(self, hidden):
190190
return action, value
191191

192192
class G2048(nn.Module):
193-
def __init__(self, env, hidden_size=128):
193+
def __init__(self, env, cnn_channels=32, hidden_size=128):
194194
super().__init__()
195195
self.hidden_size = hidden_size
196196
self.is_continuous = False
197197

198-
self.encoder= torch.nn.Sequential(
199-
nn.Linear(12*np.prod(env.single_observation_space.shape), hidden_size),
198+
self.cnn = nn.Sequential(
199+
pufferlib.pytorch.layer_init(
200+
nn.Conv2d(16, cnn_channels, 2, stride=1)),
200201
nn.GELU(),
202+
pufferlib.pytorch.layer_init(
203+
nn.Conv2d(cnn_channels, cnn_channels, 2, stride=1)),
204+
nn.Flatten(),
201205
)
206+
202207
self.decoder = pufferlib.pytorch.layer_init(
203208
nn.Linear(hidden_size, env.single_action_space.n), std=0.01)
204209
self.value = pufferlib.pytorch.layer_init(
@@ -213,9 +218,8 @@ def forward(self, x, state=None):
213218
return self.forward_eval(x, state)
214219

215220
def encode_observations(self, observations, state=None):
216-
observations = F.one_hot(observations.long(), 12).view(-1, 12*16).float()
217-
#observations = observations.float().view(-1, 16) / 11.0
218-
return self.encoder(observations)
221+
observations = F.one_hot(observations.long(), 16).view(-1, 16, 4, 4).float()
222+
return self.cnn(observations)
219223

220224
def decode_actions(self, hidden):
221225
action = self.decoder(hidden)

0 commit comments

Comments
 (0)