diff --git a/pufferlib/config/ocean/tcg.ini b/pufferlib/config/ocean/tcg.ini new file mode 100644 index 000000000..202d3697c --- /dev/null +++ b/pufferlib/config/ocean/tcg.ini @@ -0,0 +1,10 @@ +[base] +package = ocean +env_name = puffer_tcg +policy_name = Policy + +[env] +num_envs = 4096 + +[train] +total_timesteps = 1_000_000 diff --git a/pufferlib/ocean/environment.py b/pufferlib/ocean/environment.py index 6c56a4ea2..aa8975d62 100644 --- a/pufferlib/ocean/environment.py +++ b/pufferlib/ocean/environment.py @@ -162,6 +162,7 @@ def make_multiagent(buf=None, **kwargs): 'spaces': make_spaces, 'multiagent': make_multiagent, 'slimevolley': 'SlimeVolley', + 'tcg': 'TCG', } def env_creator(name='squared', *args, **kwargs): diff --git a/pufferlib/ocean/tcg/binding.c b/pufferlib/ocean/tcg/binding.c new file mode 100644 index 000000000..d5b8f0bd4 --- /dev/null +++ b/pufferlib/ocean/tcg/binding.c @@ -0,0 +1,17 @@ +#include "tcg.h" + +#define Env TCG +#include "../env_binding.h" + +static int my_init(Env* env, PyObject* args, PyObject* kwargs) { + init_tcg(env); + return 0; +} + +static int my_log(PyObject* dict, Log* log) { + assign_to_dict(dict, "perf", log->perf); + assign_to_dict(dict, "score", log->score); + assign_to_dict(dict, "episode_return", log->episode_return); + assign_to_dict(dict, "episode_length", log->episode_length); + return 0; +} diff --git a/pufferlib/ocean/tcg/tcg.h b/pufferlib/ocean/tcg/tcg.h index b82d50b9a..2b7a53f3d 100644 --- a/pufferlib/ocean/tcg/tcg.h +++ b/pufferlib/ocean/tcg/tcg.h @@ -1,17 +1,26 @@ #include #include #include +#include #include +#include #include "raylib.h" #define HAND_SIZE 10 #define BOARD_SIZE 10 #define DECK_SIZE 60 #define STACK_SIZE 100 +#define MAX_TURNS 32 +#define DRAW_PENALTY -0.5f +#define WIN_REWARD 1.0f +#define LOSE_REWARD -1.0f #define ACTION_ENTER 10 #define ACTION_NOOP 11 +#define NUM_PLAYERS 2 +#define OBS_SIZE (sizeof(Obs)) + #define TO_USER true; #define TO_STACK false; @@ -23,6 +32,7 @@ bool phase_play(TCG* env, unsigned char atn); bool phase_attack(TCG* env, unsigned char atn); bool phase_block(TCG* env, unsigned char atn); void reset(TCG* env); +void draw_episode(TCG* env); typedef struct Stack Stack; struct Stack { @@ -89,6 +99,14 @@ void condense_card_array(CardArray* ary) { ary->length = idx; } +typedef struct { + float perf; + float score; + float episode_return; + float episode_length; + float n; +} Log; + struct TCG { CardArray* my_hand; CardArray* my_board; @@ -103,15 +121,101 @@ struct TCG { int op_health; int op_mana; bool op_land_played; + bool debug_logs; Stack* stack; //bool attackers[BOARD_SIZE]; //bool defenders[BOARD_SIZE][BOARD_SIZE]; int block_idx; int turn; + + int tick; + unsigned char* observations; + int* actions; + float* rewards; + unsigned char* terminals; + Log log; }; -void allocate_tcg(TCG* env) { +static inline void dbg(TCG* env, const char* fmt, ...) { + if (!env || !env->debug_logs) { + return; + } + va_list args; + va_start(args, fmt); + vprintf(fmt, args); + va_end(args); +} + +static inline CardArray* player_hand(TCG* env, int idx) { + return idx == 0 ? env->my_hand : env->op_hand; +} + +static inline CardArray* player_board(TCG* env, int idx) { + return idx == 0 ? env->my_board : env->op_board; +} + +static inline CardArray* player_deck(TCG* env, int idx) { + return idx == 0 ? env->my_deck : env->op_deck; +} + +static inline int* player_health(TCG* env, int idx) { + return idx == 0 ? &env->my_health : &env->op_health; +} + +static inline int* player_mana(TCG* env, int idx) { + return idx == 0 ? &env->my_mana : &env->op_mana; +} + +static inline bool* player_land_played(TCG* env, int idx) { + return idx == 0 ? &env->my_land_played : &env->op_land_played; +} + +static inline int current_player(TCG* env) { + return env->turn % NUM_PLAYERS; +} + +void add_log(TCG* env) { + env->log.perf += (env->rewards[0] > 0) ? 1 : 0; + env->log.score += env->rewards[0]; + env->log.episode_length += env->tick; + env->log.episode_return += env->rewards[0]; + env->log.n++; +} + +void end_episode(TCG* env, int winner) { + int loser = 1 - winner; + env->rewards[winner] = WIN_REWARD; + env->rewards[loser] = LOSE_REWARD; + env->terminals[winner] = 1; + env->terminals[loser] = 1; + add_log(env); + reset(env); +} + +void draw_episode(TCG* env) { + for (int player = 0; player < NUM_PLAYERS; player++) { + env->rewards[player] = DRAW_PENALTY; + env->terminals[player] = 1; + } + add_log(env); + reset(env); +} + +typedef struct { + int turn; + int my_health; + int my_mana; + int op_health; + int op_mana; + int op_hand_length; + + Card my_hand[HAND_SIZE]; + Card my_board[BOARD_SIZE]; + Card op_board[BOARD_SIZE]; +} Obs; + +void init_tcg(TCG* env) { env->stack = calloc(1, sizeof(Stack)); env->my_hand = allocate_card_array(HAND_SIZE); env->op_hand = allocate_card_array(HAND_SIZE); @@ -119,6 +223,15 @@ void allocate_tcg(TCG* env) { env->op_board = allocate_card_array(BOARD_SIZE); env->my_deck = allocate_card_array(DECK_SIZE); env->op_deck = allocate_card_array(DECK_SIZE); + env->debug_logs = getenv("TCG_DEBUG") != NULL; +} + +void allocate_tcg(TCG* env) { + init_tcg(env); + env->observations = (unsigned char*)calloc(NUM_PLAYERS * OBS_SIZE, sizeof(unsigned char)); + env->actions = (int*)calloc(NUM_PLAYERS, sizeof(int)); + env->rewards = (float*)calloc(NUM_PLAYERS, sizeof(float)); + env->terminals = (unsigned char*)calloc(NUM_PLAYERS, sizeof(unsigned char)); } void free_tcg(TCG* env) { @@ -128,6 +241,11 @@ void free_tcg(TCG* env) { free_card_array(env->op_board); free_card_array(env->my_deck); free_card_array(env->op_deck); + free(env->stack); + free(env->observations); + free(env->actions); + free(env->rewards); + free(env->terminals); } void randomize_deck(CardArray* deck) { @@ -144,18 +262,19 @@ void randomize_deck(CardArray* deck) { } } -void draw_card(TCG* env, CardArray* deck, CardArray* hand) { +bool draw_card(TCG* env, int player_idx, CardArray* deck, CardArray* hand) { if (deck->length == 0) { - reset(env); - return; + end_episode(env, 1 - player_idx); + return false; } if (hand->length == hand->max) { - return; + return true; } Card card = deck->cards[deck->length - 1]; hand->cards[hand->length] = card; deck->length -= 1; hand->length += 1; + return true; } bool can_attack(CardArray* board) { @@ -168,7 +287,7 @@ bool can_attack(CardArray* board) { } int tappable_mana(TCG* env) { - CardArray* board = (env->turn == 0) ? env->my_board : env->op_board; + CardArray* board = player_board(env, current_player(env)); int tappable = 0; for (int i = 0; i < board->length; i++) { Card card = board->cards[i]; @@ -180,9 +299,10 @@ int tappable_mana(TCG* env) { } bool can_play(TCG* env) { - CardArray* hand = (env->turn == 0) ? env->my_hand : env->op_hand; - int* mana = (env->turn == 0) ? &env->my_mana : &env->op_mana; - bool* land_played = (env->turn == 0) ? &env->my_land_played : &env->op_land_played; + int player = current_player(env); + CardArray* hand = player_hand(env, player); + int* mana = player_mana(env, player); + bool* land_played = player_land_played(env, player); int min_cost = 99; for (int i = 0; i < hand->length; i++) { @@ -198,14 +318,20 @@ bool can_play(TCG* env) { } bool phase_untap(TCG* env, unsigned char atn) { - printf("PHASE_UNTAP\n"); - bool* land_played = (env->turn == 0) ? &env->my_land_played : &env->op_land_played; + dbg(env, "PHASE_UNTAP\n"); + env->turn += 1; + if (env->turn >= MAX_TURNS) { + draw_episode(env); + return TO_STACK; + } + + int player = current_player(env); + bool* land_played = player_land_played(env, player); *land_played = false; - env->turn = 1 - env->turn; - CardArray* board = (env->turn == 0) ? env->my_board : env->op_board; + CardArray* board = player_board(env, player); - int* mana = (env->turn == 0) ? &env->my_mana : &env->op_mana; + int* mana = player_mana(env, player); *mana = 0; for (int i = 0; i < board->length; i++) { @@ -220,29 +346,33 @@ bool phase_untap(TCG* env, unsigned char atn) { } bool phase_draw(TCG* env, unsigned char atn) { - printf("PHASE_DRAW\n"); - CardArray* deck = (env->turn == 0) ? env->my_deck : env->op_deck; - CardArray* hand = (env->turn == 0) ? env->my_hand : env->op_hand; - draw_card(env, deck, hand); + dbg(env, "PHASE_DRAW\n"); + int player = current_player(env); + CardArray* deck = player_deck(env, player); + CardArray* hand = player_hand(env, player); + if (!draw_card(env, player, deck, hand)) { + return TO_STACK; + } push(env->stack, phase_play); return TO_STACK; } bool phase_play(TCG* env, unsigned char atn) { - printf("PHASE_PLAY\n"); - CardArray* hand = (env->turn == 0) ? env->my_hand : env->op_hand; - CardArray* board = (env->turn == 0) ? env->my_board : env->op_board; - int* mana = (env->turn == 0) ? &env->my_mana : &env->op_mana; - bool* land_played = (env->turn == 0) ? &env->my_land_played : &env->op_land_played; + dbg(env, "PHASE_PLAY\n"); + int player = current_player(env); + CardArray* hand = player_hand(env, player); + CardArray* board = player_board(env, player); + int* mana = player_mana(env, player); + bool* land_played = player_land_played(env, player); if (board->length == BOARD_SIZE) { - printf("\t Board full\n"); + dbg(env, "\t Board full\n"); push(env->stack, phase_attack); return TO_STACK; } if (!can_play(env)) { - printf("\t No valid moves\n"); + dbg(env, "\t No valid moves\n"); push(env->stack, phase_attack); return TO_STACK; } @@ -254,7 +384,7 @@ bool phase_play(TCG* env, unsigned char atn) { push(env->stack, phase_attack); return TO_STACK; } else if (atn >= hand->length) { - printf("\t Invalid action: %i\n. Hand length: %i\n", atn, hand->length); + dbg(env, "\t Invalid action: %i\n. Hand length: %i\n", atn, hand->length); push(env->stack, phase_play); return TO_USER; } @@ -262,7 +392,7 @@ bool phase_play(TCG* env, unsigned char atn) { Card card = hand->cards[atn]; if (card.is_land) { if (*land_played) { - printf("\t Already played land this turn\n"); + dbg(env, "\t Already played land this turn\n"); push(env->stack, phase_play); return TO_USER; } @@ -271,13 +401,13 @@ bool phase_play(TCG* env, unsigned char atn) { *land_played = true; hand->cards[atn].remove = true; condense_card_array(hand); - printf("\t Land played\n"); + dbg(env, "\t Land played\n"); push(env->stack, phase_play); return TO_USER; } if (card.cost > *mana + tappable_mana(env)) { - printf("\t Not enough mana\n"); + dbg(env, "\t Not enough mana\n"); push(env->stack, phase_play); return TO_USER; } @@ -300,17 +430,18 @@ bool phase_play(TCG* env, unsigned char atn) { board->length += 1; hand->cards[atn].remove = true; condense_card_array(hand); - printf("\t Card played\n"); + dbg(env, "\t Card played\n"); push(env->stack, phase_play); return TO_USER; } bool phase_attack(TCG* env, unsigned char atn) { - printf("PHASE_ATTACK\n"); - CardArray* board = (env->turn == 0) ? env->my_board : env->op_board; + dbg(env, "PHASE_ATTACK\n"); + int attacker = current_player(env); + CardArray* board = player_board(env, attacker); if (!can_attack(board)) { - printf("\t No valid attacks. Phase end\n"); + dbg(env, "\t No valid attacks. Phase end\n"); push(env->stack, phase_untap); return TO_STACK; } @@ -319,20 +450,19 @@ bool phase_attack(TCG* env, unsigned char atn) { push(env->stack, phase_attack); return TO_USER; } else if (atn == ACTION_ENTER) { - printf("\t Attacks confirmed. Phase end\n"); - env->turn = 1 - env->turn; + dbg(env, "\t Attacks confirmed. Phase end\n"); push(env->stack, phase_block); return TO_STACK; } else if (atn >= board->length) { - printf("\t Invalid action %i\n", atn); + dbg(env, "\t Invalid action %i\n", atn); push(env->stack, phase_attack); return TO_USER; } else if (board->cards[atn].is_land) { - printf("\t Cannot attack with land\n"); + dbg(env, "\t Cannot attack with land\n"); push(env->stack, phase_attack); return TO_USER; } else { - printf("\t Setting attacker %i\n", atn); + dbg(env, "\t Setting attacker %i\n", atn); board->cards[atn].attacking = !board->cards[atn].attacking; push(env->stack, phase_attack); return TO_USER; @@ -340,13 +470,15 @@ bool phase_attack(TCG* env, unsigned char atn) { } bool phase_block(TCG* env, unsigned char atn) { - printf("PHASE_BLOCK\n"); - CardArray* defender_board = (env->turn == 0) ? env->my_board : env->op_board; - CardArray* board = (env->turn == 0) ? env->op_board : env->my_board; - int* health = (env->turn == 0) ? &env->op_health : &env->my_health; + dbg(env, "PHASE_BLOCK\n"); + int attacker_player = current_player(env); + int defender_player = 1 - attacker_player; + CardArray* defender_board = player_board(env, defender_player); + CardArray* board = player_board(env, attacker_player); + int* health = player_health(env, defender_player); while (env->block_idx < board->length && !board->cards[env->block_idx].attacking) { - printf("\t Skipping block for %i (not attacking)\n", env->block_idx); + dbg(env, "\t Skipping block for %i (not attacking)\n", env->block_idx); env->block_idx++; } @@ -358,7 +490,7 @@ bool phase_block(TCG* env, unsigned char atn) { } if (card->defending == -1 || card->defending == env->block_idx) { can_block = true; - printf("\t Can block with %i\n", i); + dbg(env, "\t Can block with %i\n", i); break; } } @@ -367,12 +499,12 @@ bool phase_block(TCG* env, unsigned char atn) { } if (env->block_idx == board->length) { - printf("\t Attacker board length: %i\n", board->length); + dbg(env, "\t Attacker board length: %i\n", board->length); for (int atk = 0; atk < board->length; atk++) { - printf("\t Resolving %i\n", atk); + dbg(env, "\t Resolving %i\n", atk); Card* attacker = &board->cards[atk]; if (!attacker->attacking) { - printf("\t Not attacking\n"); + dbg(env, "\t Not attacking\n"); continue; } int attacker_attack = attacker->attack; @@ -396,21 +528,25 @@ bool phase_block(TCG* env, unsigned char atn) { break; } } - printf("\t Reducing health by %i\n", attacker_attack); - *health -= attacker_attack; + int damage_to_player = attacker_attack; + dbg(env, "\t Reducing health by %i\n", damage_to_player); + *health -= damage_to_player; } if (*health <= 0) { - printf("\t Game over\n"); - reset(env); + int winner = attacker_player; + end_episode(env, winner); + return TO_STACK; } condense_card_array(env->my_board); condense_card_array(env->op_board); - CardArray* defender_deck = (env->turn == 0) ? env->my_deck : env->op_deck; - CardArray* defender_hand = (env->turn == 0) ? env->my_hand : env->op_hand; - draw_card(env, defender_deck, defender_hand); + CardArray* defender_deck = player_deck(env, defender_player); + CardArray* defender_hand = player_hand(env, defender_player); + if (!draw_card(env, defender_player, defender_deck, defender_hand)) { + return TO_STACK; + } for (int i = 0; i < board->length; i++) { board->cards[i].attacking = false; @@ -418,9 +554,8 @@ bool phase_block(TCG* env, unsigned char atn) { for (int i = 0; i < defender_board->length; i++) { defender_board->cards[i].defending = -1; } - printf("\t Set block idx to 0\n"); + dbg(env, "\t Set block idx to 0\n"); env->block_idx = 0; - env->turn = 1 - env->turn; push(env->stack, phase_untap); return TO_STACK; } @@ -429,28 +564,28 @@ bool phase_block(TCG* env, unsigned char atn) { push(env->stack, phase_block); return TO_USER; } else if (atn == ACTION_ENTER) { - printf("\t Manual block confirm %i\n", env->block_idx); + dbg(env, "\t Manual block confirm %i\n", env->block_idx); env->block_idx++; push(env->stack, phase_block); return TO_STACK; } else if (atn >= defender_board->length) { - printf("\t Invalid block action %i\n", atn); + dbg(env, "\t Invalid block action %i\n", atn); push(env->stack, phase_block); return TO_USER; } else if (defender_board->cards[atn].is_land) { - printf("\t Cannot block with land\n"); + dbg(env, "\t Cannot block with land\n"); push(env->stack, phase_block); return TO_USER; } for (int i = 0; i < env->block_idx; i++) { if (defender_board->cards[atn].defending == i) { - printf("\t Already blocked\n"); + dbg(env, "\t Already blocked\n"); push(env->stack, phase_block); return TO_USER; } } - printf("\t Blocking index %i with %i\n", env->block_idx, atn); + dbg(env, "\t Blocking index %i with %i\n", env->block_idx, atn); Card* card = &defender_board->cards[atn]; if (card->defending == env->block_idx) { card->defending = -1; @@ -461,12 +596,79 @@ bool phase_block(TCG* env, unsigned char atn) { return TO_USER; } +void write_card(unsigned char *obs, Card *c) { + obs[0] = c->is_land; + obs[1] = c->cost; + obs[2] = c->attack; + obs[3] = c->health; + obs[4] = c->tapped; + obs[5] = c->attacking; + obs[6] = c->defending; +} + +void update_observations(TCG* env) { + for (int player = 0; player < NUM_PLAYERS; player++) { + unsigned char* obs = env->observations + player * OBS_SIZE; + int idx = 0; + int opponent = 1 - player; + int turn_player = current_player(env); + + CardArray* self_hand = player_hand(env, player); + CardArray* self_board = player_board(env, player); + CardArray* opp_board = player_board(env, opponent); + CardArray* opp_hand = player_hand(env, opponent); + + obs[idx++] = *player_health(env, player); + obs[idx++] = *player_health(env, opponent); + obs[idx++] = (turn_player == player); + obs[idx++] = *player_mana(env, player); + obs[idx++] = *player_mana(env, opponent); + obs[idx++] = self_hand->length; + obs[idx++] = self_board->length; + obs[idx++] = opp_board->length; + + for (int i = 0; i < HAND_SIZE; i++) { + if (i < self_hand->length) { + write_card(&obs[idx], &self_hand->cards[i]); + } else { + memset(&obs[idx], 0, 7); + } + idx += 7; + } + + for (int i = 0; i < BOARD_SIZE; i++) { + if (i < self_board->length) { + write_card(&obs[idx], &self_board->cards[i]); + } else { + memset(&obs[idx], 0, 7); + } + idx += 7; + } + + for (int i = 0; i < BOARD_SIZE; i++) { + if (i < opp_board->length) { + write_card(&obs[idx], &opp_board->cards[i]); + } else { + memset(&obs[idx], 0, 7); + } + idx += 7; + } + + obs[idx++] = opp_hand->length; + + } +} + void step(TCG* env, unsigned char atn) { - printf("Turn: %i, Action: %i\n", env->turn, atn); + dbg(env, "Turn: %i (player %i), Action: %i\n", env->turn, current_player(env), atn); + env->rewards[0] = 0; + env->rewards[1] = 0; + env->tick += 1; while (true) { call fn = pop(env->stack); bool return_to_user = fn(env, atn); if (return_to_user) { + update_observations(env); return; } atn = ACTION_NOOP; @@ -482,12 +684,18 @@ void reset(TCG* env) { env->op_board->length = 0; env->my_health = 20; env->op_health = 20; + env->my_mana = 0; + env->op_mana = 0; + env->tick = 0; + env->op_mana = 0; + env->tick = 0; + memset(env->observations, 0, NUM_PLAYERS * OBS_SIZE * sizeof(unsigned char)); randomize_deck(env->my_deck); randomize_deck(env->op_deck); env->turn = rand() % 2; for (int i = 0; i < 5; i++) { - draw_card(env, env->my_deck, env->my_hand); - draw_card(env, env->op_deck, env->op_hand); + draw_card(env, 0, env->my_deck, env->my_hand); + draw_card(env, 1, env->op_deck, env->op_hand); } push(env->stack, phase_draw); step(env, ACTION_NOOP); @@ -536,7 +744,7 @@ void render(TCG* env) { int x = card_x(i, env->my_hand->length); int y = card_y(3); render_card(&card, x, y, RED); - if (env->turn == 0) { + if (current_player(env) == 0) { render_label(x, y, i); } } @@ -550,7 +758,7 @@ void render(TCG* env) { } Color color = (card.tapped) ? (Color){128, 0, 0, 255}: RED; render_card(&card, x, y, color); - if (env->turn == 0) { + if (current_player(env) == 0) { render_label(x, y, i); } } @@ -589,18 +797,50 @@ void render(TCG* env) { int y = 32; call fn = peek(env->stack); - if (fn == phase_draw) { - DrawText("Draw", x, y, 20, WHITE); - } else if (fn == phase_play) { - DrawText("Play", x, y, 20, WHITE); - } else if (fn == phase_attack) { - DrawText("Attack", x, y, 20, WHITE); - } else if (fn == phase_block) { - DrawText("Block", x, y, 20, WHITE); - } + if (fn == phase_draw) { + DrawText("Draw", x, y, 20, WHITE); + } else if (fn == phase_play) { + DrawText("Play", x, y, 20, WHITE); + } else if (fn == phase_attack) { + DrawText("Attack", x, y, 20, WHITE); + } else if (fn == phase_block) { + DrawText("Block", x, y, 20, WHITE); + } DrawText(TextFormat("Health: %i", env->my_health), 32, 32, 20, WHITE); DrawText(TextFormat("Health: %i", env->op_health), 32, GetScreenHeight() - 64, 20, WHITE); EndDrawing(); } + +void c_render(TCG* env) { + if (!IsWindowReady()) { + init_client(env); + } + render(env); +} + +void c_step(TCG* env) { + call phase = peek(env->stack); + int actor = current_player(env); + if (phase == phase_block) { + actor = 1 - actor; + } + int action = env->actions[actor]; + if (action < 0 || action >= 12) { + action = ACTION_NOOP; + } + step(env, (unsigned char)action); +} + +void c_reset(TCG* env) { + reset(env); +} + +// Required function. Should clean up anything you allocated +// Do not free env->observations, actions, rewards, terminals +void c_close(TCG* env) { + if (IsWindowReady()) { + CloseWindow(); + } +} diff --git a/pufferlib/ocean/tcg/tcg.py b/pufferlib/ocean/tcg/tcg.py new file mode 100644 index 000000000..5c3b90676 --- /dev/null +++ b/pufferlib/ocean/tcg/tcg.py @@ -0,0 +1,70 @@ +import gymnasium +import numpy as np + +import pufferlib +from pufferlib.ocean.tcg import binding + +class TCG(pufferlib.PufferEnv): + def __init__(self, num_envs=1, render_mode=None, log_interval=128, buf=None, seed=0): + self.single_observation_space = gymnasium.spaces.Box( + low=0, high=1, + shape=(624,), dtype=np.uint8 + ) + self.single_action_space = gymnasium.spaces.Discrete(12) + + self.render_mode = render_mode + self.players_per_env = 2 + self.num_envs = num_envs + self.num_agents = num_envs * self.players_per_env + self.log_interval = log_interval + + super().__init__(buf) + obs_dim = int(np.prod(self.single_observation_space.shape)) + self._obs_view = self.observations.reshape(num_envs, self.players_per_env, obs_dim) + self._act_view = self.actions.reshape(num_envs, self.players_per_env) + self._rew_view = self.rewards.reshape(num_envs, self.players_per_env) + self._term_view = self.terminals.reshape(num_envs, self.players_per_env) + self._trunc_view = self.truncations.reshape(num_envs, self.players_per_env) + + c_envs = [] + for i in range(num_envs): + env_id = binding.env_init( + self._obs_view[i], + self._act_view[i], + self._rew_view[i], + self._term_view[i], + self._trunc_view[i], + seed + i, + ) + c_envs.append(env_id) + + self.c_envs = binding.vectorize(*c_envs) + + def _update_masks(self): + self.masks[:] = self.observations[:, 2].astype(bool) + + def reset(self, seed=0): + binding.vec_reset(self.c_envs, seed) + self.tick = 0 + self._update_masks() + return self.observations, [] + + def step(self, actions): + self.tick += 1 + + self.actions[:] = actions + binding.vec_step(self.c_envs) + self._update_masks() + + info = [] + if self.tick % self.log_interval == 0: + info.append(binding.vec_log(self.c_envs)) + + return (self.observations, self.rewards, + self.terminals, self.truncations, info) + + def render(self): + binding.vec_render(self.c_envs, 0) + + def close(self): + binding.vec_close(self.c_envs)