Skip to content

Commit 6c5a8db

Browse files
committed
add max ticks, modify log
1 parent ba34433 commit 6c5a8db

File tree

4 files changed

+61
-35
lines changed

4 files changed

+61
-35
lines changed

pufferlib/config/ocean/g2048.ini

Lines changed: 35 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -5,11 +5,11 @@ policy_name = Policy
55
rnn_name = Recurrent
66

77
[policy]
8-
hidden_size = 1024
8+
hidden_size = 128
99

1010
[rnn]
11-
input_size = 1024
12-
hidden_size = 1024
11+
input_size = 128
12+
hidden_size = 128
1313

1414
[vec]
1515
num_envs = 4
@@ -18,22 +18,36 @@ num_envs = 4
1818
num_envs = 4096
1919

2020
[train]
21-
total_timesteps = 5_000_000_000
22-
adam_beta1 = 0.982603624444803
23-
adam_beta2 = 0.982603624444803
24-
adam_eps = 3.2888696338626164e-11
21+
total_timesteps = 3_000_000_000
22+
anneal_lr = True
23+
batch_size = auto
2524
bptt_horizon = 64
26-
clip_coef = 0.2709219986085283
27-
ent_coef = 0.09221187601118314
28-
gae_lambda = 0.5999999999999999
29-
gamma = 0.9913033082924563
30-
#learning_rate = 0.0032402460796988127
31-
learning_rate = 0.001370087925623787
32-
max_grad_norm = 3.382578348055827
33-
minibatch_size = 32768
34-
prio_alpha = 0.09999999999999998
35-
prio_beta0 = 0.941336023531629
36-
vf_clip_coef = 0.3229933703598912
37-
vf_coef = 3.591594736259073
38-
vtrace_c_clip = 1.405090934486193
39-
vtrace_rho_clip = 0.836535302835556
25+
minibatch_size = 65536
26+
27+
adam_beta1 = 0.99
28+
adam_beta2 = 0.96
29+
adam_eps = 1.0e-10
30+
clip_coef = 0.1
31+
ent_coef = 0.02
32+
gae_lambda = 0.6
33+
gamma = 0.985
34+
learning_rate = 0.003
35+
max_grad_norm = 1.0
36+
prio_alpha = 0.99
37+
prio_beta0 = 0.40
38+
vf_clip_coef = 0.1
39+
vf_coef = 2.0
40+
vtrace_c_clip = 4.3
41+
vtrace_rho_clip = 1.6
42+
43+
44+
[sweep]
45+
metric = score
46+
goal = maximize
47+
48+
[sweep.train.gae_lambda]
49+
distribution = logit_normal
50+
min = 0.01
51+
mean = 0.6
52+
max = 0.995
53+
scale = auto

pufferlib/ocean/g2048/binding.c

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
1-
#include "2048.h"
1+
#include "g2048.h"
22

33
#define Env Game
44
#include "../env_binding.h"
55

