Skip to content

Commit 1267d52

Browse files
authored
Merge pull request #397 from kywch/tetris
Add levels, garbage lines, 7-bag system to tetris
2 parents a8a5e33 + 4a9843f commit 1267d52

File tree

10 files changed

+425
-170
lines changed

10 files changed

+425
-170
lines changed

pufferlib/config/ocean/tetris.ini

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -8,11 +8,24 @@ rnn_name = Recurrent
88
num_envs = 8
99

1010
[env]
11-
num_envs = 1024
12-
deck_size = 3
11+
num_envs = 2048
12+
n_rows = 20
13+
n_cols = 10
14+
use_deck_obs = True
15+
n_init_garbage = 4
16+
# This is experimental. Sometimes it works.
17+
n_noise_obs = 0
18+
19+
[policy]
20+
hidden_size = 256
21+
22+
[rnn]
23+
input_size = 256
24+
hidden_size = 256
1325

1426
[train]
15-
total_timesteps = 5_000_000_000
27+
# https://wandb.ai/kywch/pufferlib/runs/era6a8p6?nw=nwuserkywch
28+
total_timesteps = 3_000_000_000
1629
batch_size = auto
1730
bptt_horizon = 64
1831
minibatch_size = 65536

pufferlib/ocean/tetris/binding.c

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,9 @@
66
static int my_init(Env* env, PyObject* args, PyObject* kwargs) {
77
env->n_rows = unpack(kwargs, "n_rows");
88
env->n_cols = unpack(kwargs, "n_cols");
9-
env->deck_size = unpack(kwargs, "deck_size");
9+
env->use_deck_obs = unpack(kwargs, "use_deck_obs");
10+
env->n_noise_obs = unpack(kwargs, "n_noise_obs");
11+
env->n_init_garbage = unpack(kwargs, "n_init_garbage");
1012
init(env);
1113
return 0;
1214
}
@@ -18,9 +20,13 @@ static int my_log(PyObject* dict, Log* log) {
1820
assign_to_dict(dict, "ep_return", log->ep_return);
1921
assign_to_dict(dict, "avg_combo", log->avg_combo);
2022
assign_to_dict(dict, "lines_deleted", log->lines_deleted);
23+
assign_to_dict(dict, "game_level", log->game_level);
24+
assign_to_dict(dict, "ticks_per_line", log->ticks_per_line);
2125

22-
assign_to_dict(dict, "atn_frac_soft_drop", log->atn_frac_soft_drop);
26+
// assign_to_dict(dict, "atn_frac_soft_drop", log->atn_frac_soft_drop);
2327
assign_to_dict(dict, "atn_frac_hard_drop", log->atn_frac_hard_drop);
2428
assign_to_dict(dict, "atn_frac_rotate", log->atn_frac_rotate);
29+
assign_to_dict(dict, "atn_frac_hold", log->atn_frac_hold);
30+
2531
return 0;
2632
}

pufferlib/ocean/tetris/tetris.c

Lines changed: 69 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -4,48 +4,95 @@
44
#define min(a, b) (((a) < (b)) ? (a) : (b))
55
#define max(a, b) (((a) > (b)) ? (a) : (b))
66

