diff --git a/pufferlib/ocean/breakout/breakout.h b/pufferlib/ocean/breakout/breakout.h index 8366b5065..0a1785ef3 100644 --- a/pufferlib/ocean/breakout/breakout.h +++ b/pufferlib/ocean/breakout/breakout.h @@ -50,12 +50,13 @@ typedef struct Breakout { float ball_y; float ball_vx; float ball_vy; - float* brick_x; - float* brick_y; - float* brick_states; + float *brick_x; + float *brick_y; + float *brick_state; int balls_fired; float initial_paddle_width; float paddle_width; + float inv_paddle_width; float paddle_height; float paddle_speed; float ball_speed; @@ -64,6 +65,8 @@ typedef struct Breakout { int hits; int width; int height; + float inv_height; + float inv_width; int num_bricks; int brick_rows; int brick_cols; @@ -71,6 +74,8 @@ typedef struct Breakout { int ball_height; int brick_width; int brick_height; + float inv_brick_height; + float inv_brick_width; int num_balls; int max_score; int half_max_score; @@ -108,10 +113,14 @@ void init(Breakout* env) { env->tick = 0; env->num_bricks = env->brick_rows * env->brick_cols; assert(env->num_bricks > 0); + env->inv_height = 1.0f / env->height; + env->inv_width = 1.0f / env->width; + env->inv_brick_height = 1.0f / env->brick_height; + env->inv_brick_width = 1.0f / env->brick_width; env->brick_x = (float*)calloc(env->num_bricks, sizeof(float)); env->brick_y = (float*)calloc(env->num_bricks, sizeof(float)); - env->brick_states = (float*)calloc(env->num_bricks, sizeof(float)); + env->brick_state = (float*)calloc(env->num_bricks, sizeof(float)); env->num_balls = -1; generate_brick_positions(env); } @@ -127,7 +136,7 @@ void allocate(Breakout* env) { void c_close(Breakout* env) { free(env->brick_x); free(env->brick_y); - free(env->brick_states); + free(env->brick_state); } void free_allocated(Breakout* env) { @@ -147,10 +156,10 @@ void add_log(Breakout* env) { } void compute_observations(Breakout* env) { - env->observations[0] = env->paddle_x / env->width; - env->observations[1] = env->paddle_y / env->height; - env->observations[2] = env->ball_x / env->width; - env->observations[3] = env->ball_y / env->height; + env->observations[0] = env->paddle_x * env->inv_width; + env->observations[1] = env->paddle_y * env->inv_height; + env->observations[2] = env->ball_x * env->inv_width; + env->observations[3] = env->ball_y * env->inv_height; env->observations[4] = env->ball_vx / 512.0f; env->observations[5] = env->ball_vy / 512.0f; env->observations[6] = env->balls_fired / 5.0f; @@ -158,7 +167,7 @@ void compute_observations(Breakout* env) { env->observations[8] = env->num_balls / 5.0f; env->observations[9] = env->paddle_width / (2.0f * HALF_PADDLE_WIDTH); for (int i = 0; i < env->num_bricks; i++) { - env->observations[10 + i] = env->brick_states[i]; + env->observations[10 + i] = env->brick_state[i]; } } @@ -166,15 +175,17 @@ void compute_observations(Breakout* env) { // with a moving line segment (x+vx*t,y+vy*t) to (x+vx*t,y+vy*t+h). static inline bool calc_vline_collision(float xw, float yw, float hw, float x, float y, float vx, float vy, float h, CollisionInfo* col) { - float t_new = (xw - x) / vx; - float topmost = fmin(yw + hw, y + h + vy * t_new); - float botmost = fmax(yw, y + vy * t_new); - float overlap_new = topmost - botmost; + const float t_new = (xw - x) / vx; + if (t_new <= 0.0f || t_new > 1.0f) return false; + + const float topmost = fmin(yw + hw, y + h + vy * t_new); + const float botmost = fmax(yw, y + vy * t_new); + const float overlap_new = topmost - botmost; + if (overlap_new <= 0.0f) return false; // Collision finds the smallest time of collision with the greatest overlap // between the ball and the wall. - if (overlap_new > 0.0f && t_new > 0.0f && t_new <= 1.0f && - (t_new < col->t || (t_new == col->t && overlap_new > col->overlap))) { + if ((t_new < col->t || (t_new == col->t && overlap_new > col->overlap))) { col->t = t_new; col->overlap = overlap_new; col->x = xw; @@ -187,10 +198,10 @@ static inline bool calc_vline_collision(float xw, float yw, float hw, float x, } static inline bool calc_hline_collision(float xw, float yw, float ww, float x, float y, float vx, float vy, float w, CollisionInfo* col) { - float t_new = (yw - y) / vy; - float rightmost = fminf(xw + ww, x + w + vx * t_new); - float leftmost = fmaxf(xw, x + vx * t_new); - float overlap_new = rightmost - leftmost; + const float t_new = (yw - y) / vy; + const float rightmost = fminf(xw + ww, x + w + vx * t_new); + const float leftmost = fmaxf(xw, x + vx * t_new); + const float overlap_new = rightmost - leftmost; // Collision finds the smallest time of collision with the greatest overlap between the ball and the wall. if (overlap_new > 0.0f && t_new > 0.0f && t_new <= 1.0f && @@ -205,13 +216,18 @@ static inline bool calc_hline_collision(float xw, float yw, float ww, } return false; } -static inline void calc_brick_collision(Breakout* env, int idx, +static inline void calc_brick_collision(Breakout* env, int idx, + float brick_width, float brick_height, float ball_x, float ball_y, + float ball_vx, float ball_vy, float ball_width, float ball_height, CollisionInfo* collision_info) { bool collision = false; // Brick left wall collides with ball right side - if (env->ball_vx > 0) { - if (calc_vline_collision(env->brick_x[idx], env->brick_y[idx], env->brick_height, - env->ball_x + env->ball_width, env->ball_y, env->ball_vx, env->ball_vy, env->ball_height, collision_info)) { + const float x = env->brick_x[idx]; + const float y = env->brick_y[idx]; + + if (env->ball_vx > 0.0f) { + if (calc_vline_collision(x, y, brick_height, ball_x + ball_width, ball_y, ball_vx, ball_vy, ball_height, + collision_info)) { collision = true; collision_info->x -= env->ball_width; } @@ -219,16 +235,16 @@ static inline void calc_brick_collision(Breakout* env, int idx, // Brick right wall collides with ball left side if (env->ball_vx < 0) { - if (calc_vline_collision(env->brick_x[idx] + env->brick_width, env->brick_y[idx], env->brick_height, - env->ball_x, env->ball_y, env->ball_vx, env->ball_vy, env->ball_height, collision_info)) { + if (calc_vline_collision(x + brick_width, y, brick_height, + ball_x, ball_y, ball_vx, ball_vy, ball_height, collision_info)) { collision = true; } } // Brick top wall collides with ball bottom side - if (env->ball_vy > 0) { - if (calc_hline_collision(env->brick_x[idx], env->brick_y[idx], env->brick_width, - env->ball_x, env->ball_y + env->ball_height, env->ball_vx, env->ball_vy, env->ball_width, collision_info)) { + if (env->ball_vy > 0.0f) { + if (calc_hline_collision(x, y, brick_width, + ball_x, ball_y + ball_height, ball_vx, ball_vy, ball_width, collision_info)) { collision = true; collision_info->y -= env->ball_height; } @@ -236,8 +252,8 @@ static inline void calc_brick_collision(Breakout* env, int idx, // Brick bottom wall collides with ball top side if (env->ball_vy < 0) { - if (calc_hline_collision(env->brick_x[idx], env->brick_y[idx] + env->brick_height, env->brick_width, - env->ball_x, env->ball_y, env->ball_vx, env->ball_vy, env->ball_width, collision_info)) { + if (calc_hline_collision(x, y + brick_height, brick_width, + ball_x, ball_y, ball_vx, ball_vy, ball_width, collision_info)) { collision = true; } } @@ -246,10 +262,10 @@ static inline void calc_brick_collision(Breakout* env, int idx, } } static inline int column_index(Breakout* env, float x) { - return (int)(floorf(x / env->brick_width)); + return (int)(floorf(x * env->inv_brick_width)); } static inline int row_index(Breakout* env, float y) { - return (int)(floorf((y - Y_OFFSET) / env->brick_height)); + return (int)(floorf((y - Y_OFFSET) * env->inv_brick_height)); } void calc_all_brick_collisions(Breakout* env, CollisionInfo* collision_info) { @@ -262,11 +278,21 @@ void calc_all_brick_collisions(Breakout* env, CollisionInfo* collision_info) { int row_to = row_index(env, fmaxf(env->ball_y + env->ball_height + env->ball_vy, env->ball_y + env->ball_height)); row_to = fminf(row_to, env->brick_rows - 1); + const float brick_height = env->brick_height; + const float brick_width = env->brick_width; + const float ball_x = env->ball_x; + const float ball_y = env->ball_y; + const float ball_vx = env->ball_vx; + const float ball_vy = env->ball_vy; + const float ball_width = env->ball_width; + const float ball_height = env->ball_height; + for (int row = row_from; row <= row_to; row++) { for (int column = column_from; column <= column_to; column++) { int brick_index = row * env->brick_cols + column; - if (env->brick_states[brick_index] == 0.0) - calc_brick_collision(env, brick_index, collision_info); + if (env->brick_state[brick_index] == 0.0) + calc_brick_collision(env, brick_index, brick_width, brick_height, ball_x, ball_y, ball_vx, ball_vy, + ball_width, ball_height, collision_info); } } } @@ -303,7 +329,7 @@ bool calc_paddle_ball_collisions(Breakout* env, CollisionInfo* collision_info) { } if (env->score == env->half_max_score) { for (int i = 0; i < env->num_bricks; i++) { - env->brick_states[i] = 0.0; + env->brick_state[i] = 0.0; } } return true; @@ -317,7 +343,7 @@ void calc_all_wall_collisions(Breakout* env, CollisionInfo* collision_info) { collision_info->brick_index = BRICK_INDEX_SIDEWALL_COLLISION; } } - if (env->ball_vx > 0) { + if (env->ball_vx > 0.0f) { if (calc_vline_collision(env->width, 0, env->height, env->ball_x + env->ball_width, env->ball_y, env->ball_vx, env->ball_vy, env->ball_height, collision_info)) { @@ -353,7 +379,7 @@ void destroy_brick(Breakout* env, int brick_idx) { float gained_points = 7 - 3 * ((brick_idx / env->brick_cols) / 2); env->score += gained_points; - env->brick_states[brick_idx] = 1.0; + env->brick_state[brick_idx] = 1.0; env->rewards[0] += gained_points; @@ -415,7 +441,7 @@ void c_reset(Breakout* env) { env->score = 0; env->num_balls = 5; for (int i = 0; i < env->num_bricks; i++) { - env->brick_states[i] = 0.0; + env->brick_state[i] = 0.0; } reset_round(env); env->tick = 0; @@ -443,11 +469,9 @@ void step_frame(Breakout* env, float action) { act = action; } env->paddle_x += act * env->paddle_speed * TICK_RATE; - if (env->paddle_x <= 0){ - env->paddle_x = fmaxf(0, env->paddle_x); - } else { - env->paddle_x = fminf(env->width - env->paddle_width, env->paddle_x); - } + env->paddle_x = fmaxf(0.0f, env->paddle_x); + float max_paddle_x = env->width - env->paddle_width; + env->paddle_x = fminf(max_paddle_x, env->paddle_x); //Handle collisions. //Regular timestepping is done only if there are no collisions. @@ -458,7 +482,9 @@ void step_frame(Breakout* env, float action) { if (env->ball_y >= env->paddle_y + env->paddle_height) { env->num_balls -= 1; - reset_round(env); + if (env->num_balls >= 0 && env->score < env->max_score) { + reset_round(env); + } } if (env->num_balls < 0 || env->score == env->max_score) { env->terminals[0] = 1; @@ -527,7 +553,7 @@ void c_render(Breakout* env) { DrawTexturePro( client->ball, (Rectangle){ - (env->ball_vx > 0) ? 0 : 128, + (env->ball_vx > 0.0f) ? 0 : 128, 0, 128, 128, }, (Rectangle){ @@ -544,7 +570,7 @@ void c_render(Breakout* env) { for (int row = 0; row < env->brick_rows; row++) { for (int col = 0; col < env->brick_cols; col++) { int brick_idx = row * env->brick_cols + col; - if (env->brick_states[brick_idx] == 1) { + if (env->brick_state[brick_idx] == 1) { continue; } int x = env->brick_x[brick_idx];