6-
// 2048.h does not have a 'size' field, so my_init can just return 0
6+
// g2048.h does not have a 'size' field, so my_init can just return 0
77
static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
88
// No custom initialization needed for 2048
99
return 0;
@@ -12,6 +12,7 @@ static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
1212
static int my_log(PyObject* dict, Log* log) {
1313
assign_to_dict(dict, "perf", log->perf);
1414
assign_to_dict(dict, "score", log->score);
15+
assign_to_dict(dict, "max_tile", log->max_tile);
1516
assign_to_dict(dict, "episode_return", log->episode_return);
1617
assign_to_dict(dict, "episode_length", log->episode_length);
1718
return 0;

pufferlib/ocean/g2048/g2048.c

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
#include "2048.h"
1+
#include "g2048.h"
22
#include "puffernet.h"
33

44
int main() {
Lines changed: 22 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -5,22 +5,28 @@
55
#include <math.h>
66
#include <string.h>
77
#include "raylib.h"
8+
#define max(a, b) (((a) > (b)) ? (a) : (b))
89

910
#define SIZE 4
1011
#define EMPTY 0
1112
#define UP 1
1213
#define DOWN 2
1314
#define LEFT 3
1415
#define RIGHT 4
16+
#define BASE_MAX_TICKS 2000
1517

1618
// Precomputed constants
1719
#define REWARD_MULTIPLIER 0.0625f
1820
#define INVALID_MOVE_PENALTY -0.05f
1921
#define GAME_OVER_PENALTY -1.0f
2022

23+
// To normalize perf from 0 to 1. Update when beaten.
24+
#define OBSERVED_MAX_SCORE 100000.0f
25+
2126
typedef struct {
2227
float perf;
2328
float score;
29+
float max_tile;
2430
float episode_return;
2531
float episode_length;
2632
float n;
@@ -36,6 +42,7 @@ typedef struct {
3642
int tick;
3743
unsigned char grid[SIZE][SIZE];
3844
float episode_reward; // Accumulate episode reward
45+
int max_episode_ticks; // Dynamic max_ticks based on score
3946

4047
// Cached values to avoid recomputation
4148
int empty_count;
@@ -93,8 +100,7 @@ static inline void update_empty_count(Game* game) {
93100
game->empty_count = count;
94101
}
95102

96-
// Optimized score calculation
97-
static inline unsigned char calc_score(Game* game) {
103+
static inline unsigned char get_max_tile(Game* game) {
98104
unsigned char max_tile = 0;
99105
// Unroll loop for better performance
100106
for (int i = 0; i < SIZE; i++) {
@@ -108,9 +114,10 @@ static inline unsigned char calc_score(Game* game) {
108114
}
109115

110116
void add_log(Game* game) {
111-
unsigned char s = calc_score(game);
112-
game->log.score = (float)(1 << s);
113-
game->log.perf += ((float)s) * 0.0909f;
117+
unsigned char s = get_max_tile(game);
118+
game->log.max_tile += (float)(1 << s);
119+
game->log.score += (float)game->score;
120+
game->log.perf += (float)game->score / OBSERVED_MAX_SCORE;
114121
game->log.episode_length += game->tick;
115122
game->log.episode_return += game->episode_reward;
116123
game->log.n += 1;
@@ -129,6 +136,7 @@ void c_reset(Game* game) {
129136
game->empty_count = SIZE * SIZE;
130137
game->game_over_cached = false;
131138
game->grid_changed = true;
139+
game->max_episode_ticks = BASE_MAX_TICKS;
132140

133141
if (game->terminals) game->terminals[0] = 0;
134142

@@ -251,9 +259,7 @@ bool move(Game* game, int direction, float* reward, float* score_increase) {
251259
}
252260
}
253261

254-
if (!moved) {
255-
*reward = INVALID_MOVE_PENALTY;
256-
} else {
262+
if (moved) {
257263
game->grid_changed = true;
258264
game->game_over_cached = false; // Invalidate cache
259265
}
@@ -306,11 +312,16 @@ void c_step(Game* game) {
306312
add_random_tile(game);
307313
game->score += score_add;
308314
update_empty_count(game); // Update after adding tile
315+
// This is to limit infinite invalid moves during eval
316+
game->max_episode_ticks = max(BASE_MAX_TICKS, game->score / 20);
317+
} else {
318+
reward = INVALID_MOVE_PENALTY;
309319
}
310-
320+
311321
bool game_over = is_game_over(game);
312-
game->terminals[0] = game_over ? 1 : 0;
313-
322+
bool max_ticks_reached = game->tick >= game->max_episode_ticks;
323+
game->terminals[0] = (game_over || max_ticks_reached) ? 1 : 0;
324+
314325
if (game_over) {
315326
reward = GAME_OVER_PENALTY;
316327
}

0 commit comments

Comments
 (0)