Skip to content

Commit a8a5e33

Browse files
authored
Merge pull request #404 from kywch/h2048
Modified G2048 for better sweeps
2 parents b02b11c + 026119d commit a8a5e33

File tree

6 files changed

+139
-104
lines changed

6 files changed

+139
-104
lines changed

pufferlib/config/ocean/g2048.ini

Lines changed: 48 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ policy_name = Policy
55
rnn_name = Recurrent
66

77
[policy]
8-
hidden_size = 128
8+
hidden_size = 256
99

1010
[rnn]
11-
input_size = 128
12-
hidden_size = 128
11+
input_size = 256
12+
hidden_size = 256
1313

1414
[vec]
1515
num_envs = 4
@@ -18,23 +18,51 @@ num_envs = 4
1818
num_envs = 4096
1919

2020
[train]
21-
total_timesteps = 5_000_000_000
22-
adam_beta1 = 0.9529488439604378
23-
adam_beta2 = 0.9993901829477296
24-
adam_eps = 2.745365927413118e-7
21+
# https://wandb.ai/kywch/pufferlib/runs/n8xml0u9?nw=nwuserkywch
22+
total_timesteps = 3_000_000_000
23+
anneal_lr = True
24+
batch_size = auto
2525
bptt_horizon = 64
26-
clip_coef = 0.596573170393339
27-
ent_coef = 0.02107417730003862
28-
gae_lambda = 0.9940613415815854
29-
gamma = 0.9889857974154952
30-
#learning_rate = 0.0032402460796988127
26+
minibatch_size = 65536
27+
28+
adam_beta1 = 0.99
29+
adam_beta2 = 0.96
30+
adam_eps = 1.0e-10
31+
clip_coef = 0.1
32+
ent_coef = 0.02
33+
gae_lambda = 0.6
34+
gamma = 0.985
3135
learning_rate = 0.001
32-
max_grad_norm = 1.0752406726589745
33-
minibatch_size = 16384
34-
prio_alpha = 0.25297099593586336
35-
prio_beta0 = 0.940606268942572
36+
max_grad_norm = 1.0
37+
prio_alpha = 0.99
38+
prio_beta0 = 0.40
3639
vf_clip_coef = 0.1
37-
vf_coef = 1.6362878279900643
38-
vtrace_c_clip = 0
39-
vtrace_rho_clip = 1.2917509971869054
40-
anneal_lr = False
40+
vf_coef = 2.0
41+
vtrace_c_clip = 4.3
42+
vtrace_rho_clip = 1.6
43+
44+
45+
[sweep]
46+
metric = score
47+
goal = maximize
48+
49+
[sweep.train.total_timesteps]
50+
distribution = log_normal
51+
min = 3e8
52+
max = 1e10
53+
mean = 1e9
54+
scale = time
55+
56+
[sweep.train.learning_rate]
57+
distribution = log_normal
58+
min = 0.00001
59+
mean = 0.001
60+
max = 0.1
61+
scale = 0.5
62+
63+
[sweep.train.gae_lambda]
64+
distribution = logit_normal
65+
min = 0.01
66+
mean = 0.6
67+
max = 0.995
68+
scale = auto

pufferlib/ocean/g2048/binding.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
#include "2048.h"
1+
#include "g2048.h"
22

33
#define Env Game
44
#include "../env_binding.h"
55

