diff --git a/README.md b/README.md index 07ee14403..7d1a98cd8 100644 --- a/README.md +++ b/README.md @@ -1,93 +1,11 @@ -# The pokemon showdown Python environment +>[!IMPORTANT] +> This is a fork of [poke-env](https://github.com/hsahovic/poke-env) that installs with [`metamon`](https://github.com/UT-Austin-RPL/metamon). It attempts to extend the lifespan of poke-env as it was during Metamon's development: +> 1. Maintains the original gymnasium interface that existed until v0.8.3. `OpenAIGymEnv` (+ ability to swap in custom Players). Rewards functions that take `last_battle` and `current_battle` as input (+ a speed boost). Removes "observation" system that slows fps and is already handled by Metamon. +> 2. Preserves minor early-generation battle details as they were when Metamon's original models were trained. +> 3. Tries to bring key fixes/improvements since v0.8.3 that are unrelated to gymnasium. +> +> Please see the main repo [here](https://github.com/hsahovic/poke-env) for any other use case. I only plan to update this to fix breaking changes to the Showdown sim/request message API. Any improvements to early-generation state tracking/sim protocol are now done in metamon. -[![PyPI version fury.io](https://badge.fury.io/py/poke-env.svg)](https://pypi.python.org/pypi/poke-env/) -[![PyPI pyversions](https://img.shields.io/pypi/pyversions/poke-env.svg)](https://pypi.python.org/pypi/poke-env/) -[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) -[![Documentation Status](https://readthedocs.org/projects/poke-env/badge/?version=stable)](https://poke-env.readthedocs.io/en/stable/?badge=stable) -[![codecov](https://codecov.io/gh/hsahovic/poke-env/branch/master/graph/badge.svg)](https://codecov.io/gh/hsahovic/poke-env) - -A Python interface to create battling pokemon agents. `poke-env` offers an easy-to-use interface for creating rule-based or training Reinforcement Learning bots to battle on [pokemon showdown](https://pokemonshowdown.com/). - -![A simple agent in action](rl-gif.gif) - -## Getting started - -Agents are instance of python classes inheriting from `Player`. Here is what your first agent could look like: - -```python -class YourFirstAgent(Player): - def choose_move(self, battle): - for move in battle.available_moves: - if move.base_power > 90: - # A powerful move! Let's use it - return self.create_order(move) - - # No available move? Let's switch then! - for switch in battle.available_switches: - if switch.current_hp_fraction > battle.active_pokemon.current_hp_fraction: - # This other pokemon has more HP left... Let's switch it in? - return self.create_order(switch) - - # Not sure what to do? - return self.choose_random_move(battle) -``` - -To get started, take a look at [our documentation](https://poke-env.readthedocs.io/en/stable/)! - - -## Documentation and examples - -Documentation, detailed examples and starting code can be found [on readthedocs](https://poke-env.readthedocs.io/en/stable/). - - -## Installation - -This project requires python >= 3.9 and a [Pokemon Showdown](https://github.com/Zarel/Pokemon-Showdown) server. - -``` -pip install poke-env -``` - -You can use [smogon's server](https://play.pokemonshowdown.com/) to try out your agents against humans, but having a development server is strongly recommended. In particular, it is recommended to use the `--no-security` flag to run a local server with most rate limiting and throttling turned off. Please refer to [the docs](https://poke-env.readthedocs.io/en/stable/getting_started.html#configuring-a-showdown-server) for detailed setup instructions. - - -``` -git clone https://github.com/smogon/pokemon-showdown.git -cd pokemon-showdown -npm install -cp config/config-example.js config/config.js -node pokemon-showdown start --no-security -``` - -## Development version - -You can also clone the latest master version with: - -``` -git clone https://github.com/hsahovic/poke-env.git -``` - -Dependencies and development dependencies can then be installed with: - -``` -pip install -r requirements.txt -pip install -r requirements-dev.txt -``` - -## Acknowledgements - -This project is a follow-up of a group project from an artifical intelligence class at [Ecole Polytechnique](https://www.polytechnique.edu/). - -You can find the original repository [here](https://github.com/hsahovic/inf581-project). It is partially inspired by the [showdown-battle-bot project](https://github.com/Synedh/showdown-battle-bot). Of course, none of these would have been possible without [Pokemon Showdown](https://github.com/Zarel/Pokemon-Showdown). - -Team data comes from [Smogon forums' RMT section](https://www.smogon.com/). - -## Data - -Data files are adapted version of the `js` data files of [Pokemon Showdown](https://github.com/Zarel/Pokemon-Showdown). - -## License -[![License: MIT](https://img.shields.io/badge/License-MIT-yellow.svg)](https://opensource.org/licenses/MIT) ## Citing `poke-env` diff --git a/pyproject.toml b/pyproject.toml index 3c4e6d563..0d614e20c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta" [project] name = "poke_env" -version = "0.8.3" +version = "0.8.3.2" description = "A python interface for training Reinforcement Learning bots to battle on pokemon showdown." readme = "README.md" requires-python = ">=3.9.0" diff --git a/src/poke_env/data/gen_data.py b/src/poke_env/data/gen_data.py index 9f8c3cc10..047996805 100644 --- a/src/poke_env/data/gen_data.py +++ b/src/poke_env/data/gen_data.py @@ -34,7 +34,15 @@ def load_moves(self, gen: int) -> Dict[str, Any]: with open( os.path.join(self._static_files_root, "moves", f"gen{gen}moves.json") ) as f: - return orjson.loads(f.read()) + data = orjson.loads(f.read()) + + # manually fix data entries gathered from Showdown data files + if gen == 1: + data["recover"]["heal"] = [1, 2] + data["softboiled"]["heal"] = [1, 2] + # TODO: check vicegrip / visegrip + + return data def load_natures(self) -> Dict[str, Dict[str, Union[int, float]]]: with open(os.path.join(self._static_files_root, "natures.json")) as f: @@ -64,6 +72,12 @@ def load_pokedex(self, gen: int) -> Dict[str, Any]: dex.update(other_forms_dex) for name, value in dex.items(): + if gen <= 2 and "abilities" in value: + # remove abilities from gen 1-2. Gens before abilities + # existed will often list an "ability" called "No Ability". + # Because it is the only option, `Pokemon` will assume it + # is active at the start of the battle. + value["abilities"] = {"0": "No Ability"} if "baseSpecies" in value: value["species"] = value["baseSpecies"] else: diff --git a/src/poke_env/environment/__init__.py b/src/poke_env/environment/__init__.py index 4a5b71411..fc0bbf7f3 100644 --- a/src/poke_env/environment/__init__.py +++ b/src/poke_env/environment/__init__.py @@ -6,8 +6,6 @@ field, move, move_category, - observation, - observed_pokemon, pokemon, pokemon_gender, pokemon_type, @@ -24,8 +22,6 @@ from poke_env.environment.field import Field from poke_env.environment.move import SPECIAL_MOVES, EmptyMove, Move from poke_env.environment.move_category import MoveCategory -from poke_env.environment.observation import Observation -from poke_env.environment.observed_pokemon import ObservedPokemon from poke_env.environment.pokemon import Pokemon from poke_env.environment.pokemon_gender import PokemonGender from poke_env.environment.pokemon_type import PokemonType @@ -44,8 +40,6 @@ "Field", "Move", "MoveCategory", - "Observation", - "ObservedPokemon", "Pokemon", "PokemonGender", "PokemonType", diff --git a/src/poke_env/environment/abstract_battle.py b/src/poke_env/environment/abstract_battle.py index 9b7f2c259..bd2261f9b 100644 --- a/src/poke_env/environment/abstract_battle.py +++ b/src/poke_env/environment/abstract_battle.py @@ -7,8 +7,6 @@ from poke_env.data import GenData, to_id_str from poke_env.data.replay_template import REPLAY_TEMPLATE from poke_env.environment.field import Field -from poke_env.environment.observation import Observation -from poke_env.environment.observed_pokemon import ObservedPokemon from poke_env.environment.pokemon import Pokemon from poke_env.environment.side_condition import STACKABLE_CONDITIONS, SideCondition from poke_env.environment.weather import Weather @@ -37,6 +35,7 @@ class AbstractBattle(ABC): "J", "L", "askreg", + "badge", "c", "chat", "crit", @@ -53,7 +52,9 @@ class AbstractBattle(ABC): "leave", "n", "name", + "noinit", "rated", + "rename", "resisted", "sentchoice", "split", @@ -74,7 +75,6 @@ class AbstractBattle(ABC): "_can_mega_evolve", "_can_tera", "_can_z_move", - "_current_observation", "_data", "_dynamax_turn", "_fields", @@ -86,8 +86,6 @@ class AbstractBattle(ABC): "_last_request", "_max_team_size", "_maybe_trapped", - "_move_on_next_request", - "_observations", "_opponent_can_dynamax", "_opponent_can_mega_evolve", "_opponent_can_terrastallize", @@ -152,7 +150,6 @@ def __init__( # Turn choice attributes self.in_team_preview: bool = False - self._move_on_next_request: bool = False self._wait: Optional[bool] = None # Battle state attributes @@ -181,10 +178,6 @@ def __init__( self._team: Dict[str, Pokemon] = {} self._opponent_team: Dict[str, Pokemon] = {} - # Initialize Observations - self._observations: Dict[int, Observation] = {} - self._current_observation: Observation = Observation() - def get_pokemon( self, identifier: str, @@ -367,9 +360,6 @@ def field_start(self, field_str: str): self._fields[field] = self.turn def _finish_battle(self): - # Recording the battle state and save events as we finish up - self.observations[self.turn] = self._current_observation - if self._save_replays: if self._save_replays is True: folder = "replays" @@ -387,7 +377,6 @@ def _finish_battle(self): encoding="utf-8", ) as f: formatted_replay = REPLAY_TEMPLATE - formatted_replay = formatted_replay.replace( "{BATTLE_TAG}", f"{self.battle_tag}" ) @@ -398,23 +387,17 @@ def _finish_battle(self): "{OPPONENT_USERNAME}", f"{self._opponent_username}" ) replay_log = f">{self.battle_tag}" + "\n".join( - [ - "|".join(split_message) - for turn in sorted(self._observations.keys()) - for split_message in self._observations[turn].events - ] + ["|".join(split_message) for split_message in self._replay_data] ) formatted_replay = formatted_replay.replace("{REPLAY_LOG}", replay_log) - f.write(formatted_replay) self._finished = True def parse_message(self, split_message: List[str]): - self._current_observation.events.append(split_message) + if self._save_replays: + self._replay_data.append(split_message) - # We copy because we directly modify split_message in poke-env; this is to - # preserve further usage of this event upstream event = split_message[:] if event[1] in self.MESSAGES_TO_IGNORE: @@ -428,6 +411,9 @@ def parse_message(self, split_message: List[str]): self._check_damage_message_for_item(event) self._check_damage_message_for_ability(event) elif event[1] == "move": + # JAKE: this is inaccurate for early gens, but for some reason I left it like this + # during all the metamon evals. Now that we'll be doing a new `Battle`, leave as it + # was for backwards compatibility. failed = False override_move = None reveal_other_move = False @@ -457,14 +443,12 @@ def parse_message(self, split_message: List[str]): reveal_other_move = True elif override_move in {"Copycat", "Metronome", "Nature Power"}: pass - elif override_move in {"Grass Pledge", "Water Pledge", "Fire Pledge"}: - override_move = None elif self.logger is not None: self.logger.warning( "Unmanaged [from]move message received - move %s in cleaned up " "message %s in battle %s turn %d", override_move, - event, + split_message, self.battle_tag, self.turn, ) @@ -474,7 +458,7 @@ def parse_message(self, split_message: List[str]): if event[-1].startswith("[from]ability: "): revealed_ability = event.pop()[15:] - pokemon = event[2] + pokemon = split_message[2] self.get_pokemon(pokemon).ability = revealed_ability if revealed_ability == "Magic Bounce": @@ -486,7 +470,7 @@ def parse_message(self, split_message: List[str]): "Unmanaged [from]ability: message received - ability %s in " "cleaned up message %s in battle %s turn %d", revealed_ability, - event, + split_message, self.battle_tag, self.turn, ) @@ -523,14 +507,17 @@ def parse_message(self, split_message: List[str]): ) else: pokemon, move, presumed_target = event[2:5] - if self.logger is not None: - self.logger.warning( - "Unmanaged move message format received - cleaned up message %s in " - "battle %s turn %d", - event, - self.battle_tag, - self.turn, - ) + if len(event) == 6 and "[from]" in event[-1]: + pass + else: + if self.logger is not None: + self.logger.warning( + "Unmanaged move message format received - cleaned up message %s in " + "battle %s turn %d", + event, + self.battle_tag, + self.turn, + ) # Check if a silent-effect move has occurred (Minimize) and add the effect if move.upper().strip() == "MINIMIZE": @@ -544,49 +531,12 @@ def parse_message(self, split_message: List[str]): self.get_pokemon(pokemon).moved(override_move, failed=failed, use=False) if override_move is None or reveal_other_move: self.get_pokemon(pokemon).moved(move, failed=failed, use=False) + elif event[1] == "cant": pokemon, _ = event[2:4] self.get_pokemon(pokemon).cant_move() elif event[1] == "turn": - # Saving the beginning-of-turn battle state and events as we go into the turn - self.observations[self.turn] = self._current_observation - self.end_turn(int(event[2])) - - opp_active_mon, active_mon = None, None - if isinstance(self.opponent_active_pokemon, Pokemon): - opp_active_mon = ObservedPokemon.from_pokemon( - self.opponent_active_pokemon - ) - active_mon = ObservedPokemon.from_pokemon(self.active_pokemon) - else: - opp_active_mon = [ - ObservedPokemon.from_pokemon(mon) - for mon in self.opponent_active_pokemon - ] - active_mon = [ - ObservedPokemon.from_pokemon(mon) for mon in self.active_pokemon - ] - - # Create new Observation and record battle state going into the next turn - self._current_observation = Observation( - side_conditions={k: v for (k, v) in self.side_conditions.items()}, - opponent_side_conditions={ - k: v for (k, v) in self.opponent_side_conditions.items() - }, - weather={k: v for (k, v) in self.weather.items()}, - fields={k: v for (k, v) in self.fields.items()}, - active_pokemon=active_mon, - team={ - ident: ObservedPokemon.from_pokemon(mon) - for (ident, mon) in self.team.items() - }, - opponent_active_pokemon=opp_active_mon, - opponent_team={ - ident: ObservedPokemon.from_pokemon(mon) - for (ident, mon) in self.opponent_team.items() - }, - ) elif event[1] == "-heal": pokemon, hp_status = event[2:4] self.get_pokemon(pokemon).heal(hp_status) @@ -635,7 +585,10 @@ def parse_message(self, split_message: List[str]): self.opponent_can_dynamax = False elif event[1] == "-activate": target, effect = event[2:4] - if target and effect == "move: Skill Swap": + if effect.startswith("ability: "): + ability = effect[9:] + self.get_pokemon(target).ability = ability + elif target and effect == "move: Skill Swap": self.get_pokemon(target).start_effect(effect, event[4:6]) actor = event[6].replace("[of] ", "") self.get_pokemon(actor).set_temporary_ability(event[5]) @@ -752,8 +705,8 @@ def parse_message(self, split_message: List[str]): elif event[1] == "-prepare": try: attacker, move, defender = event[2:5] - defender_mon = self.get_pokemon(defender) - if to_id_str(move) == "skydrop": + defender_mon = self.get_pokemon(defender) if defender != "[premajor]" else None + if defender_mon is not None and to_id_str(move) == "skydrop": defender_mon.start_effect("Sky Drop") except ValueError: attacker, move = event[2:4] @@ -781,11 +734,31 @@ def parse_message(self, split_message: List[str]): source, target, stats = event[2:5] source_mon = self.get_pokemon(source) target_mon = self.get_pokemon(target) - for stat in stats.split(", "): - source_mon.boosts[stat], target_mon.boosts[stat] = ( - target_mon.boosts[stat], - source_mon.boosts[stat], - ) + if "[from]" in stats: + if "guardswap" in stats: + # JAKE: need to use metamon to check if this still ever triggers + all_stats = ["def", "spd"] + else: + all_stats = [ + "accuracy", + "atk", + "def", + "evasion", + "spa", + "spd", + "spe", + ] + for stat in all_stats: + source_mon.boosts[stat], target_mon.boosts[stat] = ( + target_mon.boosts[stat], + source_mon.boosts[stat], + ) + else: + for stat in stats.split(", "): + source_mon.boosts[stat], target_mon.boosts[stat] = ( + target_mon.boosts[stat], + source_mon.boosts[stat], + ) elif event[1] == "-transform": pokemon, into = event[2:4] self.get_pokemon(pokemon).transform(self.get_pokemon(into)) @@ -924,7 +897,7 @@ def side_end(self, side: str, condition_str: str): else: conditions = self.opponent_side_conditions condition = SideCondition.from_showdown_message(condition_str) - if condition is not SideCondition.UNKNOWN: + if condition is not SideCondition.UNKNOWN and condition in conditions: conditions.pop(condition) def _side_start(self, side: str, condition_str: str): @@ -1020,16 +993,6 @@ def can_z_move(self) -> Any: def can_tera(self) -> Any: pass - @property - def current_observation(self) -> Observation: - """ - :return: The current observation of the current turn in the Battle. - Most useful for when a force_switch triggers in the middle of a - turn, and our player has to return an action. - :rtype: Observation - """ - return self._current_observation - @property def dynamax_turns_left(self) -> Optional[int]: """ @@ -1114,16 +1077,6 @@ def max_team_size(self) -> Optional[int]: def maybe_trapped(self) -> Any: pass - @property - def observations(self) -> Dict[int, Observation]: - """ - :return: Observations of the battle on a turn, where the key is the turn number. - The Observation stores the battle state at the beginning of the turn, - and all the events that transpired on that turn. - :rtype: Dict[int, Observation] - """ - return self._observations - @property @abstractmethod def opponent_active_pokemon(self) -> Any: @@ -1378,19 +1331,6 @@ def won(self) -> Optional[bool]: """ return self._won - @property - def move_on_next_request(self) -> bool: - """ - :return: Wheter the next received request should yield a move order directly. - This can happen when a switch is forced, or an error is encountered. - :rtype: bool - """ - return self._move_on_next_request - - @move_on_next_request.setter - def move_on_next_request(self, value: bool): - self._move_on_next_request = value - @property def reviving(self) -> bool: return self._reviving diff --git a/src/poke_env/environment/battle.py b/src/poke_env/environment/battle.py index 655baf728..558df9846 100644 --- a/src/poke_env/environment/battle.py +++ b/src/poke_env/environment/battle.py @@ -81,9 +81,6 @@ def parse_request(self, request: Dict[str, Any]) -> None: self._trapped = False self._force_switch = request.get("forceSwitch", [False])[0] - if self._force_switch: - self._move_on_next_request = True - self._last_request = request if request.get("teamPreview", False): @@ -129,7 +126,7 @@ def parse_request(self, request: Dict[str, Any]) -> None: if not self.trapped and self.reviving: for pokemon in side["pokemon"]: - if pokemon and pokemon.get("reviving", False): + if pokemon and not pokemon.get("reviving", False): pokemon = self._team[pokemon["ident"]] if not pokemon.active: self._available_switches.append(pokemon) diff --git a/src/poke_env/environment/effect.py b/src/poke_env/environment/effect.py index 711b3327b..d9809f251 100644 --- a/src/poke_env/environment/effect.py +++ b/src/poke_env/environment/effect.py @@ -1,5 +1,4 @@ -"""This module defines the Effect class, which represents in-game effects. -""" +"""This module defines the Effect class, which represents in-game effects.""" from __future__ import annotations @@ -24,6 +23,7 @@ class Effect(Enum): BANEFUL_BUNKER = auto() BATTLE_BOND = auto() BEAK_BLAST = auto() + BEAT_UP = auto() BIDE = auto() BIND = auto() BURNING_BULWARK = auto() @@ -75,6 +75,7 @@ class Effect(Enum): FUTURE_SIGHT = auto() GASTRO_ACID = auto() GLAIVE_RUSH = auto() + GOOEY = auto() GRAVITY = auto() GRUDGE = auto() GUARD_SPLIT = auto() @@ -85,6 +86,7 @@ class Effect(Enum): G_MAX_RAPID_FLOW = auto() G_MAX_SANDBLAST = auto() HADRON_ENGINE = auto() + HAZE = auto() HEAL_BELL = auto() HEAL_BLOCK = auto() HEALER = auto() @@ -107,6 +109,7 @@ class Effect(Enum): LEECH_SEED = auto() LEPPA_BERRY = auto() LIGHTNING_ROD = auto() + LIGHT_SCREEN = auto() LIMBER = auto() LIQUID_OOZE = auto() LOCKED_MOVE = auto() @@ -124,8 +127,10 @@ class Effect(Enum): MIRACLE_EYE = auto() MIST = auto() MISTY_TERRAIN = auto() + MUD_SPORT = auto() MUMMY = auto() MUST_RECHARGE = auto() + MYSTERY_BERRY = auto() NEUTRALIZING_GAS = auto() NIGHTMARE = auto() NO_RETREAT = auto() @@ -169,6 +174,7 @@ class Effect(Enum): QUICK_GUARD = auto() RAGE = auto() RAGE_POWDER = auto() + RAMPAGE = auto() REFLECT = auto() RIPEN = auto() ROOST = auto() @@ -239,6 +245,14 @@ class Effect(Enum): def __str__(self) -> str: return f"{self.name} (effect) object" + @staticmethod + def _manual_message_corrections(message: str) -> str: + if message == "FALLENUNDEFINED": + return "FALLEN" + elif message == "LIGHTSCREEN": + return "LIGHT_SCREEN" + return message + @staticmethod def from_showdown_message(message: str) -> Effect: """Returns the Effect object corresponding to the message. @@ -254,21 +268,25 @@ def from_showdown_message(message: str) -> Effect: message = message.replace(" ", "_") message = message.replace("-", "_") message = message.upper() - - if message == "FALLENUNDEFINED": - message = "FALLEN" + message = Effect._manual_message_corrections(message) try: return Effect[message] except KeyError: - logging.getLogger("poke-env").warning( - "Unexpected effect '%s' received. Effect.UNKNOWN will be used instead. " - "If this is unexpected, please open an issue at " - "https://github.com/hsahovic/poke-env/issues/ along with this error " - "message and a description of your program.", - message, - ) - return Effect.UNKNOWN + # catch inconsistent use of whitespace both within Showdown protocol + # and between sim protocol and static data + for effect in Effect: + if effect.name.replace("_", "") == message: + return effect + # if we get here, we didn't find a match + logging.getLogger("poke-env").warning( + "Unexpected effect '%s' received. Effect.UNKNOWN will be used instead. " + "If this is unexpected, please open an issue at " + "https://github.com/hsahovic/poke-env/issues/ along with this error " + "message and a description of your program.", + message, + ) + return Effect.UNKNOWN @staticmethod def from_data(message: str) -> Effect: @@ -279,21 +297,9 @@ def from_data(message: str) -> Effect: :return: The corresponding Effect object. :rtype: Effect """ - message = message.replace("_", "") - message = message.replace(" ", "") - message = message.replace("-", "") - message = message.upper() - try: - return _FROM_DATA[message] - except KeyError: - logging.getLogger("poke-env").warning( - "Unexpected effect '%s' received. Effect.UNKNOWN will be used instead. " - "If this is unexpected, please open an issue at " - "https://github.com/hsahovic/poke-env/issues/ along with this error " - "message and a description of your program.", - message, - ) - return Effect.UNKNOWN + # JAKE: this should no longer be necessary, but leave the door open for + # one-off changes specific to how the effects are stored in static data + return Effect.from_showdown_message(message) @property def breaks_protect(self): @@ -771,227 +777,3 @@ def is_from_move(self) -> bool: } _ACTION_COUNTER_EFFECTS: Set[Effect] = {Effect.RAGE, Effect.STOCKPILE} - -_FROM_DATA: Dict[str, Effect] = { - "UNKNOWN": Effect.UNKNOWN, - "AFTERYOU": Effect.AFTER_YOU, - "AFTERMATH": Effect.AFTERMATH, - "AQUARING": Effect.AQUA_RING, - "AROMATHERAPY": Effect.AROMATHERAPY, - "AROMAVEIL": Effect.AROMA_VEIL, - "ATTRACT": Effect.ATTRACT, - "AUTOTOMIZE": Effect.AUTOTOMIZE, - "BADDREAMS": Effect.BAD_DREAMS, - "BANEFULBUNKER": Effect.BANEFUL_BUNKER, - "BATTLEBOND": Effect.BATTLE_BOND, - "BIDE": Effect.BIDE, - "BIND": Effect.BIND, - "BURNINGBULWARK": Effect.BURNING_BULWARK, - "BURNUP": Effect.BURN_UP, - "CELEBRATE": Effect.CELEBRATE, - "CHARGE": Effect.CHARGE, - "CLAMP": Effect.CLAMP, - "COMMANDER": Effect.COMMANDER, - "CONFUSION": Effect.CONFUSION, - "COURTCHANGE": Effect.COURT_CHANGE, - "CRAFTYSHIELD": Effect.CRAFTY_SHIELD, - "CUDCHEW": Effect.CUD_CHEW, - "CURSE": Effect.CURSE, - "CUSTAPBERRY": Effect.CUSTAP_BERRY, - "DANCER": Effect.DANCER, - "DEFENSECURL": Effect.DEFENSE_CURL, - "DESTINYBOND": Effect.DESTINY_BOND, - "DISABLE": Effect.DISABLE, - "DISGUISE": Effect.DISGUISE, - "DOOMDESIRE": Effect.DOOM_DESIRE, - "DRAGONCHEER": Effect.DRAGON_CHEER, - "DYNAMAX": Effect.DYNAMAX, - "EERIESPELL": Effect.EERIE_SPELL, - "ELECTRICTERRAIN": Effect.ELECTRIC_TERRAIN, - "ELECTRIFY": Effect.ELECTRIFY, - "EMBARGO": Effect.EMBARGO, - "EMERGENCYEXIT": Effect.EMERGENCY_EXIT, - "ENCORE": Effect.ENCORE, - "ENDURE": Effect.ENDURE, - "FALLEN": Effect.FALLEN, - "FALLEN1": Effect.FALLEN1, - "FALLEN2": Effect.FALLEN2, - "FALLEN3": Effect.FALLEN3, - "FALLEN4": Effect.FALLEN4, - "FALLEN5": Effect.FALLEN5, - "FAIRYLOCK": Effect.FAIRY_LOCK, - "FEINT": Effect.FEINT, - "FICKLEBEAM": Effect.FICKLE_BEAM, - "FIRESPIN": Effect.FIRE_SPIN, - "FLASHFIRE": Effect.FLASH_FIRE, - "FLINCH": Effect.FLINCH, - "FLOWERVEIL": Effect.FLOWER_VEIL, - "FOCUSBAND": Effect.FOCUS_BAND, - "FOCUSENERGY": Effect.FOCUS_ENERGY, - "FOLLOWME": Effect.FOLLOW_ME, - "FORESIGHT": Effect.FORESIGHT, - "FOREWARN": Effect.FOREWARN, - "FUTURESIGHT": Effect.FUTURE_SIGHT, - "GASTROACID": Effect.GASTRO_ACID, - "GLAIVERUSH": Effect.GLAIVE_RUSH, - "GRAVITY": Effect.GRAVITY, - "GRUDGE": Effect.GRUDGE, - "GUARDSPLIT": Effect.GUARD_SPLIT, - "GULPMISSILE": Effect.GULP_MISSILE, - "GMAXCENTIFERNO": Effect.G_MAX_CENTIFERNO, - "GMAXCHISTRIKE": Effect.G_MAX_CHI_STRIKE, - "GMAXONEBLOW": Effect.G_MAX_ONE_BLOW, - "GMAXRAPIDFLOW": Effect.G_MAX_RAPID_FLOW, - "GMAXSANDBLAST": Effect.G_MAX_SANDBLAST, - "HADRONENGINE": Effect.HADRON_ENGINE, - "HEALBELL": Effect.HEAL_BELL, - "HEALBLOCK": Effect.HEAL_BLOCK, - "HEALER": Effect.HEALER, - "HELPINGHAND": Effect.HELPING_HAND, - "HYDRATION": Effect.HYDRATION, - "HYPERSPACEFURY": Effect.HYPERSPACE_FURY, - "HYPERSPACEHOLE": Effect.HYPERSPACE_HOLE, - "ICEFACE": Effect.ICE_FACE, - "ILLUSION": Effect.ILLUSION, - "IMMUNITY": Effect.IMMUNITY, - "IMPRISON": Effect.IMPRISON, - "INFESTATION": Effect.INFESTATION, - "INGRAIN": Effect.INGRAIN, - "INNARDSOUT": Effect.INNARDS_OUT, - "INSTRUCT": Effect.INSTRUCT, - "INSOMNIA": Effect.INSOMNIA, - "IRONBARBS": Effect.IRON_BARBS, - "KINGSSHIELD": Effect.KINGS_SHIELD, - "LASERFOCUS": Effect.LASER_FOCUS, - "LEECHSEED": Effect.LEECH_SEED, - "LEPPABERRY": Effect.LEPPA_BERRY, - "LIGHTNINGROD": Effect.LIGHTNING_ROD, - "LIMBER": Effect.LIMBER, - "LIQUIDOOZE": Effect.LIQUID_OOZE, - "LOCKEDMOVE": Effect.LOCKED_MOVE, - "LOCKON": Effect.LOCK_ON, - "MAGICCOAT": Effect.MAGIC_COAT, - "MAGMASTORM": Effect.MAGMA_STORM, - "MAGNETRISE": Effect.MAGNET_RISE, - "MAGNITUDE": Effect.MAGNITUDE, - "MATBLOCK": Effect.MAT_BLOCK, - "MAXGUARD": Effect.MAX_GUARD, - "MIMIC": Effect.MIMIC, - "MIMICRY": Effect.MIMICRY, - "MINDREADER": Effect.MIND_READER, - "MINIMIZE": Effect.MINIMIZE, - "MIRACLEEYE": Effect.MIRACLE_EYE, - "MIST": Effect.MIST, - "MISTYTERRAIN": Effect.MISTY_TERRAIN, - "MUMMY": Effect.MUMMY, - "MUSTRECHARGE": Effect.MUST_RECHARGE, - "NEUTRALIZINGGAS": Effect.NEUTRALIZING_GAS, - "NIGHTMARE": Effect.NIGHTMARE, - "NORETREAT": Effect.NO_RETREAT, - "OBLIVIOUS": Effect.OBLIVIOUS, - "OBSTRUCT": Effect.OBSTRUCT, - "OCTOLOCK": Effect.OCTOLOCK, - "ORICHALCUMPULSE": Effect.ORICHALCUM_PULSE, - "OWNTEMPO": Effect.OWN_TEMPO, - "PARTIALLYTRAPPED": Effect.PARTIALLY_TRAPPED, - "PASTELVEIL": Effect.PASTEL_VEIL, - "PERISH0": Effect.PERISH0, - "PERISH1": Effect.PERISH1, - "PERISH2": Effect.PERISH2, - "PERISH3": Effect.PERISH3, - "PHANTOMFORCE": Effect.PHANTOM_FORCE, - "POLTERGEIST": Effect.POLTERGEIST, - "POWDER": Effect.POWDER, - "POWERCONSTRUCT": Effect.POWER_CONSTRUCT, - "POWERSHIFT": Effect.POWER_SHIFT, - "POWERSPLIT": Effect.POWER_SPLIT, - "POWERTRICK": Effect.POWER_TRICK, - "PROTECT": Effect.PROTECT, - "PROTECTIVEPADS": Effect.PROTECTIVE_PADS, - "PROTOSYNTHESIS": Effect.PROTOSYNTHESIS, - "PROTOSYNTHESISATK": Effect.PROTOSYNTHESISATK, - "PROTOSYNTHESISDEF": Effect.PROTOSYNTHESISDEF, - "PROTOSYNTHESISSPA": Effect.PROTOSYNTHESISSPA, - "PROTOSYNTHESISSPD": Effect.PROTOSYNTHESISSPD, - "PROTOSYNTHESISSPE": Effect.PROTOSYNTHESISSPE, - "PSYCHICTERRAIN": Effect.PSYCHIC_TERRAIN, - "PURSUIT": Effect.PURSUIT, - "QUARKDRIVE": Effect.QUARK_DRIVE, - "QUARKDRIVEATK": Effect.QUARKDRIVEATK, - "QUARKDRIVEDEF": Effect.QUARKDRIVEDEF, - "QUARKDRIVESPA": Effect.QUARKDRIVESPA, - "QUARKDRIVESPD": Effect.QUARKDRIVESPD, - "QUARKDRIVESPE": Effect.QUARKDRIVESPE, - "QUASH": Effect.QUASH, - "QUICKCLAW": Effect.QUICK_CLAW, - "QUICKDRAW": Effect.QUICK_DRAW, - "QUICKGUARD": Effect.QUICK_GUARD, - "RAGE": Effect.RAGE, - "RAGEPOWDER": Effect.RAGE_POWDER, - "REFLECT": Effect.REFLECT, - "RIPEN": Effect.RIPEN, - "ROOST": Effect.ROOST, - "ROUGHSKIN": Effect.ROUGH_SKIN, - "SAFEGUARD": Effect.SAFEGUARD, - "SAFETYGOGGLES": Effect.SAFETY_GOGGLES, - "SALTCURE": Effect.SALT_CURE, - "SANDTOMB": Effect.SAND_TOMB, - "SCREENCLEANER": Effect.SCREEN_CLEANER, - "SHADOWFORCE": Effect.SHADOW_FORCE, - "SHEDSKIN": Effect.SHED_SKIN, - "SILKTRAP": Effect.SILK_TRAP, - "SKETCH": Effect.SKETCH, - "SKILLSWAP": Effect.SKILL_SWAP, - "SKYDROP": Effect.SKY_DROP, - "SLOWSTART": Effect.SLOW_START, - "SMACKDOWN": Effect.SMACK_DOWN, - "SNAPTRAP": Effect.SNAP_TRAP, - "SNATCH": Effect.SNATCH, - "SPARKLINGARIA": Effect.SPARKLING_ARIA, - "SPEEDSWAP": Effect.SPEED_SWAP, - "SPIKYSHIELD": Effect.SPIKY_SHIELD, - "SPITE": Effect.SPITE, - "SPOTLIGHT": Effect.SPOTLIGHT, - "STICKY_HOLD": Effect.STICKY_HOLD, - "STICKY_WEB": Effect.STICKY_WEB, - "STOCKPILE": Effect.STOCKPILE, - "STOCKPILE1": Effect.STOCKPILE1, - "STOCKPILE2": Effect.STOCKPILE2, - "STOCKPILE3": Effect.STOCKPILE3, - "STORMDRAIN": Effect.STORM_DRAIN, - "STRUGGLE": Effect.STRUGGLE, - "SUBSTITUTE": Effect.SUBSTITUTE, - "SUCTIONCUPS": Effect.SUCTION_CUPS, - "SUPREMEOVERLORD": Effect.SUPREME_OVERLORD, - "SYRUPBOMB": Effect.SYRUP_BOMB, - "SWEETVEIL": Effect.SWEET_VEIL, - "SYMBIOSIS": Effect.SYMBIOSIS, - "SYNCHRONIZE": Effect.SYNCHRONIZE, - "TARSHOT": Effect.TAR_SHOT, - "TAUNT": Effect.TAUNT, - "TELEKINESIS": Effect.TELEKINESIS, - "TELEPATHY": Effect.TELEPATHY, - "TERASHELL": Effect.TERA_SHELL, - "TERASHIFT": Effect.TERA_SHIFT, - "TIDYUP": Effect.TIDY_UP, - "TOXICDEBRIS": Effect.TOXIC_DEBRIS, - "THERMALEXCHANGE": Effect.THERMAL_EXCHANGE, - "THROATCHOP": Effect.THROAT_CHOP, - "THUNDERCAGE": Effect.THUNDER_CAGE, - "TORMENT": Effect.TORMENT, - "TRAPPED": Effect.TRAPPED, - "TRICK": Effect.TRICK, - "TYPEADD": Effect.TYPEADD, - "TYPECHANGE": Effect.TYPECHANGE, - "UPROAR": Effect.UPROAR, - "VITALSPIRIT": Effect.VITAL_SPIRIT, - "WANDERINGSPIRIT": Effect.WANDERING_SPIRIT, - "WATERBUBBLE": Effect.WATER_BUBBLE, - "WATERVEIL": Effect.WATER_VEIL, - "WHIRLPOOL": Effect.WHIRLPOOL, - "WIDEGUARD": Effect.WIDE_GUARD, - "WIMPOUT": Effect.WIMP_OUT, - "WRAP": Effect.WRAP, - "YAWN": Effect.YAWN, - "ZERO_TO_HERO": Effect.ZERO_TO_HERO, -} diff --git a/src/poke_env/environment/field.py b/src/poke_env/environment/field.py index 83b5af6b7..430b020ef 100644 --- a/src/poke_env/environment/field.py +++ b/src/poke_env/environment/field.py @@ -1,5 +1,4 @@ -"""This module defines the Field class, which represents a battle field. -""" +"""This module defines the Field class, which represents a battle field.""" from __future__ import annotations diff --git a/src/poke_env/environment/move.py b/src/poke_env/environment/move.py index c1352beb8..13d5a3439 100644 --- a/src/poke_env/environment/move.py +++ b/src/poke_env/environment/move.py @@ -71,6 +71,7 @@ class Move: PokemonType.ROCK: MoveCategory.PHYSICAL, PokemonType.STEEL: MoveCategory.PHYSICAL, PokemonType.WATER: MoveCategory.SPECIAL, + PokemonType.THREE_QUESTION_MARKS: MoveCategory.PHYSICAL, } __slots__ = ( @@ -80,7 +81,7 @@ class Move: "_dynamaxed_move", "_gen", "_is_empty", - "_moves_dict", + "_gen_data", "_request_target", ) @@ -88,7 +89,7 @@ def __init__(self, move_id: str, gen: int, raw_id: Optional[str] = None): self._id = move_id self._base_power_override = None self._gen = gen - self._moves_dict = GenData.from_gen(gen).moves + self._gen_data = GenData.from_gen(gen) if move_id.startswith("hiddenpower") and raw_id is not None: base_power = "".join([c for c in raw_id if c.isdigit()]) @@ -106,6 +107,11 @@ def __init__(self, move_id: str, gen: int, raw_id: Optional[str] = None): self._dynamaxed_move = None self._request_target = None + @property + def _moves_dict(self) -> Dict[str, Any]: + # we'll only need the moves dict, but let GenData shallow copy + return self._gen_data.moves + def __repr__(self) -> str: return f"{self._id} (Move object)" @@ -297,7 +303,15 @@ def entry(self) -> Dict[str, Any]: elif self._id.startswith("z") and self._id[1:] in self._moves_dict: return self._moves_dict[self._id[1:]] elif self._id == "recharge": - return {"pp": 1, "type": "normal", "category": "Special", "accuracy": 1} + return { + "pp": 1, + "type": "normal", + "category": "Special", + "accuracy": 1, + "priority": 0, + "flags": {}, + "target": "self", + } else: raise ValueError("Unknown move: %s" % self._id) diff --git a/src/poke_env/environment/move_category.py b/src/poke_env/environment/move_category.py index fa5b86c9a..e71137755 100644 --- a/src/poke_env/environment/move_category.py +++ b/src/poke_env/environment/move_category.py @@ -1,5 +1,4 @@ -"""This module defines the MoveCategory class, which represents a move category. -""" +"""This module defines the MoveCategory class, which represents a move category.""" from enum import Enum, auto, unique diff --git a/src/poke_env/environment/observation.py b/src/poke_env/environment/observation.py deleted file mode 100644 index 1fcf8eca4..000000000 --- a/src/poke_env/environment/observation.py +++ /dev/null @@ -1,34 +0,0 @@ -"""This module defines the Observation class, which stores the state of the battle. -It is updated whenever a new event is received and processed from showdown. Each observation -records a turn. Each property is the state of the battle at the beginning of the turn, and -the events are ones that occurred that turn. In this way, you can instanciate a new battle -with the Observations' properties, and then recreate that turn with the events property. -""" - -from dataclasses import dataclass, field -from typing import Dict, List, Optional, Union - -from poke_env.environment.field import Field -from poke_env.environment.observed_pokemon import ObservedPokemon -from poke_env.environment.side_condition import SideCondition -from poke_env.environment.weather import Weather - - -@dataclass -class Observation: - side_conditions: Dict[SideCondition, int] = field(default_factory=dict) - opponent_side_conditions: Dict[SideCondition, int] = field(default_factory=dict) - - weather: Dict[Weather, int] = field(default_factory=dict) - fields: Dict[Field, int] = field(default_factory=dict) - - active_pokemon: Union[ObservedPokemon, None, List[ObservedPokemon]] = None - opponent_active_pokemon: Union[ObservedPokemon, List[ObservedPokemon], None] = None - - # The player's team, so we can track states of mons throughout the battle - team: Dict[str, Optional[ObservedPokemon]] = field(default_factory=dict) - - # The opponent's team that has been exposed to the player, for VGC - opponent_team: Dict[str, Optional[ObservedPokemon]] = field(default_factory=dict) - - events: List[List[str]] = field(default_factory=list) diff --git a/src/poke_env/environment/observed_pokemon.py b/src/poke_env/environment/observed_pokemon.py deleted file mode 100644 index b60398e48..000000000 --- a/src/poke_env/environment/observed_pokemon.py +++ /dev/null @@ -1,84 +0,0 @@ -"""This module defines the ObservedPokmon class, which stores -what we have observed about a pokemon throughout a battle -""" - -import sys -from copy import copy -from dataclasses import dataclass, field -from typing import Dict, List, Mapping, Optional, Union - -from poke_env.environment.effect import Effect -from poke_env.environment.move import Move -from poke_env.environment.pokemon import Pokemon -from poke_env.environment.pokemon_gender import PokemonGender -from poke_env.environment.pokemon_type import PokemonType -from poke_env.environment.status import Status - - -@dataclass -class ObservedPokemon: - species: str - level: int - - ability: Optional[str] = None - boosts: Dict[str, int] = field( - default_factory=lambda: { - "accuracy": 0, - "atk": 0, - "def": 0, - "evasion": 0, - "spa": 0, - "spd": 0, - "spe": 0, - } - ) - current_hp_fraction: float = 1.0 - effects: Dict[Effect, int] = field(default_factory=dict) - is_dynamaxed: bool = False - is_terastallized: bool = False - item: Optional[str] = None - gender: Optional[PokemonGender] = None - moves: Dict[str, Move] = field(default_factory=dict) - tera_type: Optional[PokemonType] = None - shiny: Optional[bool] = None - stats: Optional[Mapping[str, Union[List[int], int, None]]] = None - status: Optional[Status] = None - - @staticmethod - def initial_stats() -> Dict[str, List[int]]: - return { - "atk": [0, sys.maxsize], - "def": [0, sys.maxsize], - "spa": [0, sys.maxsize], - "spd": [0, sys.maxsize], - "spe": [0, sys.maxsize], - } - - @staticmethod - def from_pokemon(mon: Pokemon): - if mon is None: - return None - - stats: Optional[Mapping[str, Union[List[int], int, None]]] = ( - ObservedPokemon.initial_stats() - ) - if mon.stats is not None: - stats = {k: v for (k, v) in mon.stats.items()} - - return ObservedPokemon( - species=mon.species, - level=mon.level, - ability=mon.ability, - boosts={k: v for (k, v) in mon.boosts.items()}, - current_hp_fraction=mon.current_hp_fraction, - effects={k: v for (k, v) in mon.effects.items()}, - is_dynamaxed=mon.is_dynamaxed, - is_terastallized=mon.is_terastallized, - item=mon.item, - gender=mon.gender, - moves={k: copy(v) for (k, v) in mon.moves.items()}, - tera_type=mon.tera_type, - shiny=mon.shiny, - stats=stats, - status=mon.status, - ) diff --git a/src/poke_env/environment/pokemon.py b/src/poke_env/environment/pokemon.py index 35387fa36..b740c07a8 100644 --- a/src/poke_env/environment/pokemon.py +++ b/src/poke_env/environment/pokemon.py @@ -39,6 +39,7 @@ class Pokemon: "_possible_abilities", "_preparing_move", "_preparing_target", + "_previous_move", "_protect_counter", "_shiny", "_stats", @@ -109,6 +110,7 @@ def __init__( self._must_recharge: bool = False self._preparing_move: Optional[Move] = None self._preparing_target: Optional[bool | Pokemon] = None + self._previous_move: Optional[Move] = None self._protect_counter: int = 0 self._revealed: bool = False self._stats: Dict[str, Optional[int]] = { @@ -298,6 +300,8 @@ def moved(self, move_id: str, failed: bool = False, use: bool = True): self._preparing_move = None self._preparing_target = None move = self._add_move(move_id, use=use) + self._previous_move = move + self._first_turn = False if move and move.is_protect_counter and not failed: self._protect_counter += 1 @@ -343,6 +347,10 @@ def prepare(self, move_id: str, target: Optional[Pokemon]): self._preparing_move = move self._preparing_target = target + @property + def previous_move(self) -> Optional[Move]: + return self._previous_move + def primal(self): species_id_str = to_id_str(self._species) primal_species = ( @@ -638,19 +646,11 @@ def available_moves_from_request(self, request: Dict[str, Any]) -> List[Move]: [v for m, v in self.moves.items() if m.startswith("hiddenpower")][0] ) else: - assert { - "copycat", - "metronome", - "mefirst", - "mirrormove", - "assist", - "transform", - "mimic", - }.intersection(self.moves), ( - f"Error with move {move}. Expected self.moves to contain copycat, " - "metronome, mefirst, mirrormove, assist, transform or mimic. Got" - f" {self.moves}" - ) + # JAKE: almost always means stolen/imitated/dynamic movesets + # (transform, mimic, etc.). There used to be a sanity-check + # for those moves here, but it fails for known reasons, as + # the move discovery system doesn't have enough info to handle + # the edge cases. moves.append(Move(move, gen=self._data.gen)) return moves @@ -770,8 +770,8 @@ def current_hp_fraction(self) -> float: :rtype: float """ if self.current_hp: - return self.current_hp / self.max_hp - return 0 + return self.current_hp / float(self.max_hp) + return 0.0 @property def effects(self) -> Dict[Effect, int]: diff --git a/src/poke_env/environment/side_condition.py b/src/poke_env/environment/side_condition.py index dec8c9944..51ac4075f 100644 --- a/src/poke_env/environment/side_condition.py +++ b/src/poke_env/environment/side_condition.py @@ -51,18 +51,23 @@ def from_showdown_message(message: str): message = message.replace("move: ", "") message = message.replace(" ", "_") message = message.replace("-", "_") - + message = message.upper() try: - return SideCondition[message.upper()] + return SideCondition[message] except KeyError: - logging.getLogger("poke-env").warning( - "Unexpected side condition '%s' received. SideCondition.UNKNOWN will be" - " used instead. If this is unexpected, please open an issue at " - "https://github.com/hsahovic/poke-env/issues/ along with this error " - "message and a description of your program.", - message, - ) - return SideCondition.UNKNOWN + # catch inconsistent use of whitespace both within Showdown protocol + # and between sim protocol and static data + for known_condition in SideCondition: + if known_condition.name.replace("_", "") == message: + return known_condition + logging.getLogger("poke-env").warning( + "Unexpected side condition '%s' received. SideCondition.UNKNOWN will be" + " used instead. If this is unexpected, please open an issue at " + "https://github.com/hsahovic/poke-env/issues/ along with this error " + "message and a description of your program.", + message, + ) + return SideCondition.UNKNOWN @staticmethod def from_data(message: str): @@ -73,50 +78,10 @@ def from_data(message: str): :return: The corresponding SideCondition object. :rtype: SideCondition """ - message = message.replace("_", "") - message = message.replace(" ", "") - message = message.replace("-", "") - message = message.upper() - - try: - return _FROM_DATA[message] - except KeyError: - logging.getLogger("poke-env").warning( - "Unexpected SideCondition '%s' received. SideCondition.UNKNOWN will be used " - "instead. If this is unexpected, please open an issue at " - "https://github.com/hsahovic/poke-env/issues/ along with this error " - "message and a description of your program.", - message, - ) - return SideCondition.UNKNOWN + # JAKE: this should no longer be necessary, but leave the door open for + # one-off changes specific to how the side conditions are stored in static data + return SideCondition.from_showdown_message(message) # SideCondition -> Max useful stack level STACKABLE_CONDITIONS = {SideCondition.SPIKES: 3, SideCondition.TOXIC_SPIKES: 2} - -_FROM_DATA: Dict[str, SideCondition] = { - "UNKNOWN": SideCondition.UNKNOWN, - "AURORAVEIL": SideCondition.AURORA_VEIL, - "CRAFTYSHIELD": SideCondition.CRAFTY_SHIELD, - "FIREPLEDGE": SideCondition.FIRE_PLEDGE, - "GMAXCANNONADE": SideCondition.G_MAX_CANNONADE, - "GMAXSTEELSURGE": SideCondition.G_MAX_STEELSURGE, - "GMAXVINELASH": SideCondition.G_MAX_VINE_LASH, - "GMAXVOLCALITH": SideCondition.G_MAX_VOLCALITH, - "GMAXWILDFIRE": SideCondition.G_MAX_WILDFIRE, - "GRASSPLEDGE": SideCondition.GRASS_PLEDGE, - "LIGHTSCREEN": SideCondition.LIGHT_SCREEN, - "LUCKYCHANT": SideCondition.LUCKY_CHANT, - "MATBLOCK": SideCondition.MATBLOCK, - "MIST": SideCondition.MIST, - "QUICKGUARD": SideCondition.QUICK_GUARD, - "REFLECT": SideCondition.REFLECT, - "SAFEGUARD": SideCondition.SAFEGUARD, - "SPIKES": SideCondition.SPIKES, - "STEALTHROCK": SideCondition.STEALTH_ROCK, - "STICKYWEB": SideCondition.STICKY_WEB, - "TAILWIND": SideCondition.TAILWIND, - "TOXICSPIKES": SideCondition.TOXIC_SPIKES, - "WATERPLEDGE": SideCondition.WATER_PLEDGE, - "WIDEGUARD": SideCondition.WIDE_GUARD, -} diff --git a/src/poke_env/environment/weather.py b/src/poke_env/environment/weather.py index 28fcc2dd6..b245612ec 100644 --- a/src/poke_env/environment/weather.py +++ b/src/poke_env/environment/weather.py @@ -1,5 +1,4 @@ -"""This module defines the Weather class, which represents a in-battle weather. -""" +"""This module defines the Weather class, which represents a in-battle weather.""" import logging from enum import Enum, auto, unique @@ -17,6 +16,7 @@ class Weather(Enum): RAINDANCE = auto() SANDSTORM = auto() SNOW = auto() + SNOWSCAPE = auto() SUNNYDAY = auto() def __str__(self) -> str: diff --git a/src/poke_env/player/__init__.py b/src/poke_env/player/__init__.py index 0f88467cc..8349adf63 100644 --- a/src/poke_env/player/__init__.py +++ b/src/poke_env/player/__init__.py @@ -1,5 +1,4 @@ -"""poke_env.player module init. -""" +"""poke_env.player module init.""" from poke_env.concurrency import POKE_LOOP from poke_env.player import env_player, openai_api, player, random_player, utils diff --git a/src/poke_env/player/env_player.py b/src/poke_env/player/env_player.py index 61c0f810f..babb01b39 100644 --- a/src/poke_env/player/env_player.py +++ b/src/poke_env/player/env_player.py @@ -1,5 +1,4 @@ -"""This module defines a player class exposing the Open AI Gym API with utility functions. -""" +"""This module defines a player class exposing the Open AI Gym API with utility functions.""" from abc import ABC from threading import Lock diff --git a/src/poke_env/player/openai_api.py b/src/poke_env/player/openai_api.py index 19270c3d5..4c25f198b 100644 --- a/src/poke_env/player/openai_api.py +++ b/src/poke_env/player/openai_api.py @@ -10,7 +10,7 @@ import time from abc import ABC, abstractmethod from logging import Logger -from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Tuple, Union +from typing import Any, Awaitable, Callable, Dict, Generic, List, Optional, Tuple, Union, Type from gymnasium.core import ActType, Env, ObsType from gymnasium.spaces import Discrete, Space @@ -56,42 +56,50 @@ async def async_join(self): await self.queue.join() -class _AsyncPlayer(Generic[ObsType, ActType], Player): - actions: _AsyncQueue - observations: _AsyncQueue +def _create_async_player_class(parent_class: Type[Player]) -> Type[Player]: - def __init__( - self, - user_funcs: OpenAIGymEnv[ObsType, ActType], - username: str, - **kwargs: Any, - ): - self.__class__.__name__ = username - super().__init__(**kwargs) - self.__class__.__name__ = "_AsyncPlayer" - self.observations = _AsyncQueue(create_in_poke_loop(asyncio.Queue, 1)) - self.actions = _AsyncQueue(create_in_poke_loop(asyncio.Queue, 1)) - self.current_battle: Optional[AbstractBattle] = None - self._user_funcs = user_funcs + class _AsyncPlayer(Generic[ObsType, ActType], parent_class): + actions: _AsyncQueue + observations: _AsyncQueue - def choose_move(self, battle: AbstractBattle) -> Awaitable[BattleOrder]: - return self._env_move(battle) - - async def _env_move(self, battle: AbstractBattle) -> BattleOrder: - if not self.current_battle or self.current_battle.finished: - self.current_battle = battle - if not self.current_battle == battle: - raise RuntimeError("Using different battles for queues") - battle_to_send = self._user_funcs.embed_battle(battle) - await self.observations.async_put(battle_to_send) - action = await self.actions.async_get() - if action == -1: - return ForfeitBattleOrder() - return self._user_funcs.action_to_move(action, battle) - - def _battle_finished_callback(self, battle: AbstractBattle): - to_put = self._user_funcs.embed_battle(battle) - asyncio.run_coroutine_threadsafe(self.observations.async_put(to_put), POKE_LOOP) + def __init__( + self, + user_funcs: OpenAIGymEnv[ObsType, ActType], + username: str, + **kwargs: Any, + ): + self.__class__.__name__ = username + super().__init__(**kwargs) + self.__class__.__name__ = "_AsyncPlayer" + self.observations = _AsyncQueue(create_in_poke_loop(asyncio.Queue, 1)) + self.actions = _AsyncQueue(create_in_poke_loop(asyncio.Queue, 1)) + self.current_battle: Optional[AbstractBattle] = None + self._user_funcs = user_funcs + + def choose_move(self, battle: AbstractBattle) -> Awaitable[BattleOrder]: + return self._env_move(battle) + + async def _env_move(self, battle: AbstractBattle) -> BattleOrder: + if not self.current_battle or self.current_battle.finished: + self.current_battle = battle + if not self.current_battle == battle: + raise RuntimeError("Using different battles for queues") + battle_to_send = self._user_funcs.embed_battle(battle) + await self.observations.async_put(battle_to_send) + action = await self.actions.async_get() + if action == -1: + return ForfeitBattleOrder() + return self._user_funcs.action_to_move(action, battle) + + def _battle_finished_callback(self, battle: AbstractBattle): + to_put = self._user_funcs.embed_battle(battle) + asyncio.run_coroutine_threadsafe(self.observations.async_put(to_put), POKE_LOOP) + + return _AsyncPlayer + + +# Create the default _AsyncPlayer class with Player as parent +_AsyncPlayer = _create_async_player_class(Player) class OpenAIGymEnv( @@ -109,6 +117,7 @@ class OpenAIGymEnv( def __init__( self, + player_class: Type[Player] = Player, account_configuration: Optional[AccountConfiguration] = None, *, avatar: Optional[int] = None, @@ -170,7 +179,8 @@ def __init__( leave it inactive. :type start_challenging: bool """ - self.agent = _AsyncPlayer( + player_class = _create_async_player_class(player_class) + self.agent = player_class( self, username=self.__class__.__name__, # type: ignore account_configuration=account_configuration, @@ -350,7 +360,19 @@ def step( return obs, 0.0, False, False, info if self.current_battle.finished: raise RuntimeError("Battle is already finished, call reset") - battle = copy.copy(self.current_battle) + # This deepcopy was removed (412254) because it was a major slowdown (#451), + # and this led to a series of changes that made reward functions much harder to write. + # With no deepcopy, the last_battle was the same as the current_battle, so the + # r(last_battle, current_battle) broke (#662). ---> the reward function was + # changed to r(current_battle)(#671). ---> there was no way to write rews as + # the diff between two turns (e.g., net change in health). ---> a rew buffer + # was added to track the scalar reward from previous turns. ---> Now the rew + # funcs have hidden state, which is bad, and it's much easier to write rews with + # some terms that are net diffs than to write ones where taking the diff of the + # output value does what you want. + # + # --> roll back all of that, and make the deepcopy fast. + battle = copy.deepcopy(self.current_battle) battle.logger = None self.last_battle = battle self._actions.put(action) @@ -507,26 +529,32 @@ async def _ladder_loop( self, n_challenges: Optional[int] = None, callback: Optional[Callable[[AbstractBattle], None]] = None, + sleep_between: Optional[int] = None, ): if n_challenges: if n_challenges <= 0: raise ValueError( f"Number of challenges must be > 0. Got {n_challenges}" ) - for _ in range(n_challenges): + for game_num in range(n_challenges): await self.agent.ladder(1) if callback and self.current_battle is not None: callback(self.current_battle) + if game_num < n_challenges - 1 and sleep_between is not None: + await asyncio.sleep(random.randint(0, sleep_between)) else: while self._keep_challenging: await self.agent.ladder(1) if callback and self.current_battle is not None: callback(self.current_battle) + if sleep_between is not None: + await asyncio.sleep(random.randint(0, sleep_between)) def start_laddering( self, n_challenges: Optional[int] = None, callback: Optional[Callable[[AbstractBattle], None]] = None, + sleep_between: Optional[int] = None, ): """ Starts the laddering loop. @@ -548,7 +576,7 @@ def start_laddering( if not n_challenges: self._keep_challenging = True self._challenge_task = asyncio.run_coroutine_threadsafe( - self._ladder_loop(n_challenges, callback), POKE_LOOP + self._ladder_loop(n_challenges, callback, sleep_between=sleep_between), POKE_LOOP ) async def _stop_challenge_loop( diff --git a/src/poke_env/player/player.py b/src/poke_env/player/player.py index 302b00d81..8546c8842 100644 --- a/src/poke_env/player/player.py +++ b/src/poke_env/player/player.py @@ -1,7 +1,7 @@ -"""This module defines a base class for players. -""" +"""This module defines a base class for players.""" import asyncio +import time import random from abc import ABC, abstractmethod from asyncio import Condition, Event, Queue, Semaphore @@ -144,6 +144,8 @@ def __init__( ) self._battle_end_condition: Condition = create_in_poke_loop(Condition) self._challenge_queue: Queue[Any] = create_in_poke_loop(Queue) + self._waiting: Event = create_in_poke_loop(Event) + self._trying_again: Event = create_in_poke_loop(Event) self._team: Optional[Teambuilder] = None if isinstance(team, Teambuilder): @@ -276,9 +278,10 @@ async def _handle_battle_message(self, split_messages: List[List[str]]): if split_message[2]: request = orjson.loads(split_message[2]) battle.parse_request(request) - if battle.move_on_next_request: + if battle._wait: + self._waiting.set() + else: await self._handle_battle_request(battle) - battle.move_on_next_request = False elif split_message[1] == "win" or split_message[1] == "tie": if split_message[1] == "win": battle.won_by(split_message[2]) @@ -289,6 +292,8 @@ async def _handle_battle_message(self, split_messages: List[List[str]]): self._battle_finished_callback(battle) async with self._battle_end_condition: self._battle_end_condition.notify_all() + if hasattr(self.ps_client, "websocket"): + await self.ps_client.send_message(f"/leave {battle.battle_tag}") elif split_message[1] == "error": self.logger.log( 25, "Error message received: %s", "|".join(split_message) @@ -297,15 +302,14 @@ async def _handle_battle_message(self, split_messages: List[List[str]]): "[Invalid choice] Sorry, too late to make a different move" ): if battle.trapped: - await self._handle_battle_request(battle) + self._trying_again.set() elif split_message[2].startswith( "[Unavailable choice] Can't switch: The active Pokémon is " "trapped" ) or split_message[2].startswith( "[Invalid choice] Can't switch: The active Pokémon is trapped" ): - battle.trapped = True - await self._handle_battle_request(battle) + self._trying_again.set() elif split_message[2].startswith( "[Invalid choice] Can't switch: You can't switch to an active " "Pokémon" @@ -356,12 +360,6 @@ async def _handle_battle_message(self, split_messages: List[List[str]]): await self._handle_battle_request(battle, maybe_default_order=True) else: self.logger.critical("Unexpected error message: %s", split_message) - elif split_message[1] == "turn": - battle.parse_message(split_message) - await self._handle_battle_request(battle) - elif split_message[1] == "teampreview": - battle.parse_message(split_message) - await self._handle_battle_request(battle, from_teampreview_request=True) elif split_message[1] == "bigerror": self.logger.warning("Received 'bigerror' message: %s", split_message) elif split_message[1] == "uhtml" and split_message[2] == "otsrequest": @@ -372,16 +370,15 @@ async def _handle_battle_message(self, split_messages: List[List[str]]): async def _handle_battle_request( self, battle: AbstractBattle, - from_teampreview_request: bool = False, maybe_default_order: bool = False, ): if maybe_default_order and random.random() < self.DEFAULT_CHOICE_CHANCE: message = self.choose_default_move().message elif battle.teampreview: - if not from_teampreview_request: - return message = self.teampreview(battle) else: + if maybe_default_order: + self._trying_again.set() choice = self.choose_move(battle) if isinstance(choice, Awaitable): choice = await choice @@ -663,21 +660,23 @@ def choose_random_move(battle: AbstractBattle) -> BattleOrder: f"battle should be Battle or DoubleBattle. Received {type(battle)}" ) - async def ladder(self, n_games: int): + async def ladder(self, n_games: int, sleep_between: Optional[int] = None): """Make the player play games on the ladder. n_games defines how many battles will be played. :param n_games: Number of battles that will be played :type n_games: int + :param sleep_between: Seconds to wait before challenging again + :type sleep_between: int """ - await handle_threaded_coroutines(self._ladder(n_games)) + await handle_threaded_coroutines(self._ladder(n_games, sleep_between=sleep_between)) - async def _ladder(self, n_games: int): + async def _ladder(self, n_games: int, sleep_between: Optional[int] = None): await self.ps_client.logged_in.wait() start_time = perf_counter() - for _ in range(n_games): + for game_num in range(n_games): async with self._battle_start_condition: await self.ps_client.search_ladder_game(self._format, self.next_team) await self._battle_start_condition.wait() @@ -685,6 +684,8 @@ async def _ladder(self, n_games: int): async with self._battle_end_condition: await self._battle_end_condition.wait() await self._battle_semaphore.acquire() + if game_num < n_games - 1 and sleep_between is not None: + await asyncio.sleep(random.randint(0, sleep_between)) await self._battle_count_queue.join() self.logger.info( "Laddering (%d battles) finished in %fs", @@ -851,11 +852,13 @@ def format_is_doubles(self) -> bool: @property def n_finished_battles(self) -> int: - return len([None for b in self._battles.values() if b.finished]) + battles = list(self._battles.values()) + return len([None for b in battles if b.finished]) @property def n_lost_battles(self) -> int: - return len([None for b in self._battles.values() if b.lost]) + battles = list(self._battles.values()) + return len([None for b in battles if b.lost]) @property def n_tied_battles(self) -> int: @@ -863,7 +866,8 @@ def n_tied_battles(self) -> int: @property def n_won_battles(self) -> int: - return len([None for b in self._battles.values() if b.won]) + battles = list(self._battles.values()) + return len([None for b in battles if b.won]) @property def accept_open_team_sheet(self) -> bool: diff --git a/src/poke_env/player/random_player.py b/src/poke_env/player/random_player.py index f7cd35c25..3055ccb9d 100644 --- a/src/poke_env/player/random_player.py +++ b/src/poke_env/player/random_player.py @@ -1,5 +1,4 @@ -"""This module defines a random players baseline -""" +"""This module defines a random players baseline""" from poke_env.environment import AbstractBattle from poke_env.player.battle_order import BattleOrder diff --git a/src/poke_env/player/utils.py b/src/poke_env/player/utils.py index e2d9dc24d..ff7c4cb36 100644 --- a/src/poke_env/player/utils.py +++ b/src/poke_env/player/utils.py @@ -1,5 +1,4 @@ -"""This module contains utility functions and objects related to Player classes. -""" +"""This module contains utility functions and objects related to Player classes.""" import asyncio import math