7+
// Network with hidden size 256. Should go to puffernet
8+
LinearLSTM* make_linearlstm_256(Weights* weights, int num_agents, int input_dim, int logit_sizes[], int num_actions) {
9+
LinearLSTM* net = calloc(1, sizeof(LinearLSTM));
10+
net->num_agents = num_agents;
11+
net->obs = calloc(num_agents*input_dim, sizeof(float));
12+
int hidden_dim = 256;
13+
net->encoder = make_linear(weights, num_agents, input_dim, hidden_dim);
14+
net->gelu1 = make_gelu(num_agents, hidden_dim);
15+
int atn_sum = 0;
16+
for (int i = 0; i < num_actions; i++) {
17+
atn_sum += logit_sizes[i];
18+
}
19+
net->actor = make_linear(weights, num_agents, hidden_dim, atn_sum);
20+
net->value_fn = make_linear(weights, num_agents, hidden_dim, 1);
21+
net->lstm = make_lstm(weights, num_agents, hidden_dim, hidden_dim);
22+
net->multidiscrete = make_multidiscrete(num_agents, logit_sizes, num_actions);
23+
return net;
24+
}
25+
26+
727
void demo() {
828
Tetris env = {
929
.n_rows = 20,
1030
.n_cols = 10,
11-
.deck_size=3,
31+
.use_deck_obs = true,
32+
.n_noise_obs = 0,
33+
.n_init_garbage = 0,
1234
};
1335
allocate(&env);
1436
env.client = make_client(&env);
1537
c_reset(&env);
1638

17-
Weights* weights = load_weights("resources/tetris/tetris_weights.bin", 163208);
39+
Weights* weights = load_weights("resources/tetris/tetris_weights.bin", 588552);
1840
int logit_sizes[1] = {7};
19-
LinearLSTM* net = make_linearlstm(weights, 1, 234, logit_sizes, 1);
41+
LinearLSTM* net = make_linearlstm_256(weights, 1, 234, logit_sizes, 1);
2042

43+
// State tracking for single-press actions to avoid using IsKeyPressed
44+
// because IsKeyPressed doesn't work well in web browsers
45+
static bool rotate_key_was_down = false;
46+
static bool hard_drop_key_was_down = false;
47+
static bool swap_key_was_down = false;
48+
49+
int frame = 0;
50+
env.actions[0] = 0;
2151
while (!WindowShouldClose()) {
52+
bool process_logic = true;
53+
frame++;
54+
2255
if (IsKeyDown(KEY_LEFT_SHIFT)) {
23-
if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)){
24-
env.actions[0] = 1;
25-
}
26-
if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)){
27-
env.actions[0] = 2;
28-
}
29-
if (IsKeyPressed(KEY_UP) || IsKeyDown(KEY_W)) {
30-
env.actions[0] = 3;
31-
}
32-
if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) {
33-
env.actions[0] = 4;
34-
}
35-
if (IsKeyPressed(KEY_SPACE)) {
36-
env.actions[0] = 5;
37-
}
38-
if (IsKeyPressed(KEY_C)) {
39-
env.actions[0] = 6;
56+
if (frame % 3 != 0) {
57+
// This effectively slows down the client by 3x
58+
process_logic = false;
59+
} else {
60+
// Use KeyDown for left, right, down to allow continuous input
61+
// Though, IsKeyDown can overshoot ...
62+
if (IsKeyDown(KEY_LEFT) || IsKeyDown(KEY_A)) {
63+
env.actions[0] = 1;
64+
} else if (IsKeyDown(KEY_RIGHT) || IsKeyDown(KEY_D)) {
65+
env.actions[0] = 2;
66+
} else if (IsKeyDown(KEY_DOWN) || IsKeyDown(KEY_S)) {
67+
env.actions[0] = 4; // Soft drop
68+
}
69+
// Manual state tracking for single-press actions, mutually exclusive
70+
else if ((IsKeyDown(KEY_UP) || IsKeyDown(KEY_W)) && !rotate_key_was_down) {
71+
env.actions[0] = 3; // Rotate
72+
} else if (IsKeyDown(KEY_SPACE) && !hard_drop_key_was_down) {
73+
env.actions[0] = 5; // Hard drop
74+
} else if (IsKeyDown(KEY_C) && !swap_key_was_down) {
75+
env.actions[0] = 6; // Swap
76+
}
4077
}
4178
} else {
4279
forward_linearlstm(net, env.observations, env.actions);
4380
}
4481

45-
c_step(&env);
46-
env.actions[0] = 0;
82+
if (process_logic) {
83+
// Update key state flags after processing actions for the frame
84+
rotate_key_was_down = IsKeyDown(KEY_UP) || IsKeyDown(KEY_W);
85+
hard_drop_key_was_down = IsKeyDown(KEY_SPACE);
86+
swap_key_was_down = IsKeyDown(KEY_C);
87+
88+
c_step(&env);
89+
90+
env.actions[0] = 0;
91+
}
92+
4793
c_render(&env);
4894
}
95+
4996
free_linearlstm(net);
5097
free_allocated(&env);
5198
close_client(env.client);

0 commit comments

Comments
 (0)