6-
// 2048.h does not have a 'size' field, so my_init can just return 0
6+
// g2048.h does not have a 'size' field, so my_init can just return 0
77
static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
88
// No custom initialization needed for 2048
99
return 0;
@@ -12,6 +12,7 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
1212
static int my_log(PyObject* dict, Log* log) {
1313
assign_to_dict(dict, "perf", log->perf);
1414
assign_to_dict(dict, "score", log->score);
15+
assign_to_dict(dict, "merge_score", log->merge_score);
1516
assign_to_dict(dict, "episode_return", log->episode_return);
1617
assign_to_dict(dict, "episode_length", log->episode_length);
1718
return 0;

pufferlib/ocean/g2048/g2048.c

Lines changed: 36 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,25 @@
1-
#include "2048.h"
1+
#include "g2048.h"
22
#include "puffernet.h"
33

4+
// Network with hidden size 256. Should go to puffernet
5+
LinearLSTM* make_linearlstm_256(Weights* weights, int num_agents, int input_dim, int logit_sizes[], int num_actions) {
6+
LinearLSTM* net = calloc(1, sizeof(LinearLSTM));
7+
net->num_agents = num_agents;
8+
net->obs = calloc(num_agents*input_dim, sizeof(float));
9+
int hidden_dim = 256;
10+
net->encoder = make_linear(weights, num_agents, input_dim, hidden_dim);
11+
net->gelu1 = make_gelu(num_agents, hidden_dim);
12+
int atn_sum = 0;
13+
for (int i = 0; i < num_actions; i++) {
14+
atn_sum += logit_sizes[i];
15+
}
16+
net->actor = make_linear(weights, num_agents, hidden_dim, atn_sum);
17+
net->value_fn = make_linear(weights, num_agents, hidden_dim, 1);
18+
net->lstm = make_lstm(weights, num_agents, hidden_dim, hidden_dim);
19+
net->multidiscrete = make_multidiscrete(num_agents, logit_sizes, num_actions);
20+
return net;
21+
}
22+
423
int main() {
524
srand(time(NULL));
625
Game env;
@@ -14,26 +33,27 @@ int main() {
1433
env.actions = actions;
1534
env.rewards = rewards;
1635

17-
Weights* weights = load_weights("resources/g2048/g2048_weights.bin", 134917);
36+
Weights* weights = load_weights("resources/g2048/g2048_weights.bin", 531973);
1837
int logit_sizes[1] = {4};
19-
LinearLSTM* net = make_linearlstm(weights, 1, 16, logit_sizes, 1);
38+
LinearLSTM* net = make_linearlstm_256(weights, 1, 16, logit_sizes, 1);
2039
c_reset(&env);
2140
c_render(&env);
2241

2342
// Main game loop
2443
int frame = 0;
44+
int action = -1;
2545
while (!WindowShouldClose()) {
2646
c_render(&env);
2747
frame++;
28-
29-
int action = 0;
48+
3049
if (IsKeyDown(KEY_LEFT_SHIFT)) {
31-
if (IsKeyPressed(KEY_W) || IsKeyPressed(KEY_UP)) action = UP;
32-
else if (IsKeyPressed(KEY_S) || IsKeyPressed(KEY_DOWN)) action = DOWN;
33-
else if (IsKeyPressed(KEY_A) || IsKeyPressed(KEY_LEFT)) action = LEFT;
34-
else if (IsKeyPressed(KEY_D) || IsKeyPressed(KEY_RIGHT)) action = RIGHT;
50+
action = -1;
51+
if (IsKeyDown(KEY_W) || IsKeyDown(KEY_UP)) action = UP;
52+
else if (IsKeyDown(KEY_S) || IsKeyDown(KEY_DOWN)) action = DOWN;
53+
else if (IsKeyDown(KEY_A) || IsKeyDown(KEY_LEFT)) action = LEFT;
54+
else if (IsKeyDown(KEY_D) || IsKeyDown(KEY_RIGHT)) action = RIGHT;
3555
env.actions[0] = action - 1;
36-
} else if (frame % 10 != 0) {
56+
} else if (frame % 1 != 0) {
3757
continue;
3858
} else {
3959
action = 1;
@@ -43,9 +63,14 @@ int main() {
4363
forward_linearlstm(net, net->obs, env.actions);
4464
}
4565

46-
if (action != 0) {
66+
if (action > 0) {
4767
c_step(&env);
4868
}
69+
70+
if (IsKeyDown(KEY_LEFT_SHIFT) && action > 0) {
71+
// Don't need to be super reactive
72+
WaitTime(0.1);
73+
}
4974
}
5075

5176
free_linearlstm(net);
Lines changed: 52 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,28 @@
55
#include <math.h>
66
#include <string.h>
77
#include "raylib.h"
8+
#define max(a, b) (((a) > (b)) ? (a) : (b))
89

910
#define SIZE 4
1011
#define EMPTY 0
1112
#define UP 1
1213
#define DOWN 2
1314
#define LEFT 3
1415
#define RIGHT 4
16+
#define BASE_MAX_TICKS 2000
1517

1618
// Precomputed constants
17-
#define REWARD_MULTIPLIER 0.09090909f
19+
#define REWARD_MULTIPLIER 0.0625f
1820
#define INVALID_MOVE_PENALTY -0.05f
1921
#define GAME_OVER_PENALTY -1.0f
2022

23+
// To normalize perf from 0 to 1. Reachable with hidden size 256.
24+
#define OBSERVED_MAX_TILE 4096.0f
25+
2126
typedef struct {
2227
float perf;
2328
float score;
29+
float merge_score;
2430
float episode_return;
2531
float episode_length;
2632
float n;
@@ -36,6 +42,8 @@ typedef struct {
3642
int tick;
3743
unsigned char grid[SIZE][SIZE];
3844
float episode_reward; // Accumulate episode reward
45+
int moves_made;
46+
int max_episode_ticks; // Dynamic max_ticks based on score
3947

4048
// Cached values to avoid recomputation
4149
int empty_count;
@@ -93,9 +101,24 @@ static inline void update_empty_count(Game* game) {
93101
game->empty_count = count;
94102
}
95103

104+
static inline unsigned char get_max_tile(Game* game) {
105+
unsigned char max_tile = 0;
106+
// Unroll loop for better performance
107+
for (int i = 0; i < SIZE; i++) {
108+
for (int j = 0; j < SIZE; j++) {
109+
if (game->grid[i][j] > max_tile) {
110+
max_tile = game->grid[i][j];
111+
}
112+
}
113+
}
114+
return max_tile;
115+
}
116+
96117
void add_log(Game* game) {
97-
game->log.score = (float)(1 << game->score);
98-
game->log.perf += ((float)game->score) * REWARD_MULTIPLIER;
118+
unsigned char s = get_max_tile(game);
119+
game->log.score += (float)(1 << s);
120+
game->log.perf += (float)(1 << s) / OBSERVED_MAX_TILE;
121+
game->log.merge_score += (float)game->score;
99122
game->log.episode_length += game->tick;
100123
game->log.episode_return += game->episode_reward;
101124
game->log.n += 1;
@@ -114,6 +137,8 @@ void c_reset(Game* game) {
114137
game->empty_count = SIZE * SIZE;
115138
game->game_over_cached = false;
116139
game->grid_changed = true;
140+
game->moves_made = 0;
141+
game->max_episode_ticks = BASE_MAX_TICKS;
117142

118143
if (game->terminals) game->terminals[0] = 0;
119144

@@ -153,6 +178,7 @@ void add_random_tile(Game* game) {
153178
if (chosen_pos >= 0) {
154179
int i = chosen_pos / SIZE;
155180
int j = chosen_pos % SIZE;
181+
// Implement the 90% 2, 10% 4 rule
156182
game->grid[i][j] = (rand() % 10 == 0) ? 2 : 1;
157183
game->empty_count--;
158184
game->grid_changed = true;
@@ -162,7 +188,7 @@ void add_random_tile(Game* game) {
162188
}
163189

164190
// Optimized slide and merge with fewer memory operations
165-
static inline bool slide_and_merge(unsigned char* row, float* reward) {
191+
static inline bool slide_and_merge(unsigned char* row, float* reward, float* score_increase) {
166192
bool moved = false;
167193
int write_pos = 0;
168194

@@ -183,6 +209,7 @@ static inline bool slide_and_merge(unsigned char* row, float* reward) {
183209
if (row[i] != EMPTY && row[i] == row[i + 1]) {
184210
row[i]++;
185211
*reward += ((float)row[i]) * REWARD_MULTIPLIER;
212+
*score_increase += (float)(1 << (int)row[i]);
186213
// Shift remaining elements left
187214
for (int j = i + 1; j < SIZE - 1; j++) {
188215
row[j] = row[j + 1];
@@ -195,7 +222,7 @@ static inline bool slide_and_merge(unsigned char* row, float* reward) {
195222
return moved;
196223
}
197224

198-
bool move(Game* game, int direction, float* reward) {
225+
bool move(Game* game, int direction, float* reward, float* score_increase) {
199226
bool moved = false;
200227
unsigned char temp[SIZE];
201228

@@ -207,7 +234,7 @@ bool move(Game* game, int direction, float* reward) {
207234
temp[i] = game->grid[idx][col];
208235
}
209236

210-
if (slide_and_merge(temp, reward)) {
237+
if (slide_and_merge(temp, reward, score_increase)) {
211238
moved = true;
212239
// Write back column
213240
for (int i = 0; i < SIZE; i++) {
@@ -224,7 +251,7 @@ bool move(Game* game, int direction, float* reward) {
224251
temp[i] = game->grid[row][idx];
225252
}
226253

227-
if (slide_and_merge(temp, reward)) {
254+
if (slide_and_merge(temp, reward, score_increase)) {
228255
moved = true;
229256
// Write back row
230257
for (int i = 0; i < SIZE; i++) {
@@ -235,9 +262,7 @@ bool move(Game* game, int direction, float* reward) {
235262
}
236263
}
237264

238-
if (!moved) {
239-
*reward = INVALID_MOVE_PENALTY;
240-
} else {
265+
if (moved) {
241266
game->grid_changed = true;
242267
game->game_over_cached = false; // Invalidate cache
243268
}
@@ -280,34 +305,28 @@ bool is_game_over(Game* game) {
280305
return true;
281306
}
282307

283-
// Optimized score calculation
284-
static inline unsigned char calc_score(Game* game) {
285-
unsigned char max_tile = 0;
286-
// Unroll loop for better performance
287-
for (int i = 0; i < SIZE; i++) {
288-
for (int j = 0; j < SIZE; j++) {
289-
if (game->grid[i][j] > max_tile) {
290-
max_tile = game->grid[i][j];
291-
}
292-
}
293-
}
294-
return max_tile;
295-
}
296-
297308
void c_step(Game* game) {
298309
float reward = 0.0f;
299-
bool did_move = move(game, game->actions[0] + 1, &reward);
310+
float score_add = 0.0f;
311+
bool did_move = move(game, game->actions[0] + 1, &reward, &score_add);
300312
game->tick++;
301313

302314
if (did_move) {
315+
game->moves_made++;
303316
add_random_tile(game);
304-
game->score = calc_score(game);
317+
game->score += score_add;
305318
update_empty_count(game); // Update after adding tile
319+
// This is to limit infinite invalid moves during eval
320+
// Don't need to be tight. Don't need to show to user?
321+
game->max_episode_ticks = max(BASE_MAX_TICKS, game->score / 10);
322+
} else {
323+
reward = INVALID_MOVE_PENALTY;
306324
}
307-
325+
308326
bool game_over = is_game_over(game);
309-
game->terminals[0] = game_over ? 1 : 0;
310-
327+
bool max_ticks_reached = game->tick >= game->max_episode_ticks;
328+
game->terminals[0] = (game_over || max_ticks_reached) ? 1 : 0;
329+
311330
if (game_over) {
312331
reward = GAME_OVER_PENALTY;
313332
}
@@ -369,8 +388,11 @@ void c_render(Game* game) {
369388
}
370389

371390
// Draw score (format once per frame)
372-
snprintf(score_text, sizeof(score_text), "Score: %d", 1 << game->score);
391+
snprintf(score_text, sizeof(score_text), "Score: %d", game->score);
373392
DrawText(score_text, 10, px * SIZE + 10, 24, PUFF_WHITE);
393+
394+
snprintf(score_text, sizeof(score_text), "Moves: %d", game->moves_made);
395+
DrawText(score_text, 210, px * SIZE + 10, 24, PUFF_WHITE);
374396

375397
EndDrawing();
376398
}

0 commit comments

Comments
 (0)