Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
120 changes: 73 additions & 47 deletions pufferlib/ocean/breakout/breakout.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,12 +50,13 @@ typedef struct Breakout {
float ball_y;
float ball_vx;
float ball_vy;
float* brick_x;
float* brick_y;
float* brick_states;
float *brick_x;
float *brick_y;
float *brick_state;
int balls_fired;
float initial_paddle_width;
float paddle_width;
float inv_paddle_width;
float paddle_height;
float paddle_speed;
float ball_speed;
Expand All @@ -64,13 +65,17 @@ typedef struct Breakout {
int hits;
int width;
int height;
float inv_height;
float inv_width;
int num_bricks;
int brick_rows;
int brick_cols;
int ball_width;
int ball_height;
int brick_width;
int brick_height;
float inv_brick_height;
float inv_brick_width;
int num_balls;
int max_score;
int half_max_score;
Expand Down Expand Up @@ -108,10 +113,14 @@ void init(Breakout* env) {
env->tick = 0;
env->num_bricks = env->brick_rows * env->brick_cols;
assert(env->num_bricks > 0);
env->inv_height = 1.0f / env->height;
env->inv_width = 1.0f / env->width;
env->inv_brick_height = 1.0f / env->brick_height;
env->inv_brick_width = 1.0f / env->brick_width;

env->brick_x = (float*)calloc(env->num_bricks, sizeof(float));
env->brick_y = (float*)calloc(env->num_bricks, sizeof(float));
env->brick_states = (float*)calloc(env->num_bricks, sizeof(float));
env->brick_state = (float*)calloc(env->num_bricks, sizeof(float));
env->num_balls = -1;
generate_brick_positions(env);
}
Expand All @@ -127,7 +136,7 @@ void allocate(Breakout* env) {
void c_close(Breakout* env) {
free(env->brick_x);
free(env->brick_y);
free(env->brick_states);
free(env->brick_state);
}

void free_allocated(Breakout* env) {
Expand All @@ -147,34 +156,36 @@ void add_log(Breakout* env) {
}

void compute_observations(Breakout* env) {
env->observations[0] = env->paddle_x / env->width;
env->observations[1] = env->paddle_y / env->height;
env->observations[2] = env->ball_x / env->width;
env->observations[3] = env->ball_y / env->height;
env->observations[0] = env->paddle_x * env->inv_width;
env->observations[1] = env->paddle_y * env->inv_height;
env->observations[2] = env->ball_x * env->inv_width;
env->observations[3] = env->ball_y * env->inv_height;
env->observations[4] = env->ball_vx / 512.0f;
env->observations[5] = env->ball_vy / 512.0f;
env->observations[6] = env->balls_fired / 5.0f;
env->observations[7] = env->score / 864.0f;
env->observations[8] = env->num_balls / 5.0f;
env->observations[9] = env->paddle_width / (2.0f * HALF_PADDLE_WIDTH);
for (int i = 0; i < env->num_bricks; i++) {
env->observations[10 + i] = env->brick_states[i];
env->observations[10 + i] = env->brick_state[i];
}
}

// Collision of a stationary vertical line segment (xw,yw) to (xw,yw+hw)
// with a moving line segment (x+vx*t,y+vy*t) to (x+vx*t,y+vy*t+h).
static inline bool calc_vline_collision(float xw, float yw, float hw, float x,
float y, float vx, float vy, float h, CollisionInfo* col) {
float t_new = (xw - x) / vx;
float topmost = fmin(yw + hw, y + h + vy * t_new);
float botmost = fmax(yw, y + vy * t_new);
float overlap_new = topmost - botmost;
const float t_new = (xw - x) / vx;
if (t_new <= 0.0f || t_new > 1.0f) return false;

const float topmost = fmin(yw + hw, y + h + vy * t_new);
const float botmost = fmax(yw, y + vy * t_new);
const float overlap_new = topmost - botmost;
if (overlap_new <= 0.0f) return false;

// Collision finds the smallest time of collision with the greatest overlap
// between the ball and the wall.
if (overlap_new > 0.0f && t_new > 0.0f && t_new <= 1.0f &&
(t_new < col->t || (t_new == col->t && overlap_new > col->overlap))) {
if ((t_new < col->t || (t_new == col->t && overlap_new > col->overlap))) {
col->t = t_new;
col->overlap = overlap_new;
col->x = xw;
Expand All @@ -187,10 +198,10 @@ static inline bool calc_vline_collision(float xw, float yw, float hw, float x,
}
static inline bool calc_hline_collision(float xw, float yw, float ww,
float x, float y, float vx, float vy, float w, CollisionInfo* col) {
float t_new = (yw - y) / vy;
float rightmost = fminf(xw + ww, x + w + vx * t_new);
float leftmost = fmaxf(xw, x + vx * t_new);
float overlap_new = rightmost - leftmost;
const float t_new = (yw - y) / vy;
const float rightmost = fminf(xw + ww, x + w + vx * t_new);
const float leftmost = fmaxf(xw, x + vx * t_new);
const float overlap_new = rightmost - leftmost;

// Collision finds the smallest time of collision with the greatest overlap between the ball and the wall.
if (overlap_new > 0.0f && t_new > 0.0f && t_new <= 1.0f &&
Expand All @@ -205,39 +216,44 @@ static inline bool calc_hline_collision(float xw, float yw, float ww,
}
return false;
}
static inline void calc_brick_collision(Breakout* env, int idx,
static inline void calc_brick_collision(Breakout* env, int idx,
float brick_width, float brick_height, float ball_x, float ball_y,
float ball_vx, float ball_vy, float ball_width, float ball_height,
CollisionInfo* collision_info) {
bool collision = false;
// Brick left wall collides with ball right side
if (env->ball_vx > 0) {
if (calc_vline_collision(env->brick_x[idx], env->brick_y[idx], env->brick_height,
env->ball_x + env->ball_width, env->ball_y, env->ball_vx, env->ball_vy, env->ball_height, collision_info)) {
const float x = env->brick_x[idx];
const float y = env->brick_y[idx];

if (env->ball_vx > 0.0f) {
if (calc_vline_collision(x, y, brick_height, ball_x + ball_width, ball_y, ball_vx, ball_vy, ball_height,
collision_info)) {
collision = true;
collision_info->x -= env->ball_width;
}
}

// Brick right wall collides with ball left side
if (env->ball_vx < 0) {
if (calc_vline_collision(env->brick_x[idx] + env->brick_width, env->brick_y[idx], env->brick_height,
env->ball_x, env->ball_y, env->ball_vx, env->ball_vy, env->ball_height, collision_info)) {
if (calc_vline_collision(x + brick_width, y, brick_height,
ball_x, ball_y, ball_vx, ball_vy, ball_height, collision_info)) {
collision = true;
}
}

// Brick top wall collides with ball bottom side
if (env->ball_vy > 0) {
if (calc_hline_collision(env->brick_x[idx], env->brick_y[idx], env->brick_width,
env->ball_x, env->ball_y + env->ball_height, env->ball_vx, env->ball_vy, env->ball_width, collision_info)) {
if (env->ball_vy > 0.0f) {
if (calc_hline_collision(x, y, brick_width,
ball_x, ball_y + ball_height, ball_vx, ball_vy, ball_width, collision_info)) {
collision = true;
collision_info->y -= env->ball_height;
}
}

// Brick bottom wall collides with ball top side
if (env->ball_vy < 0) {
if (calc_hline_collision(env->brick_x[idx], env->brick_y[idx] + env->brick_height, env->brick_width,
env->ball_x, env->ball_y, env->ball_vx, env->ball_vy, env->ball_width, collision_info)) {
if (calc_hline_collision(x, y + brick_height, brick_width,
ball_x, ball_y, ball_vx, ball_vy, ball_width, collision_info)) {
collision = true;
}
}
Expand All @@ -246,10 +262,10 @@ static inline void calc_brick_collision(Breakout* env, int idx,
}
}
static inline int column_index(Breakout* env, float x) {
return (int)(floorf(x / env->brick_width));
return (int)(floorf(x * env->inv_brick_width));
}
static inline int row_index(Breakout* env, float y) {
return (int)(floorf((y - Y_OFFSET) / env->brick_height));
return (int)(floorf((y - Y_OFFSET) * env->inv_brick_height));
}

void calc_all_brick_collisions(Breakout* env, CollisionInfo* collision_info) {
Expand All @@ -262,11 +278,21 @@ void calc_all_brick_collisions(Breakout* env, CollisionInfo* collision_info) {
int row_to = row_index(env, fmaxf(env->ball_y + env->ball_height + env->ball_vy, env->ball_y + env->ball_height));
row_to = fminf(row_to, env->brick_rows - 1);

const float brick_height = env->brick_height;
const float brick_width = env->brick_width;
const float ball_x = env->ball_x;
const float ball_y = env->ball_y;
const float ball_vx = env->ball_vx;
const float ball_vy = env->ball_vy;
const float ball_width = env->ball_width;
const float ball_height = env->ball_height;

for (int row = row_from; row <= row_to; row++) {
for (int column = column_from; column <= column_to; column++) {
int brick_index = row * env->brick_cols + column;
if (env->brick_states[brick_index] == 0.0)
calc_brick_collision(env, brick_index, collision_info);
if (env->brick_state[brick_index] == 0.0)
calc_brick_collision(env, brick_index, brick_width, brick_height, ball_x, ball_y, ball_vx, ball_vy,
ball_width, ball_height, collision_info);
}
}
}
Expand Down Expand Up @@ -303,7 +329,7 @@ bool calc_paddle_ball_collisions(Breakout* env, CollisionInfo* collision_info) {
}
if (env->score == env->half_max_score) {
for (int i = 0; i < env->num_bricks; i++) {
env->brick_states[i] = 0.0;
env->brick_state[i] = 0.0;
}
}
return true;
Expand All @@ -317,7 +343,7 @@ void calc_all_wall_collisions(Breakout* env, CollisionInfo* collision_info) {
collision_info->brick_index = BRICK_INDEX_SIDEWALL_COLLISION;
}
}
if (env->ball_vx > 0) {
if (env->ball_vx > 0.0f) {
if (calc_vline_collision(env->width, 0, env->height,
env->ball_x + env->ball_width, env->ball_y, env->ball_vx, env->ball_vy, env->ball_height,
collision_info)) {
Expand Down Expand Up @@ -353,7 +379,7 @@ void destroy_brick(Breakout* env, int brick_idx) {
float gained_points = 7 - 3 * ((brick_idx / env->brick_cols) / 2);

env->score += gained_points;
env->brick_states[brick_idx] = 1.0;
env->brick_state[brick_idx] = 1.0;

env->rewards[0] += gained_points;

Expand Down Expand Up @@ -415,7 +441,7 @@ void c_reset(Breakout* env) {
env->score = 0;
env->num_balls = 5;
for (int i = 0; i < env->num_bricks; i++) {
env->brick_states[i] = 0.0;
env->brick_state[i] = 0.0;
}
reset_round(env);
env->tick = 0;
Expand Down Expand Up @@ -443,11 +469,9 @@ void step_frame(Breakout* env, float action) {
act = action;
}
env->paddle_x += act * env->paddle_speed * TICK_RATE;
if (env->paddle_x <= 0){
env->paddle_x = fmaxf(0, env->paddle_x);
} else {
env->paddle_x = fminf(env->width - env->paddle_width, env->paddle_x);
}
env->paddle_x = fmaxf(0.0f, env->paddle_x);
float max_paddle_x = env->width - env->paddle_width;
env->paddle_x = fminf(max_paddle_x, env->paddle_x);

//Handle collisions.
//Regular timestepping is done only if there are no collisions.
Expand All @@ -458,7 +482,9 @@ void step_frame(Breakout* env, float action) {

if (env->ball_y >= env->paddle_y + env->paddle_height) {
env->num_balls -= 1;
reset_round(env);
if (env->num_balls >= 0 && env->score < env->max_score) {
reset_round(env);
}
}
if (env->num_balls < 0 || env->score == env->max_score) {
env->terminals[0] = 1;
Expand Down Expand Up @@ -527,7 +553,7 @@ void c_render(Breakout* env) {
DrawTexturePro(
client->ball,
(Rectangle){
(env->ball_vx > 0) ? 0 : 128,
(env->ball_vx > 0.0f) ? 0 : 128,
0, 128, 128,
},
(Rectangle){
Expand All @@ -544,7 +570,7 @@ void c_render(Breakout* env) {
for (int row = 0; row < env->brick_rows; row++) {
for (int col = 0; col < env->brick_cols; col++) {
int brick_idx = row * env->brick_cols + col;
if (env->brick_states[brick_idx] == 1) {
if (env->brick_state[brick_idx] == 1) {
continue;
}
int x = env->brick_x[brick_idx];
Expand Down