Skip to content

Commit c5c4742

Browse files
author
Joseph Suarez
committed
Asteroids fix
1 parent f609772 commit c5c4742

File tree

2 files changed

+30
-23
lines changed

2 files changed

+30
-23
lines changed

pufferlib/config/ocean/asteroids.ini

Lines changed: 6 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,13 @@ env_name = puffer_asteroids
44
policy_name = Policy
55
rnn_name = Recurrent
66

7+
[vec]
8+
num_envs = 8
9+
710
[env]
8-
# num_envs = 4096
9-
num_envs = 4
11+
num_envs = 1024
1012
size = 500
1113

1214
[train]
13-
total_timesteps = 20_000_000
14-
# minibatch_size = 32768
15-
minibatch_size = 128
15+
total_timesteps = 150_000_000
16+
minibatch_size = 32768

pufferlib/ocean/asteroids/asteroids.h

Lines changed: 24 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -76,6 +76,7 @@ typedef struct {
7676
int last_shot;
7777
int tick;
7878
int score;
79+
float episode_return;
7980
int frameskip;
8081
} Asteroids;
8182

@@ -306,6 +307,7 @@ void check_player_asteroid_collision(Asteroids *env) {
306307
dy = env->player_position.y - as->position.y;
307308
if (min_dist * min_dist > dx * dx + dy * dy) {
308309
env->terminals[0] = 1;
310+
env->rewards[0] = -1.0f;
309311
return;
310312
}
311313
}
@@ -372,10 +374,10 @@ void compute_observations(Asteroids *env) {
372374
}
373375

374376
void add_log(Asteroids *env) {
375-
env->log.perf += (env->rewards[0] > 0) ? 1 : 0;
376-
env->log.score += env->rewards[0];
377+
env->log.perf += env->score / 100.0f;
378+
env->log.score += env->score;
377379
env->log.episode_length += env->tick;
378-
env->log.episode_return += env->rewards[0];
380+
env->log.episode_return += env->episode_return;
379381
env->log.n++;
380382
}
381383

@@ -391,6 +393,7 @@ void c_reset(Asteroids *env) {
391393
env->asteroid_index = 0;
392394
env->tick = 0;
393395
env->score = 0;
396+
env->episode_return = 0;
394397
env->last_shot = 0;
395398
}
396399

@@ -456,6 +459,7 @@ void c_step(Asteroids *env) {
456459
step_frame(env, action);
457460
}
458461

462+
env->episode_return += env->rewards[0];
459463
if (env->terminals[0] == 1 || env->tick > MAX_TICK) {
460464
env->terminals[0] = 1;
461465
add_log(env);
@@ -467,6 +471,11 @@ void c_step(Asteroids *env) {
467471
compute_observations(env);
468472
}
469473

474+
const Color PUFF_RED = (Color){187, 0, 0, 255};
475+
const Color PUFF_CYAN = (Color){0, 187, 187, 255};
476+
const Color PUFF_WHITE = (Color){241, 241, 241, 241};
477+
const Color PUFF_BACKGROUND = (Color){6, 24, 24, 255};
478+
470479
void draw_player(Asteroids *env) {
471480
if (global_game_over_timer > 0)
472481
return;
@@ -498,21 +507,20 @@ void draw_player(Asteroids *env) {
498507
for (int i = 0; i < 8; i++)
499508
ps[i] = rotate_vector(ps[i], env->player_position, env->player_angle);
500509

501-
DrawLineV(ps[0], ps[2], RAYWHITE);
502-
DrawLineV(ps[1], ps[2], RAYWHITE);
510+
DrawLineV(ps[0], ps[2], PUFF_RED);
511+
DrawLineV(ps[1], ps[2], PUFF_RED);
503512

504-
DrawLineV(ps[3], ps[4], RAYWHITE);
513+
DrawLineV(ps[3], ps[4], PUFF_RED);
505514

506515
if (env->thruster_on) {
507-
DrawLineV(ps[5], ps[7], RAYWHITE);
508-
DrawLineV(ps[6], ps[7], RAYWHITE);
516+
DrawLineV(ps[5], ps[7], PUFF_RED);
517+
DrawLineV(ps[6], ps[7], PUFF_RED);
509518
}
510519
}
511520

512521
void draw_particles(Asteroids *env) {
513522
for (int i = 0; i < MAX_PARTICLES; i++) {
514-
DrawPixel(env->particles[i].position.x, env->particles[i].position.y,
515-
RAYWHITE);
523+
DrawCircle(env->particles[i].position.x, env->particles[i].position.y, 2, PUFF_RED);
516524
}
517525
}
518526

@@ -532,7 +540,7 @@ void draw_asteroids(Asteroids *env) {
532540
as.position.y + as.shape[v].y};
533541
Vector2 pos2 = {as.position.x + as.shape[next_v].x,
534542
as.position.y + as.shape[next_v].y};
535-
DrawLineV(pos1, pos2, (Color){245, 245, 245, 175});
543+
DrawLineV(pos1, pos2, PUFF_CYAN);
536544
}
537545
}
538546
}
@@ -561,14 +569,13 @@ void c_render(Asteroids *env) {
561569
}
562570

563571
BeginDrawing();
564-
ClearBackground(BLACK);
572+
ClearBackground(PUFF_BACKGROUND);
565573
draw_player(env);
566574
draw_particles(env);
567575
draw_asteroids(env);
568576

569-
DrawText(TextFormat("Score: %d", env->score), 10, 10, 20, RAYWHITE);
570-
DrawText(TextFormat("%d s", (int)(env->tick / 60)), env->size - 40, 10, 20,
571-
RAYWHITE);
577+
DrawText(TextFormat("Score: %d", env->score), 10, 10, 20, PUFF_WHITE);
578+
DrawText(TextFormat("%d s", (int)(env->tick / 60)), env->size - 40, 10, 20, PUFF_WHITE);
572579

573580
if (global_game_over_timer > 0) {
574581
const char *game_over_text = "GAME OVER";
@@ -579,9 +586,8 @@ void c_render(Asteroids *env) {
579586
float alpha = (float)global_game_over_timer / 120.0f;
580587
int alpha_value = (int)(alpha * 255);
581588

582-
Color text_color = ColorAlpha(RED, alpha_value);
583-
DrawTextEx(GetFontDefault(), game_over_text, (Vector2){x, y}, 40, 2,
584-
text_color);
589+
Color text_color = ColorAlpha(PUFF_RED, alpha_value);
590+
DrawTextEx(GetFontDefault(), game_over_text, (Vector2){x, y}, 40, 2, text_color);
585591
}
586592

587593
EndDrawing();

0 commit comments

Comments
 (0)