diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index bff2c4656b..1a549ae3ec 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -64,7 +64,7 @@ jobs: include-hidden-files: true upload-coverage: - needs: [test] + needs: [ test ] runs-on: ubuntu-latest steps: @@ -151,6 +151,11 @@ jobs: run: | nox -s verify-types + - name: Pyright (tests) + if: always() + run: | + nox -s pyright-tests + - name: Ruff if: always() && !cancelled() run: | @@ -220,7 +225,7 @@ jobs: # Allows us to add this as a required check in Github branch rules, as all the other jobs are subject to change ci-done: - needs: [upload-coverage, linting, twemoji, docs] + needs: [ upload-coverage, linting, twemoji, docs ] if: always() && !cancelled() runs-on: ubuntu-latest diff --git a/hikari/users.py b/hikari/users.py index 65c178ebf2..da45257064 100644 --- a/hikari/users.py +++ b/hikari/users.py @@ -622,12 +622,6 @@ class User(PartialUser, abc.ABC): __slots__: typing.Sequence[str] = () - @property - @abc.abstractmethod - @typing_extensions.override - def app(self) -> traits.RESTAware: - """Client application that models may use for procedures.""" - @property @abc.abstractmethod @typing_extensions.override diff --git a/pipelines/nox.py b/pipelines/nox.py index 27b5445609..b1fadfee88 100644 --- a/pipelines/nox.py +++ b/pipelines/nox.py @@ -28,7 +28,16 @@ NoxCallbackSigT = typing.Callable[[nox.Session], None] # Default sessions should be defined here -nox.options.sessions = ["reformat-code", "codespell", "pytest", "ruff", "slotscheck", "mypy", "verify-types"] +nox.options.sessions = [ + "reformat-code", + "codespell", + "pytest", + "pyright-tests", + "ruff", + "slotscheck", + "mypy", + "verify-types", +] nox.options.default_venv_backend = "uv" diff --git a/pipelines/pyright.nox.py b/pipelines/pyright.nox.py index 2260808510..47791aed81 100644 --- a/pipelines/pyright.nox.py +++ b/pipelines/pyright.nox.py @@ -35,7 +35,14 @@ def pyright(session: nox.Session) -> None: exists to make it easier to test and eventually reach that 100% compatibility. """ nox.sync(session, self=True, extras=["speedups", "server"], groups=["pyright"]) - session.run("pyright") + session.run("pyright", config.MAIN_PACKAGE) + + +@nox.session() +def pyright_tests(session: nox.Session) -> None: + """Perform type analysis on the tests using Pyright.""" + nox.sync(session, self=True, extras=["speedups", "server"], groups=["pyright", "pytest"]) + session.run("pyright", config.TEST_PACKAGE) @nox.session() diff --git a/pyproject.toml b/pyproject.toml index 2c044b75f8..790b26e6a6 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -157,9 +157,7 @@ exclude_lines = [ ] [tool.pyright] -include = ["hikari", "examples"] -exclude = [ - "examples/simple_dashboard.py", +ignore = [ "**/__init__.py", "hikari/internal/enums.py", "hikari/internal/fast_protocol.py", @@ -180,6 +178,11 @@ reportUnknownMemberType = "warning" reportUntypedFunctionDecorator = "warning" reportOptionalMemberAccess = "warning" +executionEnvironments = [ + { root = "hikari" }, + { root = "tests/hikari", reportPrivateUsage = "none" }, +] + [tool.pytest.ini_options] xfail_strict = true norecursedirs = [ @@ -192,7 +195,7 @@ norecursedirs = [ ".venv", "venv", "public", - "ci", + "pipelines", ] # Treat warnings as errors filterwarnings = [ diff --git a/tests/hikari/conftest.py b/tests/hikari/conftest.py new file mode 100644 index 0000000000..5ed562b101 --- /dev/null +++ b/tests/hikari/conftest.py @@ -0,0 +1,159 @@ +# Copyright (c) 2020 Nekokatt +# Copyright (c) 2021-present davfsa +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations + +import datetime + +import mock +import pytest + +from hikari import channels +from hikari import emojis +from hikari import guilds +from hikari import messages +from hikari import snowflakes +from hikari import stickers +from hikari import traits +from hikari import users + + +@pytest.fixture +def hikari_app() -> traits.RESTAware: + return mock.Mock(spec=traits.RESTAware, rest=mock.AsyncMock()) + + +@pytest.fixture +def hikari_partial_guild() -> guilds.PartialGuild: + return guilds.PartialGuild( + app=mock.Mock(), id=snowflakes.Snowflake(123), icon_hash="partial_guild_icon_hash", name="partial_guild" + ) + + +@pytest.fixture +def hikari_guild_text_channel() -> channels.GuildTextChannel: + return channels.GuildTextChannel( + app=mock.Mock(), + id=snowflakes.Snowflake(4560), + name="guild_text_channel_name", + type=channels.ChannelType.GUILD_TEXT, + guild_id=mock.Mock(), # FIXME: Can this be pulled from the actual fixture? + parent_id=mock.Mock(), # FIXME: Can this be pulled from the actual fixture? + position=0, + is_nsfw=False, + permission_overwrites={}, + topic=None, + last_message_id=None, + rate_limit_per_user=datetime.timedelta(seconds=10), + last_pin_timestamp=None, + default_auto_archive_duration=datetime.timedelta(seconds=10), + ) + + +@pytest.fixture +def hikari_user() -> users.User: + return users.UserImpl( + id=snowflakes.Snowflake(789), + app=mock.Mock(), + discriminator="0", + username="user_username", + global_name="user_global_name", + avatar_decoration=users.AvatarDecoration( + asset_hash="avatar_decoration_asset_hash", sku_id=snowflakes.Snowflake(999), expires_at=None + ), + avatar_hash="user_avatar_hash", + banner_hash="user_banner_hash", + accent_color=None, + is_bot=False, + is_system=False, + flags=users.UserFlag.NONE, + primary_guild=None, + ) + + +@pytest.fixture +def hikari_message() -> messages.Message: + return messages.Message( + id=snowflakes.Snowflake(101), + app=mock.Mock(), + channel_id=snowflakes.Snowflake(456), + guild_id=None, + author=mock.Mock(), + member=mock.Mock(), + content=None, + timestamp=datetime.datetime.fromtimestamp(6000), + edited_timestamp=None, + is_tts=False, + user_mentions={}, + role_mention_ids=[], + channel_mentions={}, + mentions_everyone=False, + attachments=[], + embeds=[], + poll=None, + reactions=[], + is_pinned=False, + webhook_id=snowflakes.Snowflake(432), + type=messages.MessageType.DEFAULT, + activity=None, + application=None, + message_reference=None, + flags=messages.MessageFlag.NONE, + stickers=[], + nonce=None, + referenced_message=None, + message_snapshots=[], + application_id=None, + components=[], + thread=None, + interaction_metadata=None, + ) + + +@pytest.fixture +def hikari_partial_sticker() -> stickers.PartialSticker: + return stickers.PartialSticker( + id=snowflakes.Snowflake(222), name="sticker_name", format_type=stickers.StickerFormatType.PNG + ) + + +@pytest.fixture +def hikari_guild_sticker() -> stickers.GuildSticker: + return stickers.GuildSticker( + id=snowflakes.Snowflake(2220), + name="guild_sticker_name", + format_type=stickers.StickerFormatType.PNG, + description="guild_sticker_description", + guild_id=snowflakes.Snowflake(123), + is_available=True, + tag="guild_sticker_tag", + user=None, + ) + + +@pytest.fixture +def hikari_custom_emoji() -> emojis.CustomEmoji: + return emojis.CustomEmoji(id=snowflakes.Snowflake(444), name="custom_emoji_name", is_animated=False) + + +@pytest.fixture +def hikari_unicode_emoji() -> emojis.UnicodeEmoji: + return emojis.UnicodeEmoji("🙂") diff --git a/tests/hikari/events/test_base_events.py b/tests/hikari/events/test_base_events.py index a372ebc0c7..cd48523d5f 100644 --- a/tests/hikari/events/test_base_events.py +++ b/tests/hikari/events/test_base_events.py @@ -91,18 +91,18 @@ def error(self): return ex @pytest.fixture - def event(self, error): + def event(self, error: RuntimeError) -> base_events.ExceptionEvent[mock.Mock]: return base_events.ExceptionEvent( exception=error, failed_event=mock.Mock(base_events.Event), failed_callback=mock.AsyncMock() ) - def test_app_property(self, event): + def test_app_property(self, event: base_events.ExceptionEvent[mock.Mock]): app = mock.Mock() event.failed_event.app = app assert event.app is app @pytest.mark.parametrize("has_shard", [True, False]) - def test_shard_property(self, has_shard, event): + def test_shard_property(self, has_shard: bool, event: base_events.ExceptionEvent[mock.Mock]): shard = mock.Mock(spec_set=gateway_shard.GatewayShard) if has_shard: event.failed_event.shard = shard @@ -110,10 +110,11 @@ def test_shard_property(self, has_shard, event): else: assert event.shard is None - def test_exc_info_property(self, event, error): + def test_exc_info_property(self, event: base_events.ExceptionEvent[mock.Mock], error: RuntimeError): assert event.exc_info == (type(error), error, error.__traceback__) @pytest.mark.asyncio - async def test_retry(self, event): - await event.retry() - event.failed_callback.assert_awaited_once_with(event.failed_event) + async def test_retry(self, event: base_events.ExceptionEvent[mock.Mock]): + with mock.patch.object(event, "failed_callback") as patched_failed_callback: + await event.retry() + patched_failed_callback.assert_awaited_once_with(event.failed_event) diff --git a/tests/hikari/events/test_channel_events.py b/tests/hikari/events/test_channel_events.py index 1a490e1450..9dc932d5e6 100644 --- a/tests/hikari/events/test_channel_events.py +++ b/tests/hikari/events/test_channel_events.py @@ -20,181 +20,259 @@ # SOFTWARE. from __future__ import annotations +import typing + import mock import pytest from hikari import channels from hikari import snowflakes +from hikari import traits +from hikari.api import shard as shard_api from hikari.events import channel_events -from tests.hikari import hikari_test_helpers class TestGuildChannelEvent: - @pytest.fixture - def event(self): - cls = hikari_test_helpers.mock_class_namespace( - channel_events.GuildChannelEvent, - guild_id=mock.PropertyMock(return_value=snowflakes.Snowflake(929292929)), - channel_id=mock.PropertyMock(return_value=snowflakes.Snowflake(432432432)), - ) - return cls() + class MockGuildChannelEvent(channel_events.GuildChannelEvent): + def __init__(self, app: traits.RESTAware): + self._app = app + self._shard = mock.Mock() + self._channel_id = snowflakes.Snowflake(123) + self._guild_id = snowflakes.Snowflake(456) - def test_get_guild_when_available(self, event): - result = event.get_guild() + @property + def app(self) -> traits.RESTAware: + return self._app - assert result is event.app.cache.get_available_guild.return_value - event.app.cache.get_available_guild.assert_called_once_with(929292929) - event.app.cache.get_unavailable_guild.assert_not_called() + @property + def shard(self) -> shard_api.GatewayShard: + return self._shard - def test_get_guild_when_unavailable(self, event): - event.app.cache.get_available_guild.return_value = None - result = event.get_guild() + @property + def channel_id(self) -> snowflakes.Snowflake: + return self._channel_id - assert result is event.app.cache.get_unavailable_guild.return_value - event.app.cache.get_available_guild.assert_called_once_with(929292929) - event.app.cache.get_unavailable_guild.assert_called_once_with(929292929) + @property + def guild_id(self) -> snowflakes.Snowflake: + return self._guild_id - def test_get_guild_without_cache(self): - event = hikari_test_helpers.mock_class_namespace(channel_events.GuildChannelEvent, app=None)() - - assert event.get_guild() is None + @pytest.fixture + def guild_channel_event(self, hikari_app: traits.RESTAware) -> channel_events.GuildChannelEvent: + return TestGuildChannelEvent.MockGuildChannelEvent(hikari_app) + + def test_get_guild_when_available(self, guild_channel_event: channel_events.GuildChannelEvent): + with ( + mock.patch.object(guild_channel_event, "_app") as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_available_guild") as patched_get_available_guild, + mock.patch.object(patched_cache, "get_unavailable_guild") as patched_get_unavailable_guild, + ): + result = guild_channel_event.get_guild() + + assert result is patched_get_available_guild.return_value + patched_get_available_guild.assert_called_once_with(456) + patched_get_unavailable_guild.assert_not_called() + + def test_get_guild_when_unavailable(self, guild_channel_event: channel_events.GuildChannelEvent): + with ( + mock.patch.object(guild_channel_event, "_app") as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_available_guild", return_value=None) as patched_get_available_guild, + mock.patch.object(patched_cache, "get_unavailable_guild") as patched_get_unavailable_guild, + ): + result = guild_channel_event.get_guild() + + assert result is patched_get_unavailable_guild.return_value + patched_get_available_guild.assert_called_once_with(456) + patched_get_unavailable_guild.assert_called_once_with(456) + + def test_get_guild_without_cache(self, guild_channel_event: channel_events.GuildChannelEvent): + with mock.patch.object(guild_channel_event, "_app", None): + assert guild_channel_event.get_guild() is None @pytest.mark.asyncio - async def test_fetch_guild(self, event): - event.app.rest.fetch_guild = mock.AsyncMock() - result = await event.fetch_guild() - - assert result is event.app.rest.fetch_guild.return_value - event.app.rest.fetch_guild.assert_awaited_once_with(929292929) - - def test_get_channel(self, event): - result = event.get_channel() - - assert result is event.app.cache.get_guild_channel.return_value - event.app.cache.get_guild_channel.assert_called_once_with(432432432) - - def test_get_channel_without_cache(self): - event = hikari_test_helpers.mock_class_namespace(channel_events.GuildChannelEvent, app=None)() - - assert event.get_channel() is None + async def test_fetch_guild(self, guild_channel_event: channel_events.GuildChannelEvent): + with mock.patch.object( + guild_channel_event.app.rest, "fetch_guild", new_callable=mock.AsyncMock + ) as patched_fetch_guild: + result = await guild_channel_event.fetch_guild() + + assert result is patched_fetch_guild.return_value + patched_fetch_guild.assert_awaited_once_with(456) + + def test_get_channel(self, guild_channel_event: channel_events.GuildChannelEvent): + with ( + mock.patch.object(guild_channel_event, "_app") as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_guild_channel") as patched_get_guild_channel, + ): + result = guild_channel_event.get_channel() + + assert result is patched_get_guild_channel.return_value + patched_get_guild_channel.assert_called_once_with(123) + + def test_get_channel_without_cache(self, guild_channel_event: channel_events.GuildChannelEvent): + with mock.patch.object(guild_channel_event, "_app", None): + assert guild_channel_event.get_channel() is None @pytest.mark.asyncio - async def test_fetch_channel(self, event): - event.app.rest.fetch_channel = mock.AsyncMock(return_value=mock.MagicMock(spec=channels.GuildChannel)) - result = await event.fetch_channel() + async def test_fetch_channel(self, guild_channel_event: channel_events.GuildChannelEvent): + with mock.patch.object( + guild_channel_event.app.rest, + "fetch_channel", + mock.AsyncMock(return_value=mock.MagicMock(spec=channels.GuildChannel)), + ) as patched_fetch_channel: + result = await guild_channel_event.fetch_channel() - assert result is event.app.rest.fetch_channel.return_value - event.app.rest.fetch_channel.assert_awaited_once_with(432432432) + assert result is patched_fetch_channel.return_value + patched_fetch_channel.assert_awaited_once_with(123) class TestGuildChannelCreateEvent: @pytest.fixture - def event(self): - return channel_events.GuildChannelCreateEvent(channel=mock.Mock(), shard=None) + def event(self) -> channel_events.GuildChannelCreateEvent: + return channel_events.GuildChannelCreateEvent(channel=mock.Mock(), shard=mock.Mock()) - def test_app_property(self, event): + def test_app_property(self, event: channel_events.GuildChannelCreateEvent): assert event.app is event.channel.app - def test_channel_id_property(self, event): - event.channel.id = 123 + def test_channel_id_property(self, event: channel_events.GuildChannelCreateEvent): + event.channel.id = snowflakes.Snowflake(123) assert event.channel_id == 123 - def test_guild_id_property(self, event): - event.channel.guild_id = 123 + def test_guild_id_property(self, event: channel_events.GuildChannelCreateEvent): + event.channel.guild_id = snowflakes.Snowflake(123) assert event.guild_id == 123 class TestGuildChannelUpdateEvent: @pytest.fixture - def event(self): - return channel_events.GuildChannelUpdateEvent(channel=mock.Mock(), old_channel=mock.Mock(), shard=None) + def event(self) -> channel_events.GuildChannelUpdateEvent: + return channel_events.GuildChannelUpdateEvent(channel=mock.Mock(), old_channel=mock.Mock(), shard=mock.Mock()) - def test_app_property(self, event): + def test_app_property(self, event: channel_events.GuildChannelUpdateEvent): assert event.app is event.channel.app - def test_channel_id_property(self, event): - event.channel.id = 123 + def test_channel_id_property(self, event: channel_events.GuildChannelUpdateEvent): + event.channel.id = snowflakes.Snowflake(123) assert event.channel_id == 123 - def test_guild_id_property(self, event): - event.channel.guild_id = 123 + def test_guild_id_property(self, event: channel_events.GuildChannelUpdateEvent): + event.channel.guild_id = snowflakes.Snowflake(123) assert event.guild_id == 123 - def test_old_channel_id_property(self, event): - event.old_channel.id = 123 + def test_old_channel_id_property(self, event: channel_events.GuildChannelUpdateEvent): + assert event.old_channel + event.old_channel.id = snowflakes.Snowflake(123) assert event.old_channel.id == 123 class TestGuildChannelDeleteEvent: @pytest.fixture - def event(self): - return channel_events.GuildChannelDeleteEvent(channel=mock.Mock(), shard=None) + def event(self) -> channel_events.GuildChannelDeleteEvent: + return channel_events.GuildChannelDeleteEvent(channel=mock.Mock(), shard=mock.Mock()) - def test_app_property(self, event): + def test_app_property(self, event: channel_events.GuildChannelDeleteEvent): assert event.app is event.channel.app - def test_channel_id_property(self, event): - event.channel.id = 123 + def test_channel_id_property(self, event: channel_events.GuildChannelDeleteEvent): + event.channel.id = snowflakes.Snowflake(123) assert event.channel_id == 123 - def test_guild_id_property(self, event): - event.channel.guild_id = 123 + def test_guild_id_property(self, event: channel_events.GuildChannelDeleteEvent): + event.channel.guild_id = snowflakes.Snowflake(123) assert event.guild_id == 123 class TestGuildPinsUpdateEvent: @pytest.fixture - def event(self): + def event(self) -> channel_events.GuildPinsUpdateEvent: return channel_events.GuildPinsUpdateEvent( - app=mock.Mock(), shard=None, channel_id=12343, guild_id=None, last_pin_timestamp=None + app=mock.Mock(), + shard=mock.Mock(), + channel_id=snowflakes.Snowflake(12343), + guild_id=snowflakes.Snowflake(45676), + last_pin_timestamp=None, ) @pytest.mark.parametrize("result", [mock.Mock(spec=channels.GuildTextChannel), None]) - def test_get_channel(self, event, result): - event.app.cache.get_guild_channel.return_value = result - - result = event.get_channel() + def test_get_channel( + self, event: channel_events.GuildPinsUpdateEvent, result: typing.Optional[channels.GuildTextChannel] + ): + with ( + mock.patch.object(event, "app") as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_guild_channel", return_value=result) as patched_get_guild_channel, + ): + channel = event.get_channel() - assert result is event.app.cache.get_guild_channel.return_value - event.app.cache.get_guild_channel.assert_called_once_with(event.channel_id) + assert channel is patched_get_guild_channel.return_value + patched_get_guild_channel.assert_called_once_with(event.channel_id) @pytest.mark.asyncio class TestInviteEvent: + class MockInviteEvent(channel_events.InviteEvent): + def __init__(self, app: traits.RESTAware): + self._app = app + self._shard = mock.Mock() + self._channel_id = snowflakes.Snowflake(123) + self._guild_id = snowflakes.Snowflake(456) + self._code = "code" + + @property + def app(self) -> traits.RESTAware: + return self._app + + @property + def shard(self) -> shard_api.GatewayShard: + return self._shard + + @property + def channel_id(self) -> snowflakes.Snowflake: + return self._channel_id + + @property + def guild_id(self) -> snowflakes.Snowflake: + return self._guild_id + + @property + def code(self) -> str: + return self._code + @pytest.fixture - def event(self): - return hikari_test_helpers.mock_class_namespace( - channel_events.InviteEvent, slots_=False, code=mock.PropertyMock(return_value="Jx4cNGG") - )() + def invite_event(self, hikari_app: traits.RESTAware) -> channel_events.InviteEvent: + return TestInviteEvent.MockInviteEvent(hikari_app) - async def test_fetch_invite(self, event): - event.app.rest.fetch_invite = mock.AsyncMock() + async def test_fetch_invite(self, invite_event: channel_events.InviteEvent): + invite_event.app.rest.fetch_invite = mock.AsyncMock() - await event.fetch_invite() + with mock.patch.object(invite_event.app.rest, "fetch_invite", mock.AsyncMock()) as patched_fetch_invite: + await invite_event.fetch_invite() - event.app.rest.fetch_invite.assert_awaited_once_with("Jx4cNGG") + patched_fetch_invite.assert_awaited_once_with("code") class TestInviteCreateEvent: @pytest.fixture - def event(self): - return channel_events.InviteCreateEvent(shard=None, invite=mock.Mock()) + def event(self) -> channel_events.InviteCreateEvent: + return channel_events.InviteCreateEvent(shard=mock.Mock(), invite=mock.Mock()) - def test_app_property(self, event): + def test_app_property(self, event: channel_events.InviteCreateEvent): assert event.app is event.invite.app @pytest.mark.asyncio - async def test_channel_id_property(self, event): - event.invite.channel_id = 123 + async def test_channel_id_property(self, event: channel_events.InviteCreateEvent): + event.invite.channel_id = snowflakes.Snowflake(123) assert event.channel_id == 123 @pytest.mark.asyncio - async def test_guild_id_property(self, event): - event.invite.guild_id = 123 + async def test_guild_id_property(self, event: channel_events.InviteCreateEvent): + event.invite.guild_id = snowflakes.Snowflake(123) assert event.guild_id == 123 @pytest.mark.asyncio - async def test_code_property(self, event): + async def test_code_property(self, event: channel_events.InviteCreateEvent): event.invite.code = "Jx4cNGG" assert event.code == "Jx4cNGG" @@ -202,33 +280,64 @@ async def test_code_property(self, event): @pytest.mark.asyncio class TestWebhookUpdateEvent: @pytest.fixture - def event(self): - return channel_events.WebhookUpdateEvent(app=mock.AsyncMock(), shard=mock.Mock(), channel_id=123, guild_id=456) - - async def test_fetch_channel_webhooks(self, event): - await event.fetch_channel_webhooks() - - event.app.rest.fetch_channel_webhooks.assert_awaited_once_with(123) + def event(self) -> channel_events.WebhookUpdateEvent: + return channel_events.WebhookUpdateEvent( + app=mock.AsyncMock(), + shard=mock.Mock(), + channel_id=snowflakes.Snowflake(123), + guild_id=snowflakes.Snowflake(456), + ) - async def test_fetch_guild_webhooks(self, event): - await event.fetch_guild_webhooks() + async def test_fetch_channel_webhooks(self, event: channel_events.WebhookUpdateEvent): + with mock.patch.object(event.app.rest, "fetch_channel_webhooks") as patched_fetch_channel_webhooks: + await event.fetch_channel_webhooks() + patched_fetch_channel_webhooks.assert_awaited_once_with(123) - event.app.rest.fetch_guild_webhooks.assert_awaited_once_with(456) + async def test_fetch_guild_webhooks(self, event: channel_events.WebhookUpdateEvent): + with mock.patch.object(event.app.rest, "fetch_guild_webhooks") as patched_fetch_guild_webhooks: + await event.fetch_guild_webhooks() + patched_fetch_guild_webhooks.assert_awaited_once_with(456) class TestGuildThreadEvent: - @pytest.mark.asyncio - async def test_fetch_channel(self): - mock_app = mock.AsyncMock() - mock_app.rest.fetch_channel.return_value = mock.Mock(channels.GuildThreadChannel) - event = hikari_test_helpers.mock_class_namespace( - channel_events.GuildThreadEvent, app=mock_app, thread_id=123321 - )() + class MockGuildThreadEvent(channel_events.GuildThreadEvent): + def __init__(self, app: traits.RESTAware): + self._app = app + self._shard = mock.Mock() + self._thread_id = snowflakes.Snowflake(123) + self._guild_id = snowflakes.Snowflake(456) + self._code = "code" + + @property + def app(self) -> traits.RESTAware: + return self._app + + @property + def shard(self) -> shard_api.GatewayShard: + return self._shard + + @property + def guild_id(self) -> snowflakes.Snowflake: + return self._guild_id + + @property + def thread_id(self) -> snowflakes.Snowflake: + return self._thread_id - result = await event.fetch_channel() - - assert result is mock_app.rest.fetch_channel.return_value - mock_app.rest.fetch_channel.assert_awaited_once_with(123321) + @pytest.mark.asyncio + async def test_fetch_channel(self, hikari_app: traits.RESTAware): + with mock.patch.object( + hikari_app.rest, + "fetch_channel", + new_callable=mock.AsyncMock, + return_value=mock.Mock(channels.GuildThreadChannel), + ) as patched_fetch_channel: + event = TestGuildThreadEvent.MockGuildThreadEvent(hikari_app) + + result = await event.fetch_channel() + + assert result is patched_fetch_channel.return_value + patched_fetch_channel.assert_awaited_once_with(123) class TestGuildThreadAccessEvent: diff --git a/tests/hikari/events/test_guild_events.py b/tests/hikari/events/test_guild_events.py index a98a1ce2c4..6659e84885 100644 --- a/tests/hikari/events/test_guild_events.py +++ b/tests/hikari/events/test_guild_events.py @@ -26,60 +26,89 @@ from hikari import guilds from hikari import presences from hikari import snowflakes +from hikari import traits +from hikari import users +from hikari.api import shard as shard_api from hikari.events import guild_events -from tests.hikari import hikari_test_helpers class TestGuildEvent: - @pytest.fixture - def event(self): - cls = hikari_test_helpers.mock_class_namespace( - guild_events.GuildEvent, guild_id=mock.PropertyMock(return_value=snowflakes.Snowflake(534123123)) - ) - return cls() - - def test_get_guild_when_available(self, event): - result = event.get_guild() - - assert result is event.app.cache.get_available_guild.return_value - event.app.cache.get_available_guild.assert_called_once_with(534123123) - event.app.cache.get_unavailable_guild.assert_not_called() + class MockGuildEvent(guild_events.GuildEvent): + def __init__(self, app: traits.RESTAware): + self._app = app + self._shard = mock.Mock() + self._guild_id = snowflakes.Snowflake(123) - def test_get_guild_when_unavailable(self, event): - event.app.cache.get_available_guild.return_value = None - result = event.get_guild() + @property + def app(self) -> traits.RESTAware: + return self._app - assert result is event.app.cache.get_unavailable_guild.return_value - event.app.cache.get_unavailable_guild.assert_called_once_with(534123123) - event.app.cache.get_available_guild.assert_called_once_with(534123123) + @property + def shard(self) -> shard_api.GatewayShard: + return self._shard - def test_get_guild_cacheless(self, event): - event = hikari_test_helpers.mock_class_namespace(guild_events.GuildEvent, app=object())() + @property + def guild_id(self) -> snowflakes.Snowflake: + return self._guild_id - assert event.get_guild() is None + @pytest.fixture + def guild_event(self, hikari_app: traits.RESTAware) -> guild_events.GuildEvent: + return TestGuildEvent.MockGuildEvent(hikari_app) + + def test_get_guild_when_available(self, guild_event: guild_events.GuildEvent): + with ( + mock.patch.object(guild_event, "_app", mock.Mock(traits.CacheAware)) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_available_guild") as patched_get_available_guild, + mock.patch.object(patched_cache, "get_unavailable_guild") as patched_get_unavailable_guild, + ): + result = guild_event.get_guild() + + assert result is patched_get_available_guild.return_value + patched_get_available_guild.assert_called_once_with(123) + patched_get_unavailable_guild.assert_not_called() + + def test_get_guild_when_unavailable(self, guild_event: guild_events.GuildEvent): + with ( + mock.patch.object(guild_event, "_app", mock.Mock(traits.CacheAware)) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_available_guild", return_value=None) as patched_get_available_guild, + mock.patch.object(patched_cache, "get_unavailable_guild") as patched_get_unavailable_guild, + ): + result = guild_event.get_guild() + + assert result is patched_get_unavailable_guild.return_value + patched_get_unavailable_guild.assert_called_once_with(123) + patched_get_available_guild.assert_called_once_with(123) + + def test_get_guild_cacheless(self, guild_event: guild_events.GuildEvent): + with mock.patch.object(guild_event, "_app", None): + assert guild_event.get_guild() is None @pytest.mark.asyncio - async def test_fetch_guild(self, event): - event.app.rest.fetch_guild = mock.AsyncMock() - result = await event.fetch_guild() + async def test_fetch_guild(self, guild_event: guild_events.GuildEvent): + with mock.patch.object(guild_event.app.rest, "fetch_guild", mock.AsyncMock()) as patched_fetch_guild: + result = await guild_event.fetch_guild() - assert result is event.app.rest.fetch_guild.return_value - event.app.rest.fetch_guild.assert_called_once_with(534123123) + assert result is patched_fetch_guild.return_value + patched_fetch_guild.assert_called_once_with(123) @pytest.mark.asyncio - async def test_fetch_guild_preview(self, event): - event.app.rest.fetch_guild_preview = mock.AsyncMock() - result = await event.fetch_guild_preview() + async def test_fetch_guild_preview(self, guild_event: guild_events.GuildEvent): + with mock.patch.object( + guild_event.app.rest, "fetch_guild_preview", mock.AsyncMock() + ) as patched_fetch_guild_preview: + result = await guild_event.fetch_guild_preview() - assert result is event.app.rest.fetch_guild_preview.return_value - event.app.rest.fetch_guild_preview.assert_called_once_with(534123123) + assert result is patched_fetch_guild_preview.return_value + patched_fetch_guild_preview.assert_called_once_with(123) class TestGuildAvailableEvent: @pytest.fixture - def event(self): + def event(self) -> guild_events.GuildAvailableEvent: return guild_events.GuildAvailableEvent( - shard=object(), + shard=mock.Mock(), guild=mock.Mock(guilds.Guild), emojis={}, stickers={}, @@ -91,19 +120,19 @@ def event(self): threads={}, ) - def test_app_property(self, event): + def test_app_property(self, event: guild_events.GuildAvailableEvent): assert event.app is event.guild.app - def test_guild_id_property(self, event): - event.guild.id = 123 + def test_guild_id_property(self, event: guild_events.GuildAvailableEvent): + event.guild.id = snowflakes.Snowflake(123) assert event.guild_id == 123 class TestGuildUpdateEvent: @pytest.fixture - def event(self): + def event(self) -> guild_events.GuildUpdateEvent: return guild_events.GuildUpdateEvent( - shard=object(), + shard=mock.Mock(), guild=mock.Mock(guilds.Guild), old_guild=mock.Mock(guilds.Guild), emojis={}, @@ -111,59 +140,81 @@ def event(self): roles={}, ) - def test_app_property(self, event): + def test_app_property(self, event: guild_events.GuildUpdateEvent): assert event.app is event.guild.app - def test_guild_id_property(self, event): - event.guild.id = 123 + def test_guild_id_property(self, event: guild_events.GuildUpdateEvent): + event.guild.id = snowflakes.Snowflake(123) assert event.guild_id == 123 - def test_old_guild_id_property(self, event): - event.old_guild.id = 123 - assert event.old_guild.id == 123 + def test_old_guild_id_property(self, event: guild_events.GuildUpdateEvent): + with mock.patch.object(event.old_guild, "id", snowflakes.Snowflake(123)): + assert event.old_guild is not None + assert event.old_guild.id == 123 class TestBanEvent: + class MockBanEvent(guild_events.BanEvent): + def __init__(self, app: traits.RESTAware): + self._app = app + self._shard = mock.Mock() + self._guild_id = snowflakes.Snowflake(123) + self._user = mock.Mock(app=app, id=snowflakes.Snowflake(456)) + + @property + def app(self) -> traits.RESTAware: + return self._app + + @property + def shard(self) -> shard_api.GatewayShard: + return self._shard + + @property + def guild_id(self) -> snowflakes.Snowflake: + return self._guild_id + + @property + def user(self) -> users.User: + return self._user + @pytest.fixture - def event(self): - return hikari_test_helpers.mock_class_namespace(guild_events.BanEvent)() + def ban_event(self, hikari_app: traits.RESTAware) -> guild_events.BanEvent: + return TestBanEvent.MockBanEvent(hikari_app) - def test_app_property(self, event): - assert event.app is event.user.app + def test_app_property(self, ban_event: guild_events.BanEvent): + assert ban_event.app is ban_event.user.app class TestPresenceUpdateEvent: @pytest.fixture - def event(self): + def event(self) -> guild_events.PresenceUpdateEvent: return guild_events.PresenceUpdateEvent( - shard=object(), + shard=mock.Mock(), presence=mock.Mock(presences.MemberPresence), old_presence=mock.Mock(presences.MemberPresence), user=mock.Mock(), ) - def test_app_property(self, event): + def test_app_property(self, event: guild_events.PresenceUpdateEvent): assert event.app is event.presence.app - def test_user_id_property(self, event): - event.presence.user_id = 123 + def test_user_id_property(self, event: guild_events.PresenceUpdateEvent): + event.presence.user_id = snowflakes.Snowflake(123) assert event.user_id == 123 - def test_guild_id_property(self, event): - event.presence.guild_id = 123 + def test_guild_id_property(self, event: guild_events.PresenceUpdateEvent): + event.presence.guild_id = snowflakes.Snowflake(123) assert event.guild_id == 123 - def test_old_presence(self, event): - event.old_presence.id = 123 - event.old_presence.guild_id = 456 - - assert event.old_presence.id == 123 - assert event.old_presence.guild_id == 456 + def test_old_presence(self, event: guild_events.PresenceUpdateEvent): + with mock.patch.object(event.old_presence, "guild_id", 456): + assert event.old_presence is not None + assert event.old_presence.guild_id == 456 class TestGuildStickersUpdateEvent: @pytest.fixture - def event(self): + def event(self) -> guild_events.StickersUpdateEvent: return guild_events.StickersUpdateEvent( app=mock.Mock(), shard=mock.Mock(), @@ -173,7 +224,7 @@ def event(self): ) @pytest.mark.asyncio - async def test_fetch_stickers(self, event): + async def test_fetch_stickers(self, event: guild_events.StickersUpdateEvent): event.app.rest.fetch_guild_stickers = mock.AsyncMock() assert await event.fetch_stickers() is event.app.rest.fetch_guild_stickers.return_value @@ -183,11 +234,11 @@ async def test_fetch_stickers(self, event): class TestAuditLogEntryCreateEvent: @pytest.fixture - def event(self): + def event(self) -> guild_events.AuditLogEntryCreateEvent: return guild_events.AuditLogEntryCreateEvent(shard=mock.Mock(), entry=mock.Mock()) - def test_app_property(self, event): + def test_app_property(self, event: guild_events.AuditLogEntryCreateEvent): assert event.app is event.entry.app - def test_guild_id_property(self, event): + def test_guild_id_property(self, event: guild_events.AuditLogEntryCreateEvent): assert event.guild_id is event.entry.guild_id diff --git a/tests/hikari/events/test_interaction_events.py b/tests/hikari/events/test_interaction_events.py index 141c08ff9b..1ff58683b2 100644 --- a/tests/hikari/events/test_interaction_events.py +++ b/tests/hikari/events/test_interaction_events.py @@ -27,6 +27,6 @@ class TestInteractionCreateEvent: def test_app_property(self): - mock_event = interaction_events.InteractionCreateEvent(shard=object(), interaction=mock.Mock()) + mock_event = interaction_events.InteractionCreateEvent(shard=mock.Mock(), interaction=mock.Mock()) assert mock_event.app is mock_event.interaction.app diff --git a/tests/hikari/events/test_member_events.py b/tests/hikari/events/test_member_events.py index f1d53aa346..bd736c2a49 100644 --- a/tests/hikari/events/test_member_events.py +++ b/tests/hikari/events/test_member_events.py @@ -25,80 +25,112 @@ from hikari import guilds from hikari import snowflakes +from hikari import traits +from hikari import users +from hikari.api import shard as shard_api from hikari.events import member_events -from tests.hikari import hikari_test_helpers class TestMemberEvent: - @pytest.fixture - def event(self): - cls = hikari_test_helpers.mock_class_namespace( - member_events.MemberEvent, - slots_=False, - guild_id=mock.PropertyMock(return_value=snowflakes.Snowflake(123)), - user=mock.Mock(id=456), - ) - return cls() - - def test_app_property(self, event): - assert event.app is event.user.app - - def test_user_id_property(self, event): - event.user_id == 456 - - def test_guild_when_no_cache_trait(self): - event = hikari_test_helpers.mock_class_namespace(member_events.MemberEvent, app=None)() + class MockMemberEvent(member_events.MemberEvent): + def __init__(self, app: traits.RESTAware): + self._app = app + self._shard = mock.Mock() + self._guild_id = snowflakes.Snowflake(123) + self._user = mock.Mock(app=app, id=snowflakes.Snowflake(456)) - assert event.get_guild() is None + @property + def app(self) -> traits.RESTAware: + return self._app - def test_get_guild_when_available(self, event): - result = event.get_guild() + @property + def shard(self) -> shard_api.GatewayShard: + return self._shard - assert result is event.app.cache.get_available_guild.return_value - event.app.cache.get_available_guild.assert_called_once_with(123) - event.app.cache.get_unavailable_guild.assert_not_called() + @property + def guild_id(self) -> snowflakes.Snowflake: + return self._guild_id - def test_guild_when_unavailable(self, event): - event.app.cache.get_available_guild.return_value = None - result = event.get_guild() + @property + def user(self) -> users.User: + return self._user - assert result is event.app.cache.get_unavailable_guild.return_value - event.app.cache.get_unavailable_guild.assert_called_once_with(123) - event.app.cache.get_available_guild.assert_called_once_with(123) + @pytest.fixture + def member_event(self, hikari_app: traits.RESTAware) -> member_events.MemberEvent: + return TestMemberEvent.MockMemberEvent(hikari_app) + + def test_app_property(self, member_event: member_events.MemberEvent): + assert member_event.app is member_event.user.app + + def test_user_id_property(self, member_event: member_events.MemberEvent): + assert member_event.user_id == 456 + + def test_guild_when_no_cache_trait(self, member_event: member_events.MemberEvent): + with mock.patch.object(member_event, "_app", None): + assert member_event.get_guild() is None + + def test_get_guild_when_available(self, member_event: member_events.MemberEvent): + with ( + mock.patch.object(member_event, "_app", mock.Mock(traits.CacheAware)) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_available_guild") as patched_get_available_guild, + mock.patch.object(patched_cache, "get_unavailable_guild") as patched_get_unavailable_guild, + ): + result = member_event.get_guild() + + assert result is patched_get_available_guild.return_value + patched_get_available_guild.assert_called_once_with(123) + patched_get_unavailable_guild.assert_not_called() + + def test_guild_when_unavailable(self, member_event: member_events.MemberEvent): + with ( + mock.patch.object(member_event, "_app", mock.Mock(traits.CacheAware)) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_available_guild", return_value=None) as patched_get_available_guild, + mock.patch.object(patched_cache, "get_unavailable_guild") as patched_get_unavailable_guild, + ): + result = member_event.get_guild() + + assert result is patched_get_unavailable_guild.return_value + patched_get_unavailable_guild.assert_called_once_with(123) + patched_get_available_guild.assert_called_once_with(123) class TestMemberCreateEvent: @pytest.fixture - def event(self): - return member_events.MemberCreateEvent(shard=None, member=mock.Mock()) + def event(self) -> member_events.MemberCreateEvent: + return member_events.MemberCreateEvent(shard=mock.Mock(), member=mock.Mock()) - def test_guild_property(self, event): - event.member.guild_id = 123 - event.guild_id == 123 + def test_guild_property(self, event: member_events.MemberCreateEvent): + with mock.patch.object(event.member, "guild_id", snowflakes.Snowflake(123)): + assert event.guild_id == 123 - def test_user_property(self, event): - user = object() - event.member.user = user - event.user == user + def test_user_property(self, event: member_events.MemberCreateEvent): + user = mock.Mock() + with mock.patch.object(event.member, "user", user): + assert event.user == user class TestMemberUpdateEvent: @pytest.fixture - def event(self): - return member_events.MemberUpdateEvent(shard=None, member=mock.Mock(), old_member=mock.Mock(guilds.Member)) - - def test_guild_property(self, event): - event.member.guild_id = 123 - event.guild_id == 123 - - def test_user_property(self, event): - user = object() - event.member.user = user - event.user == user - - def test_old_user_property(self, event): - event.member.guild_id = 123 - event.member.id = 456 + def event(self) -> member_events.MemberUpdateEvent: + return member_events.MemberUpdateEvent( + shard=mock.Mock(), member=mock.Mock(), old_member=mock.Mock(guilds.Member) + ) - assert event.member.guild_id == 123 - assert event.member.id == 456 + def test_guild_property(self, event: member_events.MemberUpdateEvent): + with mock.patch.object(event.member, "guild_id", snowflakes.Snowflake(123)): + assert event.guild_id == 123 + + def test_user_property(self, event: member_events.MemberUpdateEvent): + user = mock.Mock() + with mock.patch.object(event.member, "user", user): + assert event.user == user + + def test_old_user_property(self, event: member_events.MemberUpdateEvent): + with ( + mock.patch.object(event.member, "guild_id", snowflakes.Snowflake(123)), + mock.patch.object(event.member, "id", 456), + ): + assert event.member.guild_id == 123 + assert event.member.id == 456 diff --git a/tests/hikari/events/test_message_events.py b/tests/hikari/events/test_message_events.py index e33424c946..fbc68b78ca 100644 --- a/tests/hikari/events/test_message_events.py +++ b/tests/hikari/events/test_message_events.py @@ -20,114 +20,151 @@ # SOFTWARE. from __future__ import annotations +import typing + import mock import pytest from hikari import channels from hikari import messages from hikari import snowflakes +from hikari import traits from hikari import undefined from hikari import users +from hikari.api import shard as shard_api from hikari.events import message_events -from tests.hikari import hikari_test_helpers class TestMessageCreateEvent: - @pytest.fixture - def event(self): - cls = hikari_test_helpers.mock_class_namespace( - message_events.MessageCreateEvent, - message=mock.Mock(spec_set=messages.Message, author=mock.Mock(spec_set=users.User)), - shard=mock.Mock(), - ) + class MockMessageCreateEvent(message_events.MessageCreateEvent): + def __init__(self, app: traits.RESTAware): + self._app = app + self._shard = mock.Mock() + self._message = mock.Mock(channel_id=snowflakes.Snowflake(123), id=snowflakes.Snowflake(456)) + + @property + def shard(self) -> shard_api.GatewayShard: + return self._shard + + @property + def message(self) -> messages.Message: + return self._message - return cls() + @pytest.fixture + def message_create_event(self, hikari_app: traits.RESTAware) -> message_events.MessageCreateEvent: + return TestMessageCreateEvent.MockMessageCreateEvent(hikari_app) - def test_app_property(self, event): - assert event.app is event.message.app + def test_app_property(self, message_create_event: message_events.MessageCreateEvent): + assert message_create_event.app is message_create_event.message.app - def test_author_property(self, event): - assert event.author is event.message.author + def test_author_property(self, message_create_event: message_events.MessageCreateEvent): + assert message_create_event.author is message_create_event.message.author - def test_author_id_property(self, event): - assert event.author_id is event.author.id + def test_author_id_property(self, message_create_event: message_events.MessageCreateEvent): + assert message_create_event.author_id is message_create_event.author.id - def test_channel_id_property(self, event): - assert event.channel_id is event.message.channel_id + def test_channel_id_property(self, message_create_event: message_events.MessageCreateEvent): + assert message_create_event.channel_id is message_create_event.message.channel_id - def test_content_property(self, event): - assert event.content is event.message.content + def test_content_property(self, message_create_event: message_events.MessageCreateEvent): + assert message_create_event.content is message_create_event.message.content - def test_embeds_property(self, event): - assert event.embeds is event.message.embeds + def test_embeds_property(self, message_create_event: message_events.MessageCreateEvent): + assert message_create_event.embeds is message_create_event.message.embeds @pytest.mark.parametrize("is_bot", [True, False]) - def test_is_bot_property(self, event, is_bot): - event.message.author.is_bot = is_bot - assert event.is_bot is is_bot + def test_is_bot_property(self, message_create_event: message_events.MessageCreateEvent, is_bot: bool): + with mock.patch.object(message_create_event.message.author, "is_bot", is_bot): + assert message_create_event.is_bot is is_bot @pytest.mark.parametrize( ("author_is_bot", "webhook_id", "expected_is_human"), [(True, 123, False), (True, None, False), (False, 123, False), (False, None, True)], ) - def test_is_human_property(self, event, author_is_bot, webhook_id, expected_is_human): - event.message.author.is_bot = author_is_bot - event.message.webhook_id = webhook_id - assert event.is_human is expected_is_human + def test_is_human_property( + self, + message_create_event: message_events.MessageCreateEvent, + author_is_bot: bool, + webhook_id: snowflakes.Snowflake, + expected_is_human: bool, + ): + with ( + mock.patch.object(message_create_event.message.author, "is_bot", author_is_bot), + mock.patch.object(message_create_event.message, "webhook_id", webhook_id), + ): + assert message_create_event.is_human is expected_is_human @pytest.mark.parametrize(("webhook_id", "is_webhook"), [(123, True), (None, False)]) - def test_is_webhook_property(self, event, webhook_id, is_webhook): - event.message.webhook_id = webhook_id - assert event.is_webhook is is_webhook + def test_is_webhook_property( + self, + message_create_event: message_events.MessageCreateEvent, + webhook_id: typing.Optional[int], + is_webhook: bool, + ): + with mock.patch.object(message_create_event.message, "webhook_id", webhook_id): + assert message_create_event.is_webhook is is_webhook - def test_message_id_property(self, event): - assert event.message_id is event.message.id + def test_message_id_property(self, message_create_event: message_events.MessageCreateEvent): + assert message_create_event.message_id is message_create_event.message.id class TestMessageUpdateEvent: - @pytest.fixture - def event(self): - cls = hikari_test_helpers.mock_class_namespace( - message_events.MessageUpdateEvent, - message=mock.Mock(spec_set=messages.Message, author=mock.Mock(spec_set=users.User)), - shard=mock.Mock(), - ) + class MockMessageUpdateEvent(message_events.MessageUpdateEvent): + def __init__(self, app: traits.RESTAware): + self._app = app + self._shard = mock.Mock() + self._message = mock.Mock(channel_id=snowflakes.Snowflake(123), id=snowflakes.Snowflake(456)) + + @property + def shard(self) -> shard_api.GatewayShard: + return self._shard - return cls() + @property + def message(self) -> messages.Message: + return self._message - def test_app_property(self, event): - assert event.app is event.message.app + @pytest.fixture + def message_update_event(self, hikari_app: traits.RESTAware) -> message_events.MessageUpdateEvent: + return TestMessageUpdateEvent.MockMessageUpdateEvent(hikari_app) + + def test_app_property(self, message_update_event: message_events.MessageUpdateEvent): + assert message_update_event.app is message_update_event.message.app @pytest.mark.parametrize("author", [mock.Mock(spec_set=users.User), undefined.UNDEFINED]) - def test_author_property(self, event, author): - event.message.author = author - assert event.author is author + def test_author_property(self, message_update_event: message_events.MessageUpdateEvent, author: users.User): + message_update_event.message.author = author + assert message_update_event.author is author @pytest.mark.parametrize( ("author", "expected_id"), [(mock.Mock(spec_set=users.User, id=91827), 91827), (undefined.UNDEFINED, undefined.UNDEFINED)], ) - def test_author_id_property(self, event, author, expected_id): - event.message.author = author - assert event.author_id == expected_id + def test_author_id_property( + self, + message_update_event: message_events.MessageUpdateEvent, + author: undefined.UndefinedOr[users.User], + expected_id: undefined.UndefinedOr[int], + ): + message_update_event.message.author = author + assert message_update_event.author_id == expected_id - def test_channel_id_property(self, event): - assert event.channel_id is event.message.channel_id + def test_channel_id_property(self, message_update_event: message_events.MessageUpdateEvent): + assert message_update_event.channel_id is message_update_event.message.channel_id - def test_content_property(self, event): - assert event.content is event.message.content + def test_content_property(self, message_update_event: message_events.MessageUpdateEvent): + assert message_update_event.content is message_update_event.message.content - def test_embeds_property(self, event): - assert event.embeds is event.message.embeds + def test_embeds_property(self, message_update_event: message_events.MessageUpdateEvent): + assert message_update_event.embeds is message_update_event.message.embeds @pytest.mark.parametrize("is_bot", [True, False]) - def test_is_bot_property(self, event, is_bot): - event.message.author.is_bot = is_bot - assert event.is_bot is is_bot + def test_is_bot_property(self, message_update_event: message_events.MessageUpdateEvent, is_bot: bool): + with mock.patch.object(message_update_event.message.author, "is_bot", is_bot): + assert message_update_event.is_bot is is_bot - def test_is_bot_property_if_no_author(self, event): - event.message.author = undefined.UNDEFINED - assert event.is_bot is undefined.UNDEFINED + def test_is_bot_property_if_no_author(self, message_update_event: message_events.MessageUpdateEvent): + message_update_event.message.author = undefined.UNDEFINED + assert message_update_event.is_bot is undefined.UNDEFINED @pytest.mark.parametrize( ("author", "webhook_id", "expected_is_human"), @@ -140,88 +177,117 @@ def test_is_bot_property_if_no_author(self, event): (undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED), ], ) - def test_is_human_property(self, event, author, webhook_id, expected_is_human): - event.message.author = author - event.message.webhook_id = webhook_id - assert event.is_human is expected_is_human + def test_is_human_property( + self, + message_update_event: message_events.MessageUpdateEvent, + author: undefined.UndefinedOr[users.User], + webhook_id: undefined.UndefinedOr[snowflakes.Snowflake], + expected_is_human: undefined.UndefinedOr[bool], + ): + message_update_event.message.author = author + message_update_event.message.webhook_id = webhook_id + assert message_update_event.is_human is expected_is_human @pytest.mark.parametrize( ("webhook_id", "is_webhook"), [(123, True), (None, False), (undefined.UNDEFINED, undefined.UNDEFINED)] ) - def test_is_webhook_property(self, event, webhook_id, is_webhook): - event.message.webhook_id = webhook_id - assert event.is_webhook is is_webhook + def test_is_webhook_property( + self, + message_update_event: message_events.MessageUpdateEvent, + webhook_id: undefined.UndefinedOr[snowflakes.Snowflake], + is_webhook: undefined.UndefinedOr[bool], + ): + message_update_event.message.webhook_id = webhook_id + assert message_update_event.is_webhook is is_webhook - def test_message_id_property(self, event): - assert event.message_id is event.message.id + def test_message_id_property(self, message_update_event: message_events.MessageUpdateEvent): + assert message_update_event.message_id is message_update_event.message.id class TestGuildMessageCreateEvent: @pytest.fixture - def event(self): + def guild_message_create_event(self) -> message_events.GuildMessageCreateEvent: return message_events.GuildMessageCreateEvent( message=mock.Mock( - spec_set=messages.Message, + spec=messages.Message, guild_id=snowflakes.Snowflake(342123123), channel_id=snowflakes.Snowflake(9121234), ), shard=mock.Mock(), ) - def test_guild_id_property(self, event): - assert event.guild_id == snowflakes.Snowflake(342123123) - - def test_get_channel_when_no_cache_trait(self): - event = hikari_test_helpers.mock_class_namespace( - message_events.GuildMessageCreateEvent, app=None, init_=False - )() + def test_guild_id_property(self, guild_message_create_event: message_events.GuildMessageCreateEvent): + assert guild_message_create_event.guild_id == snowflakes.Snowflake(342123123) - assert event.get_channel() is None + def test_get_channel_when_no_cache_trait(self, guild_message_create_event: message_events.GuildMessageCreateEvent): + with mock.patch.object(message_events.GuildMessageCreateEvent, "app", None): + assert guild_message_create_event.get_channel() is None @pytest.mark.parametrize("guild_channel_impl", [channels.GuildTextChannel, channels.GuildNewsChannel]) - def test_get_channel(self, event, guild_channel_impl): - event.app.cache.get_guild_channel = mock.Mock(return_value=mock.Mock(spec_set=guild_channel_impl)) - - result = event.get_channel() - assert result is event.app.cache.get_guild_channel.return_value - event.app.cache.get_guild_channel.assert_called_once_with(9121234) - - def test_get_guild_when_no_cache_trait(self): - event = hikari_test_helpers.mock_class_namespace( - message_events.GuildMessageCreateEvent, app=None, init_=False - )() - - assert event.get_guild() is None - - def test_get_guild(self, event): - result = event.get_guild() - - assert result is event.app.cache.get_guild.return_value - event.app.cache.get_guild.assert_called_once_with(342123123) - - def test_author_property(self, event): - assert event.author is event.message.author - - def test_member_property(self, event): - assert event.member is event.message.member - - def test_get_member_when_cacheless(self, event): - event.message.app = None - - result = event.get_member() - - assert result is None - - def test_get_member(self, event): - result = event.get_member() - - assert result is event.app.cache.get_member.return_value - event.app.cache.get_member.assert_called_once_with(event.guild_id, event.author_id) + def test_get_channel( + self, + guild_message_create_event: message_events.GuildMessageCreateEvent, + guild_channel_impl: typing.Union[channels.GuildTextChannel, channels.GuildNewsChannel], + ): + with ( + mock.patch.object( + message_events.GuildMessageCreateEvent, "app", mock.Mock(traits.CacheAware) + ) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object( + patched_cache, "get_guild_channel", mock.Mock(return_value=mock.Mock(spec_set=guild_channel_impl)) + ) as patched_get_guild_channel, + ): + result = guild_message_create_event.get_channel() + assert result is patched_get_guild_channel.return_value + patched_get_guild_channel.assert_called_once_with(9121234) + + def test_get_guild_when_no_cache_trait(self, guild_message_create_event: message_events.GuildMessageCreateEvent): + with mock.patch.object(message_events.GuildMessageCreateEvent, "app", None): + assert guild_message_create_event.get_guild() is None + + def test_get_guild(self, guild_message_create_event: message_events.GuildMessageCreateEvent): + with ( + mock.patch.object( + message_events.GuildMessageCreateEvent, "app", mock.Mock(traits.CacheAware) + ) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_guild") as patched_get_guild, + ): + result = guild_message_create_event.get_guild() + + assert result is patched_get_guild.return_value + patched_get_guild.assert_called_once_with(342123123) + + def test_author_property(self, guild_message_create_event: message_events.GuildMessageCreateEvent): + assert guild_message_create_event.author is guild_message_create_event.message.author + + def test_member_property(self, guild_message_create_event: message_events.GuildMessageCreateEvent): + assert guild_message_create_event.member is guild_message_create_event.message.member + + def test_get_member_when_cacheless(self, guild_message_create_event: message_events.GuildMessageCreateEvent): + with mock.patch.object(guild_message_create_event.message, "app", None): + assert guild_message_create_event.get_member() is None + + def test_get_member(self, guild_message_create_event: message_events.GuildMessageCreateEvent): + with ( + mock.patch.object( + message_events.GuildMessageCreateEvent, "app", mock.Mock(traits.CacheAware) + ) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_member") as patched_get_member, + ): + result = guild_message_create_event.get_member() + + assert result is patched_get_member.return_value + patched_get_member.assert_called_once_with( + guild_message_create_event.guild_id, guild_message_create_event.author_id + ) class TestGuildMessageUpdateEvent: @pytest.fixture - def event(self): + def guild_message_update_event(self) -> message_events.GuildMessageUpdateEvent: return message_events.GuildMessageUpdateEvent( message=mock.Mock( spec_set=messages.Message, @@ -232,63 +298,82 @@ def event(self): shard=mock.Mock(), ) - def test_author_property(self, event): - assert event.author is event.message.author - - def test_member_property(self, event): - assert event.member is event.message.member + def test_author_property(self, guild_message_update_event: message_events.GuildMessageUpdateEvent): + assert guild_message_update_event.author is guild_message_update_event.message.author - def test_guild_id_property(self, event): - assert event.guild_id == snowflakes.Snowflake(54123123123) + def test_member_property(self, guild_message_update_event: message_events.GuildMessageUpdateEvent): + assert guild_message_update_event.member is guild_message_update_event.message.member - def test_get_channel_when_no_cache_trait(self): - event = hikari_test_helpers.mock_class_namespace( - message_events.GuildMessageUpdateEvent, app=None, init_=False - )() + def test_guild_id_property(self, guild_message_update_event: message_events.GuildMessageUpdateEvent): + assert guild_message_update_event.guild_id == snowflakes.Snowflake(54123123123) - assert event.get_channel() is None + def test_get_channel_when_no_cache_trait(self, guild_message_update_event: message_events.GuildMessageUpdateEvent): + with mock.patch.object(message_events.GuildMessageUpdateEvent, "app", None): + assert guild_message_update_event.get_channel() is None @pytest.mark.parametrize("guild_channel_impl", [channels.GuildTextChannel, channels.GuildNewsChannel]) - def test_get_channel(self, event, guild_channel_impl): - event.app.cache.get_guild_channel = mock.Mock(return_value=mock.Mock(spec_set=guild_channel_impl)) - - result = event.get_channel() - assert result is event.app.cache.get_guild_channel.return_value - event.app.cache.get_guild_channel.assert_called_once_with(800001066) - - def test_get_member_when_cacheless(self, event): - event.message.app = None - - result = event.get_member() - - assert result is None - - def test_get_member(self, event): - result = event.get_member() - - assert result is event.app.cache.get_member.return_value - event.app.cache.get_member.assert_called_once_with(event.guild_id, event.author_id) - - def test_get_guild_when_no_cache_trait(self): - event = hikari_test_helpers.mock_class_namespace( - message_events.GuildMessageUpdateEvent, app=None, init_=False - )() - - assert event.get_guild() is None - - def test_get_guild(self, event): - result = event.get_guild() - - assert result is event.app.cache.get_guild.return_value - event.app.cache.get_guild.assert_called_once_with(54123123123) - - def test_old_message(self, event): - assert event.old_message.id == 123 + def test_get_channel( + self, + guild_message_update_event: message_events.GuildMessageUpdateEvent, + guild_channel_impl: typing.Union[channels.GuildTextChannel, channels.GuildNewsChannel], + ): + with ( + mock.patch.object( + message_events.GuildMessageUpdateEvent, "app", mock.Mock(traits.CacheAware) + ) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object( + patched_cache, "get_guild_channel", mock.Mock(return_value=mock.Mock(spec_set=guild_channel_impl)) + ) as patched_get_guild_channel, + ): + result = guild_message_update_event.get_channel() + assert result is patched_get_guild_channel.return_value + patched_get_guild_channel.assert_called_once_with(800001066) + + def test_get_member_when_cacheless(self, guild_message_update_event: message_events.GuildMessageUpdateEvent): + with mock.patch.object(guild_message_update_event.message, "app", None): + assert guild_message_update_event.get_member() is None + + def test_get_member(self, guild_message_update_event: message_events.GuildMessageUpdateEvent): + with ( + mock.patch.object( + message_events.GuildMessageUpdateEvent, "app", mock.Mock(traits.CacheAware) + ) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_member") as patched_get_member, + ): + result = guild_message_update_event.get_member() + + assert result is patched_get_member.return_value + patched_get_member.assert_called_once_with( + guild_message_update_event.guild_id, guild_message_update_event.author_id + ) + + def test_get_guild_when_no_cache_trait(self, guild_message_update_event: message_events.GuildMessageUpdateEvent): + with mock.patch.object(message_events.GuildMessageUpdateEvent, "app", None): + assert guild_message_update_event.get_guild() is None + + def test_get_guild(self, guild_message_update_event: message_events.GuildMessageUpdateEvent): + with ( + mock.patch.object( + message_events.GuildMessageUpdateEvent, "app", mock.Mock(traits.CacheAware) + ) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_guild") as patched_get_guild, + ): + result = guild_message_update_event.get_guild() + + assert result is patched_get_guild.return_value + patched_get_guild.assert_called_once_with(54123123123) + + def test_old_message(self, guild_message_update_event: message_events.GuildMessageUpdateEvent): + assert guild_message_update_event.old_message is not None + assert guild_message_update_event.old_message.id == 123 class TestDMMessageUpdateEvent: @pytest.fixture - def event(self): + def dm_message_update_event(self) -> message_events.DMMessageUpdateEvent: return message_events.DMMessageUpdateEvent( message=mock.Mock( spec_set=messages.Message, author=mock.Mock(spec_set=users.User, id=snowflakes.Snowflake(8000010662)) @@ -297,42 +382,62 @@ def event(self): shard=mock.Mock(), ) - def test_old_message(self, event): - assert event.old_message.id == 123 + def test_old_message(self, dm_message_update_event: message_events.DMMessageUpdateEvent): + assert dm_message_update_event.old_message is not None + assert dm_message_update_event.old_message.id == 123 class TestGuildMessageDeleteEvent: @pytest.fixture - def event(self): + def guild_message_delete_event(self) -> message_events.GuildMessageDeleteEvent: return message_events.GuildMessageDeleteEvent( guild_id=snowflakes.Snowflake(542342354564), channel_id=snowflakes.Snowflake(54213123123), app=mock.Mock(), - shard=object(), - message_id=9, - old_message=object(), + shard=mock.Mock(), + message_id=snowflakes.Snowflake(9), + old_message=mock.Mock(), ) - def test_get_channel_when_no_cache_trait(self, event): - event.app = object() + def test_get_channel_when_no_cache_trait(self, guild_message_delete_event: message_events.GuildMessageDeleteEvent): + guild_message_delete_event.app = mock.Mock(traits.RESTAware) - assert event.get_channel() is None + assert guild_message_delete_event.get_channel() is None @pytest.mark.parametrize("guild_channel_impl", [channels.GuildTextChannel, channels.GuildNewsChannel]) - def test_get_channel(self, event, guild_channel_impl): - event.app.cache.get_guild_channel = mock.Mock(return_value=mock.Mock(spec_set=guild_channel_impl)) - result = event.get_channel() - - assert result is event.app.cache.get_guild_channel.return_value - event.app.cache.get_guild_channel.assert_called_once_with(54213123123) - - def test_get_guild_when_no_cache_trait(self, event): - event.app = object() - - assert event.get_guild() is None - - def test_get_guild_property(self, event): - result = event.get_guild() - - assert result is event.app.cache.get_guild.return_value - event.app.cache.get_guild.assert_called_once_with(542342354564) + def test_get_channel( + self, + guild_message_delete_event: message_events.GuildMessageDeleteEvent, + guild_channel_impl: typing.Union[channels.GuildTextChannel, channels.GuildNewsChannel], + ): + with ( + mock.patch.object( + message_events.GuildMessageDeleteEvent, "app", mock.Mock(traits.CacheAware) + ) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object( + patched_cache, "get_guild_channel", mock.Mock(return_value=mock.Mock(spec_set=guild_channel_impl)) + ) as patched_get_guild_channel, + ): + result = guild_message_delete_event.get_channel() + + assert result is patched_get_guild_channel.return_value + patched_get_guild_channel.assert_called_once_with(54213123123) + + def test_get_guild_when_no_cache_trait(self, guild_message_delete_event: message_events.GuildMessageDeleteEvent): + guild_message_delete_event.app = mock.Mock(traits.RESTAware) + + assert guild_message_delete_event.get_guild() is None + + def test_get_guild_property(self, guild_message_delete_event: message_events.GuildMessageDeleteEvent): + with ( + mock.patch.object( + message_events.GuildMessageDeleteEvent, "app", mock.Mock(traits.CacheAware) + ) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_guild") as patched_get_guild, + ): + result = guild_message_delete_event.get_guild() + + assert result is patched_get_guild.return_value + patched_get_guild.assert_called_once_with(542342354564) diff --git a/tests/hikari/events/test_reaction_events.py b/tests/hikari/events/test_reaction_events.py index b4c6dd231e..93bfa3f9ea 100644 --- a/tests/hikari/events/test_reaction_events.py +++ b/tests/hikari/events/test_reaction_events.py @@ -20,155 +20,326 @@ # SOFTWARE. from __future__ import annotations +import typing + import mock import pytest from hikari import emojis from hikari import guilds +from hikari import snowflakes +from hikari import traits +from hikari.api import shard as shard_api from hikari.events import reaction_events -from tests.hikari import hikari_test_helpers class TestReactionAddEvent: - def test_is_for_emoji_when_custom_emoji_matches(self): - event = hikari_test_helpers.mock_class_namespace(reaction_events.ReactionAddEvent, emoji_id=333333)() + class MockReactionAddEvent(reaction_events.ReactionAddEvent): + def __init__(self, app: traits.RESTAware): + self._app = app + self._shard = mock.Mock() + self._channel_id = snowflakes.Snowflake(123) + self._message_id = snowflakes.Snowflake(456) + self._user_id = snowflakes.Snowflake(789) + self._emoji_name = "reaction_add_emoji" + self._emoji_id = snowflakes.Snowflake(112) + self._is_animated = False + + @property + def app(self) -> traits.RESTAware: + return self._shard + + @property + def shard(self) -> shard_api.GatewayShard: + return self._shard + + @property + def channel_id(self) -> snowflakes.Snowflake: + return self._channel_id + + @property + def message_id(self) -> snowflakes.Snowflake: + return self._message_id + + @property + def user_id(self) -> snowflakes.Snowflake: + return self._user_id + + @property + def emoji_name(self) -> typing.Union[emojis.UnicodeEmoji, str, None]: + return self._emoji_name + + @property + def emoji_id(self) -> typing.Optional[snowflakes.Snowflake]: + return self._emoji_id + + @property + def is_animated(self) -> bool: + return self._is_animated - assert event.is_for_emoji(emojis.CustomEmoji(id=333333, name=None, is_animated=True)) + @pytest.fixture + def reaction_add_event(self, hikari_app: traits.RESTAware) -> reaction_events.ReactionAddEvent: + return TestReactionAddEvent.MockReactionAddEvent(hikari_app) - def test_is_for_emoji_when_unicode_emoji_matches(self): - event = hikari_test_helpers.mock_class_namespace(reaction_events.ReactionAddEvent, emoji_name="🌲")() + def test_is_for_emoji_when_custom_emoji_matches(self, reaction_add_event: reaction_events.ReactionAddEvent): + assert reaction_add_event.is_for_emoji( + emojis.CustomEmoji(id=snowflakes.Snowflake(112), name="reaction_add_emoji", is_animated=False) + ) - assert event.is_for_emoji(emojis.UnicodeEmoji("🌲")) + def test_is_for_emoji_when_unicode_emoji_matches(self, reaction_add_event: reaction_events.ReactionAddEvent): + with mock.patch.object(reaction_add_event, "_emoji_name", "🌲"): + assert reaction_add_event.is_for_emoji(emojis.UnicodeEmoji("🌲")) @pytest.mark.parametrize( ("emoji_id", "emoji_name", "emoji"), [ - (None, "hi", emojis.CustomEmoji(name=None, id=54123, is_animated=False)), + (None, "hi", emojis.CustomEmoji(name="hi", id=snowflakes.Snowflake(54123), is_animated=False)), (123321, None, emojis.UnicodeEmoji("no u")), ], ) - def test_is_for_emoji_when_wrong_emoji_type(self, emoji_id, emoji_name, emoji): - event = hikari_test_helpers.mock_class_namespace( - reaction_events.ReactionAddEvent, emoji_id=emoji_id, emoji_name=emoji_name - )() - - assert event.is_for_emoji(emoji) is False + def test_is_for_emoji_when_wrong_emoji_type( + self, + reaction_add_event: reaction_events.ReactionAddEvent, + emoji_id: typing.Optional[int], + emoji_name: typing.Optional[str], + emoji: emojis.Emoji, + ): + with ( + mock.patch.object(reaction_add_event, "_emoji_id", emoji_id), + mock.patch.object(reaction_add_event, "_emoji_name", emoji_name), + ): + assert reaction_add_event.is_for_emoji(emoji) is False @pytest.mark.parametrize( ("emoji_id", "emoji_name", "emoji"), [ (None, "hi", emojis.UnicodeEmoji("bye")), - (123321, None, emojis.CustomEmoji(id=123312123, name=None, is_animated=False)), + (123321, "test", emojis.CustomEmoji(id=snowflakes.Snowflake(123312123), name="test", is_animated=False)), ], ) - def test_is_for_emoji_when_emoji_miss_match(self, emoji_id, emoji_name, emoji): - event = hikari_test_helpers.mock_class_namespace( - reaction_events.ReactionAddEvent, emoji_id=emoji_id, emoji_name=emoji_name - )() - - assert event.is_for_emoji(emoji) is False + def test_is_for_emoji_when_emoji_miss_match( + self, + reaction_add_event: reaction_events.ReactionAddEvent, + emoji_id: typing.Optional[int], + emoji_name: typing.Optional[str], + emoji: emojis.Emoji, + ): + with ( + mock.patch.object(reaction_add_event, "_emoji_id", emoji_id), + mock.patch.object(reaction_add_event, "_emoji_name", emoji_name), + ): + assert reaction_add_event.is_for_emoji(emoji) is False class TestReactionDeleteEvent: - def test_is_for_emoji_when_custom_emoji_matches(self): - event = hikari_test_helpers.mock_class_namespace(reaction_events.ReactionDeleteEvent, emoji_id=333)() + class MockReactionDeleteEvent(reaction_events.ReactionDeleteEvent): + def __init__(self, app: traits.RESTAware): + self._app = app + self._shard = mock.Mock() + self._channel_id = snowflakes.Snowflake(123) + self._message_id = snowflakes.Snowflake(456) + self._user_id = snowflakes.Snowflake(789) + self._emoji_name = "reaction_delete_emoji" + self._emoji_id = snowflakes.Snowflake(112) + + @property + def app(self) -> traits.RESTAware: + return self._shard + + @property + def shard(self) -> shard_api.GatewayShard: + return self._shard + + @property + def channel_id(self) -> snowflakes.Snowflake: + return self._channel_id + + @property + def message_id(self) -> snowflakes.Snowflake: + return self._message_id + + @property + def user_id(self) -> snowflakes.Snowflake: + return self._user_id + + @property + def emoji_name(self) -> typing.Union[emojis.UnicodeEmoji, str, None]: + return self._emoji_name + + @property + def emoji_id(self) -> typing.Optional[snowflakes.Snowflake]: + return self._emoji_id - assert event.is_for_emoji(emojis.CustomEmoji(id=333, name=None, is_animated=True)) + @pytest.fixture + def reaction_delete_event(self, hikari_app: traits.RESTAware) -> reaction_events.ReactionDeleteEvent: + return TestReactionDeleteEvent.MockReactionDeleteEvent(hikari_app) - def test_is_for_emoji_when_unicode_emoji_matches(self): - event = hikari_test_helpers.mock_class_namespace(reaction_events.ReactionDeleteEvent, emoji_name="e")() + def test_is_for_emoji_when_custom_emoji_matches(self, reaction_delete_event: reaction_events.ReactionDeleteEvent): + assert reaction_delete_event.is_for_emoji( + emojis.CustomEmoji(id=snowflakes.Snowflake(112), name="reaction_delete_emoji", is_animated=True) + ) - assert event.is_for_emoji(emojis.UnicodeEmoji("e")) + def test_is_for_emoji_when_unicode_emoji_matches(self, reaction_delete_event: reaction_events.ReactionDeleteEvent): + with mock.patch.object(reaction_delete_event, "_emoji_name", "e"): + assert reaction_delete_event.is_for_emoji(emojis.UnicodeEmoji("e")) @pytest.mark.parametrize( ("emoji_id", "emoji_name", "emoji"), [ - (None, "hasdi", emojis.CustomEmoji(name=None, id=3123, is_animated=False)), + (None, "hasdi", emojis.CustomEmoji(name="hasdi", id=snowflakes.Snowflake(3123), is_animated=False)), (534123, None, emojis.UnicodeEmoji("nodfgdu")), ], ) - def test_is_for_emoji_when_wrong_emoji_type(self, emoji_id, emoji_name, emoji): - event = hikari_test_helpers.mock_class_namespace( - reaction_events.ReactionDeleteEvent, emoji_id=emoji_id, emoji_name=emoji_name - )() - - assert event.is_for_emoji(emoji) is False + def test_is_for_emoji_when_wrong_emoji_type( + self, + reaction_delete_event: reaction_events.ReactionDeleteEvent, + emoji_id: typing.Optional[int], + emoji_name: typing.Optional[str], + emoji: emojis.Emoji, + ): + with ( + mock.patch.object(reaction_delete_event, "_emoji_id", emoji_id), + mock.patch.object(reaction_delete_event, "_emoji_name", emoji_name), + ): + assert reaction_delete_event.is_for_emoji(emoji) is False @pytest.mark.parametrize( ("emoji_id", "emoji_name", "emoji"), [ (None, "hfdasi", emojis.UnicodeEmoji("bgye")), - (54123, None, emojis.CustomEmoji(id=34123, name=None, is_animated=False)), + (54123, "test", emojis.CustomEmoji(id=snowflakes.Snowflake(34123), name="test", is_animated=False)), ], ) - def test_is_for_emoji_when_emoji_miss_match(self, emoji_id, emoji_name, emoji): - event = hikari_test_helpers.mock_class_namespace( - reaction_events.ReactionDeleteEvent, emoji_id=emoji_id, emoji_name=emoji_name - )() - - assert event.is_for_emoji(emoji) is False + def test_is_for_emoji_when_emoji_miss_match( + self, + reaction_delete_event: reaction_events.ReactionDeleteEvent, + emoji_id: typing.Optional[int], + emoji_name: typing.Optional[str], + emoji: emojis.Emoji, + ): + with ( + mock.patch.object(reaction_delete_event, "_emoji_id", emoji_id), + mock.patch.object(reaction_delete_event, "_emoji_name", emoji_name), + ): + assert reaction_delete_event.is_for_emoji(emoji) is False class TestReactionDeleteEmojiEvent: - def test_is_for_emoji_when_custom_emoji_matches(self): - event = hikari_test_helpers.mock_class_namespace(reaction_events.ReactionDeleteEmojiEvent, emoji_id=332223333)() + class MockReactionDeleteEmojiEvent(reaction_events.ReactionDeleteEmojiEvent): + def __init__(self, app: traits.RESTAware): + self._app = app + self._shard = mock.Mock() + self._channel_id = snowflakes.Snowflake(123) + self._message_id = snowflakes.Snowflake(456) + self._emoji_name = "reaction_delete_emoji_emoji" + self._emoji_id = snowflakes.Snowflake(112) + + @property + def app(self) -> traits.RESTAware: + return self._shard + + @property + def shard(self) -> shard_api.GatewayShard: + return self._shard + + @property + def channel_id(self) -> snowflakes.Snowflake: + return self._channel_id + + @property + def message_id(self) -> snowflakes.Snowflake: + return self._message_id + + @property + def emoji_name(self) -> typing.Union[emojis.UnicodeEmoji, str, None]: + return self._emoji_name + + @property + def emoji_id(self) -> typing.Optional[snowflakes.Snowflake]: + return self._emoji_id - assert event.is_for_emoji(emojis.CustomEmoji(id=332223333, name=None, is_animated=True)) - - def test_is_for_emoji_when_unicode_emoji_matches(self): - event = hikari_test_helpers.mock_class_namespace(reaction_events.ReactionDeleteEmojiEvent, emoji_name="🌲e")() + @pytest.fixture + def reaction_delete_emoji_event(self, hikari_app: traits.RESTAware) -> reaction_events.ReactionDeleteEmojiEvent: + return TestReactionDeleteEmojiEvent.MockReactionDeleteEmojiEvent(hikari_app) + + def test_is_for_emoji_when_custom_emoji_matches( + self, reaction_delete_emoji_event: reaction_events.ReactionDeleteEmojiEvent + ): + assert reaction_delete_emoji_event.is_for_emoji( + emojis.CustomEmoji(id=snowflakes.Snowflake(112), name="reaction_delete_emoji_emoji", is_animated=True) + ) - assert event.is_for_emoji(emojis.UnicodeEmoji("🌲e")) + def test_is_for_emoji_when_unicode_emoji_matches( + self, reaction_delete_emoji_event: reaction_events.ReactionDeleteEmojiEvent + ): + with mock.patch.object(reaction_delete_emoji_event, "_emoji_name", "🌲e"): + assert reaction_delete_emoji_event.is_for_emoji(emojis.UnicodeEmoji("🌲e")) @pytest.mark.parametrize( ("emoji_id", "emoji_name", "emoji"), [ - (None, "heeei", emojis.CustomEmoji(name=None, id=541123, is_animated=False)), + (None, "heeei", emojis.CustomEmoji(name="heeei", id=snowflakes.Snowflake(541123), is_animated=False)), (1233211, None, emojis.UnicodeEmoji("no eeeu")), ], ) - def test_is_for_emoji_when_wrong_emoji_type(self, emoji_id, emoji_name, emoji): - event = hikari_test_helpers.mock_class_namespace( - reaction_events.ReactionDeleteEmojiEvent, emoji_id=emoji_id, emoji_name=emoji_name - )() - - assert event.is_for_emoji(emoji) is False + def test_is_for_emoji_when_wrong_emoji_type( + self, + reaction_delete_emoji_event: reaction_events.ReactionDeleteEmojiEvent, + emoji_id: typing.Optional[int], + emoji_name: typing.Optional[str], + emoji: emojis.Emoji, + ): + with ( + mock.patch.object(reaction_delete_emoji_event, "_emoji_id", emoji_id), + mock.patch.object(reaction_delete_emoji_event, "_emoji_name", emoji_name), + ): + assert reaction_delete_emoji_event.is_for_emoji(emoji) is False @pytest.mark.parametrize( ("emoji_id", "emoji_name", "emoji"), [ (None, "dsahi", emojis.UnicodeEmoji("bye321")), - (12331231, None, emojis.CustomEmoji(id=121233312123, name=None, is_animated=False)), + ( + 12331231, + "sadfasd", + emojis.CustomEmoji(id=snowflakes.Snowflake(121233312123), name="sdf", is_animated=False), + ), ], ) - def test_is_for_emoji_when_emoji_miss_match(self, emoji_id, emoji_name, emoji): - event = hikari_test_helpers.mock_class_namespace( - reaction_events.ReactionDeleteEmojiEvent, emoji_id=emoji_id, emoji_name=emoji_name - )() - - assert event.is_for_emoji(emoji) is False + def test_is_for_emoji_when_emoji_miss_match( + self, + reaction_delete_emoji_event: reaction_events.ReactionDeleteEmojiEvent, + emoji_id: typing.Optional[int], + emoji_name: typing.Optional[str], + emoji: emojis.Emoji, + ): + with ( + mock.patch.object(reaction_delete_emoji_event, "_emoji_id", emoji_id), + mock.patch.object(reaction_delete_emoji_event, "_emoji_name", emoji_name), + ): + assert reaction_delete_emoji_event.is_for_emoji(emoji) is False class TestGuildReactionAddEvent: @pytest.fixture - def event(self): + def guild_reaction_add_event(self) -> reaction_events.GuildReactionAddEvent: return reaction_events.GuildReactionAddEvent( - shard=object(), + shard=mock.Mock(), member=mock.MagicMock(guilds.Member), - channel_id=123, - message_id=456, + channel_id=snowflakes.Snowflake(123), + message_id=snowflakes.Snowflake(456), emoji_name="👌", emoji_id=None, is_animated=False, ) - def test_app_property(self, event): - assert event.app is event.member.app + def test_app_property(self, guild_reaction_add_event: reaction_events.GuildReactionAddEvent): + assert guild_reaction_add_event.app is guild_reaction_add_event.member.app - def test_guild_id_property(self, event): - event.member.guild_id = 123 - assert event.guild_id == 123 + def test_guild_id_property(self, guild_reaction_add_event: reaction_events.GuildReactionAddEvent): + guild_reaction_add_event.member.guild_id = snowflakes.Snowflake(123) + assert guild_reaction_add_event.guild_id == 123 - def test_user_id_property(self, event): - event.member.user.id = 123 - assert event.user_id == 123 + def test_user_id_property(self, guild_reaction_add_event: reaction_events.GuildReactionAddEvent): + with mock.patch.object(guild_reaction_add_event.member.user, "id", 123): + assert guild_reaction_add_event.user_id == 123 diff --git a/tests/hikari/events/test_role_events.py b/tests/hikari/events/test_role_events.py index e1eb824848..8b77496e00 100644 --- a/tests/hikari/events/test_role_events.py +++ b/tests/hikari/events/test_role_events.py @@ -24,45 +24,50 @@ import pytest from hikari import guilds +from hikari import snowflakes from hikari.events import role_events class TestRoleCreateEvent: @pytest.fixture - def event(self): - return role_events.RoleCreateEvent(shard=object(), role=mock.Mock(guilds.Role)) + def event(self) -> role_events.RoleCreateEvent: + return role_events.RoleCreateEvent(shard=mock.Mock(), role=mock.Mock(guilds.Role)) - def test_app_property(self, event): + def test_app_property(self, event: role_events.RoleCreateEvent): assert event.app is event.role.app - def test_guild_id_property(self, event): - event.role.guild_id = 123 + def test_guild_id_property(self, event: role_events.RoleCreateEvent): + event.role.guild_id = snowflakes.Snowflake(123) assert event.guild_id == 123 - def test_role_id_property(self, event): - event.role.id = 123 + def test_role_id_property(self, event: role_events.RoleCreateEvent): + event.role.id = snowflakes.Snowflake(123) assert event.role_id == 123 class TestRoleUpdateEvent: @pytest.fixture - def event(self): - return role_events.RoleUpdateEvent(shard=object(), role=mock.Mock(guilds.Role), old_role=mock.Mock(guilds.Role)) + def event(self) -> role_events.RoleUpdateEvent: + return role_events.RoleUpdateEvent( + shard=mock.Mock(), role=mock.Mock(guilds.Role), old_role=mock.Mock(guilds.Role) + ) - def test_app_property(self, event): + def test_app_property(self, event: role_events.RoleUpdateEvent): assert event.app is event.role.app - def test_guild_id_property(self, event): - event.role.guild_id = 123 + def test_guild_id_property(self, event: role_events.RoleUpdateEvent): + event.role.guild_id = snowflakes.Snowflake(123) assert event.guild_id == 123 - def test_role_id_property(self, event): - event.role.id = 123 + def test_role_id_property(self, event: role_events.RoleUpdateEvent): + event.role.id = snowflakes.Snowflake(123) assert event.role_id == 123 - def test_old_role(self, event): - event.old_role.guild_id = 123 - event.old_role.id = 456 - - assert event.old_role.guild_id == 123 - assert event.old_role.id == 456 + def test_old_role(self, event: role_events.RoleUpdateEvent): + with ( + mock.patch.object(event.old_role, "guild_id", snowflakes.Snowflake(123)), + mock.patch.object(event.old_role, "id", snowflakes.Snowflake(456)), + ): + assert event.old_role is not None + assert event.old_role.guild_id == 123 + assert event.old_role.id == 456 diff --git a/tests/hikari/events/test_shard_events.py b/tests/hikari/events/test_shard_events.py index 74ae41ad11..d5eec275f9 100644 --- a/tests/hikari/events/test_shard_events.py +++ b/tests/hikari/events/test_shard_events.py @@ -20,34 +20,37 @@ # SOFTWARE. from __future__ import annotations +import typing + import mock import pytest +from hikari import applications from hikari import snowflakes from hikari.events import shard_events class TestShardReadyEvent: @pytest.fixture - def event(self): + def event(self) -> shard_events.ShardReadyEvent: return shard_events.ShardReadyEvent( my_user=mock.Mock(), resume_gateway_url="testing", - shard=None, + shard=mock.Mock(), actual_gateway_version=1, session_id="ok", - application_id=1, - application_flags=1, + application_id=snowflakes.Snowflake(1), + application_flags=applications.ApplicationFlags.EMBEDDED, unavailable_guilds=[], ) - def test_app_property(self, event): + def test_app_property(self, event: shard_events.ShardReadyEvent): assert event.app is event.my_user.app class TestMemberChunkEvent: @pytest.fixture - def event(self): + def event(self) -> shard_events.MemberChunkEvent: return shard_events.MemberChunkEvent( app=mock.Mock(), shard=mock.Mock(), @@ -65,26 +68,34 @@ def event(self): nonce="blah", ) - def test___getitem___with_slice(self, event): - mock_member_0 = object() - mock_member_1 = object() - event.members = {1: object(), 55: object(), 99: mock_member_0, 455: object(), 5444: mock_member_1} + def test___getitem___with_slice(self, event: shard_events.MemberChunkEvent): + mock_member_0 = mock.Mock() + mock_member_1 = mock.Mock() + event.members = { + snowflakes.Snowflake(1): mock.Mock(), + snowflakes.Snowflake(55): mock.Mock(), + snowflakes.Snowflake(99): mock_member_0, + snowflakes.Snowflake(455): mock.Mock(), + snowflakes.Snowflake(5444): mock_member_1, + } assert event[2:5:2] == (mock_member_0, mock_member_1) - def test___getitem___with_valid_index(self, event): - mock_member = object() + def test___getitem___with_valid_index(self, event: shard_events.MemberChunkEvent): + mock_member = mock.Mock() + assert isinstance(event.members, typing.MutableMapping) # FIXME: This seems hacky + event.members[snowflakes.Snowflake(99)] = mock_member assert event[2] is mock_member with pytest.raises(IndexError): assert event[55] - def test___getitem___with_invalid_index(self, event): + def test___getitem___with_invalid_index(self, event: shard_events.MemberChunkEvent): with pytest.raises(IndexError): assert event[123] - def test___iter___(self, event): + def test___iter___(self, event: shard_events.MemberChunkEvent): member_0 = mock.Mock() member_1 = mock.Mock() member_2 = mock.Mock() @@ -97,5 +108,5 @@ def test___iter___(self, event): assert list(event) == [member_0, member_1, member_2] - def test___len___(self, event): + def test___len___(self, event: shard_events.MemberChunkEvent): assert len(event) == 4 diff --git a/tests/hikari/events/test_stage_events.py b/tests/hikari/events/test_stage_events.py index 03cd9c4977..3a2975d373 100644 --- a/tests/hikari/events/test_stage_events.py +++ b/tests/hikari/events/test_stage_events.py @@ -29,30 +29,30 @@ class TestStageInstanceCreateEvent: @pytest.fixture - def event(self): - return stage_events.StageInstanceCreateEvent(shard=object(), stage_instance=mock.Mock()) + def event(self) -> stage_events.StageInstanceCreateEvent: + return stage_events.StageInstanceCreateEvent(shard=mock.Mock(), stage_instance=mock.Mock()) - def test_app_property(self, event): + def test_app_property(self, event: stage_events.StageInstanceCreateEvent): assert event.app is event.stage_instance.app class TestStageInstanceUpdateEvent: @pytest.fixture - def event(self): + def event(self) -> stage_events.StageInstanceUpdateEvent: return stage_events.StageInstanceUpdateEvent( - shard=object(), stage_instance=mock.Mock(stage_instances.StageInstance) + shard=mock.Mock(), stage_instance=mock.Mock(stage_instances.StageInstance) ) - def test_app_property(self, event): + def test_app_property(self, event: stage_events.StageInstanceUpdateEvent): assert event.app is event.stage_instance.app class TestStageInstanceDeleteEvent: @pytest.fixture - def event(self): + def event(self) -> stage_events.StageInstanceDeleteEvent: return stage_events.StageInstanceDeleteEvent( - shard=object(), stage_instance=mock.Mock(stage_instances.StageInstance) + shard=mock.Mock(), stage_instance=mock.Mock(stage_instances.StageInstance) ) - def test_app_property(self, event): + def test_app_property(self, event: stage_events.StageInstanceDeleteEvent): assert event.app is event.stage_instance.app diff --git a/tests/hikari/events/test_typing_events.py b/tests/hikari/events/test_typing_events.py index 8a224b04f5..f1684c4776 100644 --- a/tests/hikari/events/test_typing_events.py +++ b/tests/hikari/events/test_typing_events.py @@ -20,136 +20,196 @@ # SOFTWARE. from __future__ import annotations +import datetime +import typing + import mock import pytest from hikari import channels +from hikari import snowflakes +from hikari import traits +from hikari.api import shard as shard_api from hikari.events import typing_events -from tests.hikari import hikari_test_helpers class TestTypingEvent: - @pytest.fixture - def event(self): - cls = hikari_test_helpers.mock_class_namespace( - typing_events.TypingEvent, channel_id=123, user_id=456, timestamp=object(), shard=object() - ) - - return cls() + class MockTypingEvent(typing_events.TypingEvent): + def __init__(self, app: traits.RESTAware): + self._app = app + self._shard = mock.Mock() + self._channel_id = snowflakes.Snowflake(123) + self._user_id = snowflakes.Snowflake(456) + self._timestamp = datetime.datetime.fromtimestamp(4000) + + @property + def app(self) -> traits.RESTAware: + return self._app + + @property + def shard(self) -> shard_api.GatewayShard: + return self._shard + + @property + def channel_id(self) -> snowflakes.Snowflake: + return self._channel_id + + @property + def user_id(self) -> snowflakes.Snowflake: + return self._user_id + + @property + def timestamp(self) -> datetime.datetime: + return self._timestamp - def test_get_user_when_no_cache(self, event): - event = hikari_test_helpers.mock_class_namespace(typing_events.TypingEvent, app=None)() + @pytest.fixture + def typing_event(self, hikari_app: traits.RESTAware) -> typing_events.TypingEvent: + return TestTypingEvent.MockTypingEvent(hikari_app) - assert event.get_user() is None + def test_get_user_when_no_cache(self, typing_event: typing_events.TypingEvent): + with mock.patch.object(typing_event, "_app", None): + assert typing_event.get_user() is None - def test_get_user(self, event): - assert event.get_user() is event.app.cache.get_user.return_value + def test_get_user(self, typing_event: typing_events.TypingEvent): + with ( + mock.patch.object(typing_event, "_app", mock.Mock(traits.CacheAware)) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_user") as patched_get_user, + ): + assert typing_event.get_user() is patched_get_user.return_value - def test_trigger_typing(self, event): - event.app.rest.trigger_typing = mock.Mock() - result = event.trigger_typing() - event.app.rest.trigger_typing.assert_called_once_with(123) - assert result is event.app.rest.trigger_typing.return_value + def test_trigger_typing(self, typing_event: typing_events.TypingEvent): + typing_event.app.rest.trigger_typing = mock.Mock() + result = typing_event.trigger_typing() + typing_event.app.rest.trigger_typing.assert_called_once_with(123) + assert result is typing_event.app.rest.trigger_typing.return_value class TestGuildTypingEvent: @pytest.fixture - def event(self): - cls = hikari_test_helpers.mock_class_namespace(typing_events.GuildTypingEvent) - - return cls( - channel_id=123, - timestamp=object(), - shard=object(), - guild_id=789, + def guild_typing_event(self) -> typing_events.GuildTypingEvent: + return typing_events.GuildTypingEvent( + channel_id=snowflakes.Snowflake(123), + timestamp=mock.Mock(), + shard=mock.Mock(), + guild_id=snowflakes.Snowflake(789), member=mock.Mock(id=456, app=mock.Mock(rest=mock.AsyncMock())), ) - def test_app_property(self, event): - assert event.app is event.member.app - - def test_get_channel_when_no_cache(self): - event = hikari_test_helpers.mock_class_namespace(typing_events.GuildTypingEvent, app=None, init_=False)() + def test_app_property(self, guild_typing_event: typing_events.GuildTypingEvent): + assert guild_typing_event.app is guild_typing_event.member.app - assert event.get_channel() is None + def test_get_channel_when_no_cache(self, guild_typing_event: typing_events.GuildTypingEvent): + with mock.patch.object(typing_events.GuildTypingEvent, "app", None): + assert guild_typing_event.get_channel() is None @pytest.mark.parametrize("guild_channel_impl", [channels.GuildNewsChannel, channels.GuildTextChannel]) - def test_get_channel(self, event, guild_channel_impl): - event.app.cache.get_guild_channel = mock.Mock(return_value=mock.Mock(spec_set=guild_channel_impl)) - result = event.get_channel() - - assert result is event.app.cache.get_guild_channel.return_value - event.app.cache.get_guild_channel.assert_called_once_with(123) + def test_get_channel( + self, + guild_typing_event: typing_events.GuildTypingEvent, + guild_channel_impl: typing.Union[channels.GuildNewsChannel, channels.GuildTextChannel], + ): + with ( + mock.patch.object(typing_events.GuildTypingEvent, "app", mock.Mock(traits.CacheAware)) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object( + patched_cache, "get_guild_channel", mock.Mock(return_value=mock.Mock(spec_set=guild_channel_impl)) + ) as patched_get_guild_channel, + ): + result = guild_typing_event.get_channel() + + assert result is patched_get_guild_channel.return_value + patched_get_guild_channel.assert_called_once_with(123) @pytest.mark.asyncio - async def test_get_guild_when_no_cache(self): - event = hikari_test_helpers.mock_class_namespace(typing_events.GuildTypingEvent, app=None, init_=False)() - - assert event.get_guild() is None - - def test_get_guild_when_available(self, event): - result = event.get_guild() - - assert result is event.app.cache.get_available_guild.return_value - event.app.cache.get_available_guild.assert_called_once_with(789) - event.app.cache.get_unavailable_guild.assert_not_called() - - def test_get_guild_when_unavailable(self, event): - event.app.cache.get_available_guild.return_value = None - result = event.get_guild() - - assert result is event.app.cache.get_unavailable_guild.return_value - event.app.cache.get_unavailable_guild.assert_called_once_with(789) - event.app.cache.get_available_guild.assert_called_once_with(789) - - def test_user_id(self, event): - assert event.user_id == event.member.id - assert event.user_id == 456 + async def test_get_guild_when_no_cache(self, guild_typing_event: typing_events.GuildTypingEvent): + with mock.patch.object(typing_events.GuildTypingEvent, "app", None): + assert guild_typing_event.get_guild() is None + + def test_get_guild_when_available(self, guild_typing_event: typing_events.GuildTypingEvent): + with ( + mock.patch.object(typing_events.GuildTypingEvent, "app", mock.Mock(traits.CacheAware)) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_available_guild") as patched_get_available_guild, + mock.patch.object(patched_cache, "get_unavailable_guild") as patched_get_unavailable_guild, + ): + result = guild_typing_event.get_guild() + + assert result is patched_get_available_guild.return_value + patched_get_available_guild.assert_called_once_with(789) + patched_get_unavailable_guild.assert_not_called() + + def test_get_guild_when_unavailable(self, guild_typing_event: typing_events.GuildTypingEvent): + with ( + mock.patch.object(typing_events.GuildTypingEvent, "app", mock.Mock(traits.CacheAware)) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_available_guild", return_value=None) as patched_get_available_guild, + mock.patch.object(patched_cache, "get_unavailable_guild") as patched_get_unavailable_guild, + ): + result = guild_typing_event.get_guild() + + assert result is patched_get_unavailable_guild.return_value + patched_get_unavailable_guild.assert_called_once_with(789) + patched_get_available_guild.assert_called_once_with(789) + + def test_user_id(self, guild_typing_event: typing_events.GuildTypingEvent): + assert guild_typing_event.user_id == guild_typing_event.member.id + assert guild_typing_event.user_id == 456 @pytest.mark.asyncio @pytest.mark.parametrize("guild_channel_impl", [channels.GuildNewsChannel, channels.GuildTextChannel]) - async def test_fetch_channel(self, event, guild_channel_impl): - event.app.rest.fetch_channel = mock.AsyncMock(return_value=mock.Mock(spec_set=guild_channel_impl)) - await event.fetch_channel() + async def test_fetch_channel( + self, + guild_typing_event: typing_events.GuildTypingEvent, + guild_channel_impl: typing.Union[channels.GuildNewsChannel, channels.GuildTextChannel], + ): + guild_typing_event.app.rest.fetch_channel = mock.AsyncMock(return_value=mock.Mock(spec_set=guild_channel_impl)) + await guild_typing_event.fetch_channel() - event.app.rest.fetch_channel.assert_awaited_once_with(123) + guild_typing_event.app.rest.fetch_channel.assert_awaited_once_with(123) @pytest.mark.asyncio - async def test_fetch_guild(self, event): - await event.fetch_guild() + async def test_fetch_guild(self, guild_typing_event: typing_events.GuildTypingEvent): + with mock.patch.object(guild_typing_event.app.rest, "fetch_guild") as patched_fetch_guild: + await guild_typing_event.fetch_guild() - event.app.rest.fetch_guild.assert_awaited_once_with(789) + patched_fetch_guild.assert_awaited_once_with(789) @pytest.mark.asyncio - async def test_fetch_guild_preview(self, event): - await event.fetch_guild_preview() + async def test_fetch_guild_preview(self, guild_typing_event: typing_events.GuildTypingEvent): + with mock.patch.object(guild_typing_event.app.rest, "fetch_guild_preview") as patched_fetch_guild_preview: + await guild_typing_event.fetch_guild_preview() - event.app.rest.fetch_guild_preview.assert_awaited_once_with(789) + patched_fetch_guild_preview.assert_awaited_once_with(789) @pytest.mark.asyncio - async def test_fetch_member(self, event): - await event.fetch_member() + async def test_fetch_member(self, guild_typing_event: typing_events.GuildTypingEvent): + with mock.patch.object(guild_typing_event.app.rest, "fetch_member") as patched_fetch_member: + await guild_typing_event.fetch_member() - event.app.rest.fetch_member.assert_awaited_once_with(789, 456) + patched_fetch_member.assert_awaited_once_with(789, 456) @pytest.mark.asyncio class TestDMTypingEvent: @pytest.fixture - def event(self): - cls = hikari_test_helpers.mock_class_namespace(typing_events.DMTypingEvent) - - return cls( - channel_id=123, timestamp=object(), shard=object(), app=mock.Mock(rest=mock.AsyncMock()), user_id=456 + def dm_typing_event(self) -> typing_events.DMTypingEvent: + return typing_events.DMTypingEvent( + channel_id=snowflakes.Snowflake(123), + timestamp=mock.Mock(), + shard=mock.Mock(), + app=mock.Mock(rest=mock.AsyncMock()), + user_id=snowflakes.Snowflake(456), ) - async def test_fetch_channel(self, event): - event.app.rest.fetch_channel = mock.AsyncMock(return_value=mock.Mock(spec_set=channels.DMChannel)) - await event.fetch_channel() + async def test_fetch_channel(self, dm_typing_event: typing_events.DMTypingEvent): + dm_typing_event.app.rest.fetch_channel = mock.AsyncMock(return_value=mock.Mock(spec_set=channels.DMChannel)) + await dm_typing_event.fetch_channel() - event.app.rest.fetch_channel.assert_awaited_once_with(123) + dm_typing_event.app.rest.fetch_channel.assert_awaited_once_with(123) - async def test_fetch_user(self, event): - await event.fetch_user() + async def test_fetch_user(self, dm_typing_event: typing_events.DMTypingEvent): + with mock.patch.object(dm_typing_event.app.rest, "fetch_user") as patched_fetch_user: + await dm_typing_event.fetch_user() - event.app.rest.fetch_user.assert_awaited_once_with(456) + patched_fetch_user.assert_awaited_once_with(456) diff --git a/tests/hikari/events/test_user_events.py b/tests/hikari/events/test_user_events.py index 97f62db67f..350df4be3a 100644 --- a/tests/hikari/events/test_user_events.py +++ b/tests/hikari/events/test_user_events.py @@ -28,8 +28,8 @@ class TestOwnUserUpdateEvent: @pytest.fixture - def event(self): - return user_events.OwnUserUpdateEvent(shard=None, old_user=None, user=mock.Mock()) + def event(self) -> user_events.OwnUserUpdateEvent: + return user_events.OwnUserUpdateEvent(shard=mock.Mock(), old_user=None, user=mock.Mock()) - def test_app_property(self, event): + def test_app_property(self, event: user_events.OwnUserUpdateEvent): assert event.app is event.user.app diff --git a/tests/hikari/events/test_voice_events.py b/tests/hikari/events/test_voice_events.py index d86810de4f..26c5cdbefd 100644 --- a/tests/hikari/events/test_voice_events.py +++ b/tests/hikari/events/test_voice_events.py @@ -23,39 +23,45 @@ import mock import pytest +from hikari import snowflakes from hikari import voices from hikari.events import voice_events class TestVoiceStateUpdateEvent: @pytest.fixture - def event(self): + def event(self) -> voice_events.VoiceStateUpdateEvent: return voice_events.VoiceStateUpdateEvent( - shard=object(), state=mock.Mock(voices.VoiceState), old_state=mock.Mock(voices.VoiceState) + shard=mock.Mock(), state=mock.Mock(voices.VoiceState), old_state=mock.Mock(voices.VoiceState) ) - def test_app_property(self, event): + def test_app_property(self, event: voice_events.VoiceStateUpdateEvent): assert event.app is event.state.app - def test_guild_id_property(self, event): - event.state.guild_id = 123 + def test_guild_id_property(self, event: voice_events.VoiceStateUpdateEvent): + event.state.guild_id = snowflakes.Snowflake(123) assert event.guild_id == 123 - def test_old_voice_state(self, event): - event.old_state.guild_id = 123 + def test_old_voice_state(self, event: voice_events.VoiceStateUpdateEvent): + assert event.old_state is not None + event.old_state.guild_id = snowflakes.Snowflake(123) assert event.old_state.guild_id == 123 class TestVoiceServerUpdateEvent: @pytest.fixture - def event(self): + def event(self) -> voice_events.VoiceServerUpdateEvent: return voice_events.VoiceServerUpdateEvent( - app=None, shard=object(), guild_id=123, token="token", raw_endpoint="voice.discord.com:123" + app=mock.Mock(), + shard=mock.Mock(), + guild_id=snowflakes.Snowflake(123), + token="token", + raw_endpoint="voice.discord.com:123", ) - def test_endpoint_property(self, event): + def test_endpoint_property(self, event: voice_events.VoiceServerUpdateEvent): assert event.endpoint == "wss://voice.discord.com:123" - def test_endpoint_property_when_raw_endpoint_is_None(self, event): + def test_endpoint_property_when_raw_endpoint_is_None(self, event: voice_events.VoiceServerUpdateEvent): event.raw_endpoint = None assert event.endpoint is None diff --git a/tests/hikari/hikari_test_helpers.py b/tests/hikari/hikari_test_helpers.py index 7b126cd45f..6c8c9f0c06 100644 --- a/tests/hikari/hikari_test_helpers.py +++ b/tests/hikari/hikari_test_helpers.py @@ -81,7 +81,7 @@ def mock_class_namespace( return type(name, (klass,), namespace) -def retry(max_retries): +def retry(max_retries: int): def decorator(func): assert inspect.iscoroutinefunction(func), "retry only supports coroutine functions currently" diff --git a/tests/hikari/impl/test_buckets.py b/tests/hikari/impl/test_buckets.py index a10d7cb4e0..53a2ba7704 100644 --- a/tests/hikari/impl/test_buckets.py +++ b/tests/hikari/impl/test_buckets.py @@ -21,10 +21,9 @@ from __future__ import annotations import asyncio -import contextlib -import math import time +import typing import mock import pytest @@ -33,7 +32,6 @@ from hikari.impl import buckets from hikari.impl import rate_limits from hikari.internal import routes -from hikari.internal import time as hikari_date class TestRESTBucket: @@ -42,11 +40,11 @@ def template(self): return routes.Route("GET", "/foo/bar") @pytest.fixture - def compiled_route(self, template): + def compiled_route(self, template: routes.Route): return routes.CompiledRoute("/foo/bar", template, "1a2b3c") @pytest.mark.asyncio - async def test_usage_when_unknown(self, compiled_route): + async def test_usage_when_unknown(self, compiled_route: routes.CompiledRoute): bucket = buckets.RESTBucket(buckets.UNKNOWN_HASH, compiled_route, mock.Mock(), float("inf")) bucket._is_unknown = True @@ -58,7 +56,7 @@ async def test_usage_when_unknown(self, compiled_route): assert bucket._in_transit == 0 @pytest.mark.asyncio - async def test_usage_when_resolved(self, compiled_route): + async def test_usage_when_resolved(self, compiled_route: routes.CompiledRoute): global_ratelimit = mock.Mock(acquire=mock.AsyncMock(), reset_at=None) bucket = buckets.RESTBucket("resolved bucket", compiled_route, global_ratelimit, float("inf")) bucket._is_unknown = False @@ -71,7 +69,7 @@ async def test_usage_when_resolved(self, compiled_route): assert bucket._in_transit == 0 - def test_update_rate_limit_when_no_issues(self, compiled_route): + def test_update_rate_limit_when_no_issues(self, compiled_route: routes.CompiledRoute): bucket = buckets.RESTBucket("updating ratelimit test", compiled_route, mock.Mock(), float("inf")) now = time.time() @@ -90,7 +88,7 @@ def test_update_rate_limit_when_no_issues(self, compiled_route): assert bucket.reset_at == now_update assert bucket.period == 2 - def test_update_rate_limit_when_period_far_apart(self, compiled_route): + def test_update_rate_limit_when_period_far_apart(self, compiled_route: routes.CompiledRoute): bucket = buckets.RESTBucket("updating ratelimit test", compiled_route, mock.Mock(), float("inf")) now = 12123123 @@ -113,7 +111,7 @@ def test_update_rate_limit_when_period_far_apart(self, compiled_route): assert bucket._out_of_sync is False @pytest.mark.asyncio - async def test_acquire_when_too_long_ratelimit(self, compiled_route): + async def test_acquire_when_too_long_ratelimit(self, compiled_route: routes.CompiledRoute): bucket = buckets.RESTBucket("spaghetti", compiled_route, mock.Mock(), 60) bucket.move_at = time.time() + 999999999999999999999999999 bucket._is_unknown = False @@ -127,7 +125,7 @@ async def test_acquire_when_too_long_ratelimit(self, compiled_route): assert bucket._in_transit == 0 @pytest.mark.asyncio - async def test_acquire_when_too_long_global_ratelimit(self, compiled_route): + async def test_acquire_when_too_long_global_ratelimit(self, compiled_route: routes.CompiledRoute): global_ratelimit = mock.Mock(reset_at=time.time() + 999999999999999999999999999) bucket = buckets.RESTBucket("spaghetti", compiled_route, global_ratelimit, 1) @@ -144,7 +142,7 @@ async def test_acquire_when_too_long_global_ratelimit(self, compiled_route): global_ratelimit.acquire.assert_not_called() @pytest.mark.asyncio - async def test_acquire_when_unknown_bucket(self, compiled_route): + async def test_acquire_when_unknown_bucket(self, compiled_route: routes.CompiledRoute): global_ratelimit = mock.Mock(acquire=mock.AsyncMock(), reset_at=None) bucket = buckets.RESTBucket("UNKNOWN", compiled_route, global_ratelimit, float("inf")) @@ -158,7 +156,7 @@ async def test_acquire_when_unknown_bucket(self, compiled_route): global_ratelimit.acquire.assert_not_called() @pytest.mark.asyncio - async def test_acquire_when_resolved_bucket(self, compiled_route): + async def test_acquire_when_resolved_bucket(self, compiled_route: routes.CompiledRoute): global_ratelimit = mock.Mock(acquire=mock.AsyncMock(), reset_at=None) bucket = buckets.RESTBucket("spaghetti", compiled_route, global_ratelimit, float("inf")) bucket._is_unknown = False @@ -170,7 +168,7 @@ async def test_acquire_when_resolved_bucket(self, compiled_route): super_acquire.assert_awaited_once_with() global_ratelimit.acquire.assert_awaited_once_with() - def test_resolve_when_not_unknown(self, compiled_route): + def test_resolve_when_not_unknown(self, compiled_route: routes.CompiledRoute): bucket = buckets.RESTBucket("spaghetti", compiled_route, mock.Mock(), float("inf")) bucket._is_unknown = False @@ -179,7 +177,7 @@ def test_resolve_when_not_unknown(self, compiled_route): assert bucket.name == "spaghetti" - def test_resolve(self, compiled_route): + def test_resolve(self, compiled_route: routes.CompiledRoute): bucket = buckets.RESTBucket(buckets.UNKNOWN_HASH, compiled_route, mock.Mock(), float("inf")) bucket.resolve("test", 1, 3, 123123124, 4) @@ -195,17 +193,17 @@ class TestRESTBucketManager: @pytest.fixture def bucket_manager(self): manager = buckets.RESTBucketManager(max_rate_limit=float("inf")) - manager._gc_task = object() + manager._gc_task = mock.Mock() return manager - def test_max_rate_limit_property(self, bucket_manager): - bucket_manager._max_rate_limit = object() + def test_max_rate_limit_property(self, bucket_manager: buckets.RESTBucketManager): + bucket_manager._max_rate_limit = mock.Mock() assert bucket_manager.max_rate_limit is bucket_manager._max_rate_limit @pytest.mark.asyncio - async def test_close(self, bucket_manager): + async def test_close(self, bucket_manager: buckets.RESTBucketManager): class GcTaskMock: def __init__(self): self._awaited_count = 0 @@ -239,7 +237,7 @@ def assert_awaited_once(self): gc_task.assert_awaited_once() @pytest.mark.asyncio - async def test_start(self, bucket_manager): + async def test_start(self, bucket_manager: buckets.RESTBucketManager): bucket_manager._gc_task = None bucket_manager.start() @@ -253,26 +251,28 @@ async def test_start(self, bucket_manager): pass @pytest.mark.asyncio - async def test_start_when_already_started(self, bucket_manager): - bucket_manager._gc_task = object() + async def test_start_when_already_started(self, bucket_manager: buckets.RESTBucketManager): + bucket_manager._gc_task = mock.Mock() with pytest.raises(errors.ComponentStateConflictError): bucket_manager.start() @pytest.mark.asyncio - async def test_gc_makes_gc_pass(self, bucket_manager): + async def test_gc_makes_gc_pass(self, bucket_manager: buckets.RESTBucketManager): class ExitError(Exception): ... - with mock.patch.object(buckets.RESTBucketManager, "_purge_stale_buckets") as purge_stale_buckets: - with mock.patch.object(asyncio, "sleep", side_effect=[None, ExitError]): - with pytest.raises(ExitError): - await bucket_manager._gc(0.001, 33) + with ( + mock.patch.object(buckets.RESTBucketManager, "_purge_stale_buckets") as purge_stale_buckets, + mock.patch.object(asyncio, "sleep", side_effect=[None, ExitError]), + pytest.raises(ExitError), + ): + await bucket_manager._gc(0.001, 33) purge_stale_buckets.assert_called_with(33) @pytest.mark.asyncio async def test_purge_stale_buckets_any_buckets_that_are_empty_but_still_rate_limited_are_kept_alive( - self, bucket_manager + self, bucket_manager: buckets.RESTBucketManager ): bucket = mock.Mock() bucket.is_empty = True @@ -288,7 +288,7 @@ async def test_purge_stale_buckets_any_buckets_that_are_empty_but_still_rate_lim @pytest.mark.asyncio async def test_purge_stale_buckets_any_buckets_that_are_empty_but_not_rate_limited_and_not_expired_are_kept_alive( - self, bucket_manager + self, bucket_manager: buckets.RESTBucketManager ): bucket = mock.Mock() bucket.is_empty = True @@ -304,7 +304,7 @@ async def test_purge_stale_buckets_any_buckets_that_are_empty_but_not_rate_limit @pytest.mark.asyncio async def test_purge_stale_buckets_any_buckets_that_are_empty_but_not_rate_limited_and_expired_are_closed( - self, bucket_manager + self, bucket_manager: buckets.RESTBucketManager ): bucket = mock.Mock() bucket.is_empty = True @@ -319,7 +319,9 @@ async def test_purge_stale_buckets_any_buckets_that_are_empty_but_not_rate_limit bucket.close.assert_called_once() @pytest.mark.asyncio - async def test_purge_stale_buckets_any_buckets_that_are_not_empty_are_kept_alive(self, bucket_manager): + async def test_purge_stale_buckets_any_buckets_that_are_not_empty_are_kept_alive( + self, bucket_manager: buckets.RESTBucketManager + ): bucket = mock.Mock() bucket.is_empty = False bucket.is_unknown = True @@ -334,7 +336,7 @@ async def test_purge_stale_buckets_any_buckets_that_are_not_empty_are_kept_alive @pytest.mark.asyncio async def test_acquire_route_when_not_in_routes_to_real_hashes_makes_new_bucket_using_initial_hash( - self, bucket_manager + self, bucket_manager: buckets.RESTBucketManager ): route = mock.Mock() @@ -349,7 +351,9 @@ async def test_acquire_route_when_not_in_routes_to_real_hashes_makes_new_bucket_ create_unknown_hash.assert_has_calls((mock.call(route, "auth_hash"), mock.call(route, "auth_hash"))) @pytest.mark.asyncio - async def test_acquire_route_when_not_in_routes_to_real_hashes_doesnt_cache_route(self, bucket_manager): + async def test_acquire_route_when_not_in_routes_to_real_hashes_doesnt_cache_route( + self, bucket_manager: buckets.RESTBucketManager + ): route = mock.Mock() route.create_real_bucket_hash = mock.Mock(wraps=lambda initial_hash, auth: initial_hash + ";" + auth + ";bobs") @@ -359,7 +363,7 @@ async def test_acquire_route_when_not_in_routes_to_real_hashes_doesnt_cache_rout @pytest.mark.asyncio async def test_acquire_route_when_route_cached_already_obtains_hash_from_route_and_bucket_from_hash( - self, bucket_manager + self, bucket_manager: buckets.RESTBucketManager ): mock_route_hash = 123123123 route = mock.Mock() @@ -372,7 +376,7 @@ async def test_acquire_route_when_route_cached_already_obtains_hash_from_route_a assert bucket_manager.acquire_bucket(route, "auth") is bucket @pytest.mark.asyncio - async def test_acquire_route_returns_context_manager(self, bucket_manager): + async def test_acquire_route_returns_context_manager(self, bucket_manager: buckets.RESTBucketManager): route = mock.Mock() bucket = mock.Mock(reset_at=time.time() + 999999999999999999999999999) @@ -384,7 +388,9 @@ async def test_acquire_route_returns_context_manager(self, bucket_manager): assert bucket_manager.acquire_bucket(route, "auth") is bucket @pytest.mark.asyncio - async def test_acquire_unknown_route_returns_context_manager_for_new_bucket(self, bucket_manager): + async def test_acquire_unknown_route_returns_context_manager_for_new_bucket( + self, bucket_manager: buckets.RESTBucketManager + ): mock_route_hash = 123123123 route = mock.Mock() route.route = mock.Mock(__hash__=lambda _: mock_route_hash) @@ -396,7 +402,9 @@ async def test_acquire_unknown_route_returns_context_manager_for_new_bucket(self assert bucket_manager.acquire_bucket(route, "auth") is bucket @pytest.mark.asyncio - async def test_update_rate_limits_if_wrong_bucket_hash_reroutes_route(self, bucket_manager): + async def test_update_rate_limits_if_wrong_bucket_hash_reroutes_route( + self, bucket_manager: buckets.RESTBucketManager + ): mock_route_hash = 123123123 route = mock.Mock() route.route = mock.Mock(__hash__=lambda _: mock_route_hash) @@ -414,7 +422,9 @@ async def test_update_rate_limits_if_wrong_bucket_hash_reroutes_route(self, buck bucket.return_value.resolve.assert_called_once_with("blep;auth_hash;bobs", 22, 23, 123123.56, 3.56) @pytest.mark.asyncio - async def test_update_rate_limits_if_unknown_bucket_hash_reroutes_route(self, bucket_manager): + async def test_update_rate_limits_if_unknown_bucket_hash_reroutes_route( + self, bucket_manager: buckets.RESTBucketManager + ): mock_route_hash = 123123123 route = mock.Mock() route.route = mock.Mock(__hash__=lambda _: mock_route_hash) @@ -441,7 +451,9 @@ async def test_update_rate_limits_if_unknown_bucket_hash_reroutes_route(self, bu create_authentication_hash.assert_called_once_with("auth") @pytest.mark.asyncio - async def test_update_rate_limits_if_right_bucket_hash_does_nothing_to_hash(self, bucket_manager): + async def test_update_rate_limits_if_right_bucket_hash_does_nothing_to_hash( + self, bucket_manager: buckets.RESTBucketManager + ): route = mock.Mock() route.create_real_bucket_hash = mock.Mock(wraps=lambda initial_hash, auth: initial_hash + ";" + auth + ";bobs") bucket_manager._route_hash_to_bucket_hash[route.route] = "123" @@ -456,7 +468,7 @@ async def test_update_rate_limits_if_right_bucket_hash_does_nothing_to_hash(self bucket.update_rate_limit.assert_called_once_with(22, 23, 123123.53, 7.65) @pytest.mark.asyncio - async def test_update_rate_limits_updates_params(self, bucket_manager): + async def test_update_rate_limits_updates_params(self, bucket_manager: buckets.RESTBucketManager): route = mock.Mock() route.create_real_bucket_hash = mock.Mock(wraps=lambda initial_hash, auth: initial_hash + ";" + auth + ";bobs") bucket_manager._route_hash_to_bucket_hash[route.route] = "123" @@ -469,6 +481,6 @@ async def test_update_rate_limits_updates_params(self, bucket_manager): bucket.update_rate_limit.assert_called_once_with(22, 23, 123123123.53, 5.32) @pytest.mark.parametrize(("gc_task", "is_alive"), [(None, False), ("some", True)]) - def test_is_alive(self, bucket_manager, gc_task, is_alive): + def test_is_alive(self, bucket_manager: buckets.RESTBucketManager, gc_task: typing.Optional[str], is_alive: bool): bucket_manager._gc_task = gc_task assert bucket_manager.is_alive is is_alive diff --git a/tests/hikari/impl/test_cache.py b/tests/hikari/impl/test_cache.py index 3e5e271c65..c02c5776bf 100644 --- a/tests/hikari/impl/test_cache.py +++ b/tests/hikari/impl/test_cache.py @@ -33,10 +33,12 @@ from hikari import polls from hikari import snowflakes from hikari import stickers +from hikari import traits from hikari import undefined from hikari import users from hikari import voices from hikari.api import config as config_api +from hikari.channels import GuildTextChannel from hikari.impl import cache as cache_impl_ from hikari.impl import config from hikari.internal import cache as cache_utilities @@ -44,96 +46,124 @@ from tests.hikari import hikari_test_helpers -class StubModel(snowflakes.Unique): - id = None - - def __init__(self, id=0): - self.id = snowflakes.Snowflake(id) - - class TestCacheImpl: @pytest.fixture - def app_impl(self): + def app_impl(self) -> traits.RESTAware: return mock.Mock() @pytest.fixture - def cache_impl(self, app_impl): + def cache_impl(self, app_impl: traits.RESTAware) -> cache_impl_.CacheImpl: return hikari_test_helpers.mock_class_namespace(cache_impl_.CacheImpl, slots_=False)( app=app_impl, settings=config.CacheSettings() ) - def test__init___(self, app_impl): + def test__init___(self, app_impl: traits.RESTAware): with mock.patch.object(cache_impl_.CacheImpl, "_create_cache") as create_cache: cache_impl_.CacheImpl(app_impl, config.CacheSettings()) create_cache.assert_called_once_with() - def test__is_cache_enabled_for(self, cache_impl): + def test__is_cache_enabled_for(self, cache_impl: cache_impl_.CacheImpl): cache_impl._settings.components = config_api.CacheComponents.MESSAGES | config_api.CacheComponents.GUILDS assert cache_impl._is_cache_enabled_for(config_api.CacheComponents.MESSAGES) is True - def test__increment_ref_count(self, cache_impl): + def test__increment_ref_count(self, cache_impl: cache_impl_.CacheImpl): mock_obj = mock.Mock(ref_count=10) cache_impl._increment_ref_count(mock_obj, 10) assert mock_obj.ref_count == 20 - def test_clear(self, cache_impl): + def test_clear(self, cache_impl: cache_impl_.CacheImpl): cache_impl._create_cache = mock.Mock() cache_impl.clear() cache_impl._create_cache.assert_called_once_with() - def test_clear_dm_channel_ids(self, cache_impl): - cache_impl._dm_channel_entries = collections.FreezableDict({123: 5423, 23123: 54123123}) + def test_clear_dm_channel_ids(self, cache_impl: cache_impl_.CacheImpl): + cache_impl._dm_channel_entries = collections.FreezableDict( + { + snowflakes.Snowflake(123): snowflakes.Snowflake(5423), + snowflakes.Snowflake(23123): snowflakes.Snowflake(54123123), + } + ) result = cache_impl.clear_dm_channel_ids() assert result == {123: 5423, 23123: 54123123} assert cache_impl._dm_channel_entries == {} - def test_delete_dm_channel_id(self, cache_impl): - cache_impl._dm_channel_entries = collections.FreezableDict({54123: 2123123, 5434: 1234}) + def test_delete_dm_channel_id(self, cache_impl: cache_impl_.CacheImpl): + cache_impl._dm_channel_entries = collections.FreezableDict( + { + snowflakes.Snowflake(54123): snowflakes.Snowflake(2123123), + snowflakes.Snowflake(5434): snowflakes.Snowflake(1234), + } + ) result = cache_impl.delete_dm_channel_id(54123) assert result == 2123123 assert cache_impl._dm_channel_entries == {5434: 1234} - def test_delete_dm_channel_id_for_unknown_user(self, cache_impl): - cache_impl._dm_channel_entries = collections.FreezableDict({54123: 2123123, 5434: 1234}) + def test_delete_dm_channel_id_for_unknown_user(self, cache_impl: cache_impl_.CacheImpl): + cache_impl._dm_channel_entries = collections.FreezableDict( + { + snowflakes.Snowflake(54123): snowflakes.Snowflake(2123123), + snowflakes.Snowflake(5434): snowflakes.Snowflake(1234), + } + ) result = cache_impl.delete_dm_channel_id(65234123123) assert result is None assert cache_impl._dm_channel_entries == {54123: 2123123, 5434: 1234} - def test_get_dm_channel_id(self, cache_impl): - cache_impl._dm_channel_entries = collections.FreezableDict({24123123: 453123, 5423: 123, 653: 1223}) + def test_get_dm_channel_id(self, cache_impl: cache_impl_.CacheImpl): + cache_impl._dm_channel_entries = collections.FreezableDict( + { + snowflakes.Snowflake(24123123): snowflakes.Snowflake(453123), + snowflakes.Snowflake(5423): snowflakes.Snowflake(123), + snowflakes.Snowflake(653): snowflakes.Snowflake(1223), + } + ) assert cache_impl.get_dm_channel_id(5423) == 123 - def test_get_dm_channel_id_for_unknown_user(self, cache_impl): - cache_impl._dm_channel_entries = collections.FreezableDict({24123123: 453123, 5423: 123, 653: 1223}) + def test_get_dm_channel_id_for_unknown_user(self, cache_impl: cache_impl_.CacheImpl): + cache_impl._dm_channel_entries = collections.FreezableDict( + { + snowflakes.Snowflake(24123123): snowflakes.Snowflake(453123), + snowflakes.Snowflake(5423): snowflakes.Snowflake(123), + snowflakes.Snowflake(653): snowflakes.Snowflake(1223), + } + ) assert cache_impl.get_dm_channel_id(65656565) is None - def test_get_dm_channel_ids_view(self, cache_impl): - cache_impl._dm_channel_entries = collections.FreezableDict({222: 333, 643: 213, 54234: 1231321}) + def test_get_dm_channel_ids_view(self, cache_impl: cache_impl_.CacheImpl): + cache_impl._dm_channel_entries = collections.FreezableDict( + { + snowflakes.Snowflake(222): snowflakes.Snowflake(333), + snowflakes.Snowflake(643): snowflakes.Snowflake(213), + snowflakes.Snowflake(54234): snowflakes.Snowflake(1231321), + } + ) assert cache_impl.get_dm_channel_ids_view() == {222: 333, 643: 213, 54234: 1231321} - def test_set_dm_channel_id(self, cache_impl): - cache_impl._user_entries = collections.FreezableDict({43123123: object()}) + def test_set_dm_channel_id( + self, cache_impl: cache_impl_.CacheImpl, hikari_user: users.User, hikari_guild_text_channel: GuildTextChannel + ): + cache_impl._user_entries = collections.FreezableDict({snowflakes.Snowflake(789): mock.Mock()}) - cache_impl.set_dm_channel_id(StubModel(43123123), StubModel(12222)) + cache_impl.set_dm_channel_id(hikari_user, hikari_guild_text_channel) - assert cache_impl._dm_channel_entries == {43123123: 12222} + assert cache_impl._dm_channel_entries == {789: 4560} - def test__build_emoji(self, cache_impl): + def test__build_emoji(self, cache_impl: cache_impl_.CacheImpl): mock_user = mock.MagicMock(users.User) emoji_data = cache_utilities.KnownCustomEmojiData( id=snowflakes.Snowflake(1233534234), @@ -160,7 +190,7 @@ def test__build_emoji(self, cache_impl): assert emoji.is_managed is False assert emoji.is_available is True - def test__build_emoji_with_no_user(self, cache_impl): + def test__build_emoji_with_no_user(self, cache_impl: cache_impl_.CacheImpl): emoji_data = cache_utilities.KnownCustomEmojiData( id=snowflakes.Snowflake(1233534234), name="OKOKOKOKOK", @@ -172,14 +202,14 @@ def test__build_emoji_with_no_user(self, cache_impl): is_managed=False, is_available=True, ) - cache_impl._build_user = mock.Mock() + cache_impl._build_user = mock.Mock() # FIXME: Seems this is calling an object that does not actually exist. emoji = cache_impl._build_emoji(emoji_data) - cache_impl._build_user.assert_not_called() + cache_impl._build_user.assert_not_called() # FIXME: Seems this is calling an object that does not actually exist. assert emoji.user is None - def test_clear_emojis(self, cache_impl): + def test_clear_emojis(self, cache_impl: cache_impl_.CacheImpl): mock_user_1 = mock.Mock(cache_utilities.RefCell[users.User]) mock_user_2 = mock.Mock(cache_utilities.RefCell[users.User]) mock_emoji_data_1 = mock.Mock(cache_utilities.KnownCustomEmojiData, user=mock_user_1) @@ -213,7 +243,7 @@ def test_clear_emojis(self, cache_impl): [mock.call(mock_emoji_data_1), mock.call(mock_emoji_data_2), mock.call(mock_emoji_data_3)] ) - def test_clear_emojis_for_guild(self, cache_impl): + def test_clear_emojis_for_guild(self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild): mock_user_1 = mock.Mock(cache_utilities.RefCell[users.User]) mock_user_2 = mock.Mock(cache_utilities.RefCell[users.User]) mock_emoji_data_1 = mock.Mock(cache_utilities.KnownCustomEmojiData, user=mock_user_1) @@ -237,21 +267,18 @@ def test_clear_emojis_for_guild(self, cache_impl): ) guild_record = cache_utilities.GuildRecord(emojis=emoji_ids) cache_impl._guild_entries = collections.FreezableDict( - { - snowflakes.Snowflake(432123123): guild_record, - snowflakes.Snowflake(1): mock.Mock(cache_utilities.GuildRecord), - } + {snowflakes.Snowflake(123): guild_record, snowflakes.Snowflake(1): mock.Mock(cache_utilities.GuildRecord)} ) cache_impl._build_emoji = mock.Mock(side_effect=[mock_emoji_1, mock_emoji_2, mock_emoji_3]) cache_impl._remove_guild_record_if_empty = mock.Mock() cache_impl._garbage_collect_user = mock.Mock() - emoji_mapping = cache_impl.clear_emojis_for_guild(StubModel(432123123)) + emoji_mapping = cache_impl.clear_emojis_for_guild(hikari_partial_guild) cache_impl._garbage_collect_user.assert_has_calls( [mock.call(mock_user_1, decrement=1), mock.call(mock_user_2, decrement=1)] ) - cache_impl._remove_guild_record_if_empty.assert_called_once_with(snowflakes.Snowflake(432123123), guild_record) + cache_impl._remove_guild_record_if_empty.assert_called_once_with(snowflakes.Snowflake(123), guild_record) assert emoji_mapping == { snowflakes.Snowflake(6873451): mock_emoji_1, snowflakes.Snowflake(43123123): mock_emoji_2, @@ -260,16 +287,20 @@ def test_clear_emojis_for_guild(self, cache_impl): assert cache_impl._emoji_entries == collections.FreezableDict( {snowflakes.Snowflake(111): mock_other_emoji_data} ) - assert cache_impl._guild_entries[snowflakes.Snowflake(432123123)].emojis is None + assert cache_impl._guild_entries[snowflakes.Snowflake(123)].emojis is None cache_impl._build_emoji.assert_has_calls( [mock.call(mock_emoji_data_1), mock.call(mock_emoji_data_2), mock.call(mock_emoji_data_3)] ) - def test_clear_emojis_for_guild_for_unknown_emoji_cache(self, cache_impl): - cache_impl._emoji_entries = {snowflakes.Snowflake(3123): mock.Mock(cache_utilities.KnownCustomEmojiData)} + def test_clear_emojis_for_guild_for_unknown_emoji_cache( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): + cache_impl._emoji_entries = collections.FreezableDict( + {snowflakes.Snowflake(3123): mock.Mock(cache_utilities.KnownCustomEmojiData)} + ) cache_impl._guild_entries = collections.FreezableDict( { - snowflakes.Snowflake(432123123): cache_utilities.GuildRecord(), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(), snowflakes.Snowflake(1): mock.Mock(cache_utilities.GuildRecord), } ) @@ -277,15 +308,19 @@ def test_clear_emojis_for_guild_for_unknown_emoji_cache(self, cache_impl): cache_impl._remove_guild_record_if_empty = mock.Mock() cache_impl._garbage_collect_user = mock.Mock() - emoji_mapping = cache_impl.clear_emojis_for_guild(StubModel(432123123)) + emoji_mapping = cache_impl.clear_emojis_for_guild(hikari_partial_guild) cache_impl._garbage_collect_user.assert_not_called() cache_impl._remove_guild_record_if_empty.assert_not_called() assert emoji_mapping == {} cache_impl._build_emoji.assert_not_called() - def test_clear_emojis_for_guild_for_unknown_record(self, cache_impl): - cache_impl._emoji_entries = {snowflakes.Snowflake(123124): mock.Mock(cache_utilities.KnownCustomEmojiData)} + def test_clear_emojis_for_guild_for_unknown_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): + cache_impl._emoji_entries = collections.FreezableDict( + {snowflakes.Snowflake(123124): mock.Mock(cache_utilities.KnownCustomEmojiData)} + ) cache_impl._guild_entries = collections.FreezableDict( {snowflakes.Snowflake(1): mock.Mock(cache_utilities.GuildRecord)} ) @@ -293,24 +328,24 @@ def test_clear_emojis_for_guild_for_unknown_record(self, cache_impl): cache_impl._remove_guild_record_if_empty = mock.Mock() cache_impl._garbage_collect_user = mock.Mock() - emoji_mapping = cache_impl.clear_emojis_for_guild(StubModel(432123123)) + emoji_mapping = cache_impl.clear_emojis_for_guild(hikari_partial_guild) cache_impl._garbage_collect_user.assert_not_called() cache_impl._remove_guild_record_if_empty.assert_not_called() assert emoji_mapping == {} cache_impl._build_emoji.assert_not_called() - def test_delete_emoji(self, cache_impl): - mock_user = object() + def test_delete_emoji(self, cache_impl: cache_impl_.CacheImpl, hikari_custom_emoji: emojis.CustomEmoji): + mock_user = mock.Mock() mock_emoji_data = mock.Mock( cache_utilities.KnownCustomEmojiData, user=mock_user, guild_id=snowflakes.Snowflake(123333) ) mock_other_emoji_data = mock.Mock(cache_utilities.KnownCustomEmojiData) mock_emoji = mock.Mock(emojis.KnownCustomEmoji) emoji_ids = collections.SnowflakeSet() - emoji_ids.add_all([snowflakes.Snowflake(12354123), snowflakes.Snowflake(432123)]) + emoji_ids.add_all([snowflakes.Snowflake(444), snowflakes.Snowflake(432123)]) cache_impl._emoji_entries = collections.FreezableDict( - {snowflakes.Snowflake(12354123): mock_emoji_data, snowflakes.Snowflake(999): mock_other_emoji_data} + {snowflakes.Snowflake(444): mock_emoji_data, snowflakes.Snowflake(999): mock_other_emoji_data} ) cache_impl._guild_entries = collections.FreezableDict( {snowflakes.Snowflake(123333): cache_utilities.GuildRecord(emojis=emoji_ids)} @@ -318,7 +353,7 @@ def test_delete_emoji(self, cache_impl): cache_impl._garbage_collect_user = mock.Mock() cache_impl._build_emoji = mock.Mock(return_value=mock_emoji) - result = cache_impl.delete_emoji(StubModel(12354123)) + result = cache_impl.delete_emoji(hikari_custom_emoji) assert result is mock_emoji assert cache_impl._emoji_entries == {snowflakes.Snowflake(999): mock_other_emoji_data} @@ -326,16 +361,18 @@ def test_delete_emoji(self, cache_impl): cache_impl._build_emoji.assert_called_once_with(mock_emoji_data) cache_impl._garbage_collect_user.assert_called_once_with(mock_user, decrement=1) - def test_delete_emoji_without_user(self, cache_impl): + def test_delete_emoji_without_user( + self, cache_impl: cache_impl_.CacheImpl, hikari_custom_emoji: emojis.CustomEmoji + ): mock_emoji_data = mock.Mock( cache_utilities.KnownCustomEmojiData, user=None, guild_id=snowflakes.Snowflake(123333) ) mock_other_emoji_data = mock.Mock(cache_utilities.KnownCustomEmojiData) mock_emoji = mock.Mock(emojis.KnownCustomEmoji) emoji_ids = collections.SnowflakeSet() - emoji_ids.add_all([snowflakes.Snowflake(12354123), snowflakes.Snowflake(432123)]) + emoji_ids.add_all([snowflakes.Snowflake(444), snowflakes.Snowflake(432123)]) cache_impl._emoji_entries = collections.FreezableDict( - {snowflakes.Snowflake(12354123): mock_emoji_data, snowflakes.Snowflake(999): mock_other_emoji_data} + {snowflakes.Snowflake(444): mock_emoji_data, snowflakes.Snowflake(999): mock_other_emoji_data} ) cache_impl._guild_entries = collections.FreezableDict( {snowflakes.Snowflake(123333): cache_utilities.GuildRecord(emojis=emoji_ids)} @@ -343,7 +380,7 @@ def test_delete_emoji_without_user(self, cache_impl): cache_impl._garbage_collect_user = mock.Mock() cache_impl._build_emoji = mock.Mock(return_value=mock_emoji) - result = cache_impl.delete_emoji(StubModel(12354123)) + result = cache_impl.delete_emoji(hikari_custom_emoji) assert result is mock_emoji assert cache_impl._emoji_entries == {snowflakes.Snowflake(999): mock_other_emoji_data} @@ -351,36 +388,40 @@ def test_delete_emoji_without_user(self, cache_impl): cache_impl._build_emoji.assert_called_once_with(mock_emoji_data) cache_impl._garbage_collect_user.assert_not_called() - def test_delete_emoji_for_unknown_emoji(self, cache_impl): + def test_delete_emoji_for_unknown_emoji( + self, cache_impl: cache_impl_.CacheImpl, hikari_custom_emoji: emojis.CustomEmoji + ): cache_impl._garbage_collect_user = mock.Mock() cache_impl._build_emoji = mock.Mock() - result = cache_impl.delete_emoji(StubModel(12354123)) + result = cache_impl.delete_emoji(hikari_custom_emoji) assert result is None cache_impl._build_emoji.assert_not_called() cache_impl._garbage_collect_user.assert_not_called() - def test_get_emoji(self, cache_impl): + def test_get_emoji(self, cache_impl: cache_impl_.CacheImpl, hikari_custom_emoji: emojis.CustomEmoji): mock_emoji_data = mock.Mock(cache_utilities.KnownCustomEmojiData) mock_emoji = mock.Mock(emojis.KnownCustomEmoji) cache_impl._build_emoji = mock.Mock(return_value=mock_emoji) - cache_impl._emoji_entries = collections.FreezableDict({snowflakes.Snowflake(3422123): mock_emoji_data}) + cache_impl._emoji_entries = collections.FreezableDict({snowflakes.Snowflake(444): mock_emoji_data}) - result = cache_impl.get_emoji(StubModel(3422123)) + result = cache_impl.get_emoji(hikari_custom_emoji) assert result is mock_emoji cache_impl._build_emoji.assert_called_once_with(mock_emoji_data) - def test_get_emoji_with_unknown_emoji(self, cache_impl): + def test_get_emoji_with_unknown_emoji( + self, cache_impl: cache_impl_.CacheImpl, hikari_custom_emoji: emojis.CustomEmoji + ): cache_impl._build_emoji = mock.Mock() - result = cache_impl.get_emoji(StubModel(3422123)) + result = cache_impl.get_emoji(hikari_custom_emoji) assert result is None cache_impl._build_emoji.assert_not_called() - def test_get_emojis_view(self, cache_impl): + def test_get_emojis_view(self, cache_impl: cache_impl_.CacheImpl): mock_emoji_data_1 = mock.Mock(cache_utilities.KnownCustomEmojiData) mock_emoji_data_2 = mock.Mock(cache_utilities.KnownCustomEmojiData) mock_emoji_1 = mock.Mock(emojis.KnownCustomEmoji) @@ -395,7 +436,9 @@ def test_get_emojis_view(self, cache_impl): assert result == {snowflakes.Snowflake(123123123): mock_emoji_1, snowflakes.Snowflake(43156234): mock_emoji_2} cache_impl._build_emoji.assert_has_calls([mock.call(mock_emoji_data_1), mock.call(mock_emoji_data_2)]) - def test_get_emojis_view_for_guild(self, cache_impl): + def test_get_emojis_view_for_guild( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_emoji_data_1 = mock.Mock(cache_utilities.KnownCustomEmojiData) mock_emoji_data_2 = mock.Mock(cache_utilities.KnownCustomEmojiData) mock_emoji_1 = mock.Mock(emojis.KnownCustomEmoji) @@ -412,48 +455,52 @@ def test_get_emojis_view_for_guild(self, cache_impl): cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(99999): mock.Mock(cache_utilities.GuildRecord), - snowflakes.Snowflake(9342123): cache_utilities.GuildRecord(emojis=emoji_ids), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(emojis=emoji_ids), } ) cache_impl._build_emoji = mock.Mock(side_effect=[mock_emoji_1, mock_emoji_2]) - result = cache_impl.get_emojis_view_for_guild(StubModel(9342123)) + result = cache_impl.get_emojis_view_for_guild(hikari_partial_guild) assert result == {snowflakes.Snowflake(65123): mock_emoji_1, snowflakes.Snowflake(43156234): mock_emoji_2} cache_impl._build_emoji.assert_has_calls([mock.call(mock_emoji_data_1), mock.call(mock_emoji_data_2)]) - def test_get_emojis_view_for_guild_for_unknown_emoji_cache(self, cache_impl): + def test_get_emojis_view_for_guild_for_unknown_emoji_cache( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): cache_impl._emoji_entries = collections.FreezableDict( {snowflakes.Snowflake(9999): mock.Mock(cache_utilities.KnownCustomEmojiData)} ) cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(99999): mock.Mock(cache_utilities.GuildRecord), - snowflakes.Snowflake(9342123): cache_utilities.GuildRecord(), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(), } ) cache_impl._build_emoji = mock.Mock() - result = cache_impl.get_emojis_view_for_guild(StubModel(9342123)) + result = cache_impl.get_emojis_view_for_guild(hikari_partial_guild) assert result == {} cache_impl._build_emoji.assert_not_called() - def test_get_emojis_view_for_guild_for_unknown_record(self, cache_impl): + def test_get_emojis_view_for_guild_for_unknown_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): cache_impl._emoji_entries = collections.FreezableDict( {snowflakes.Snowflake(12354345): mock.Mock(cache_utilities.KnownCustomEmojiData)} ) cache_impl._guild_entries = collections.FreezableDict( - {snowflakes.Snowflake(9342123): cache_utilities.GuildRecord()} + {snowflakes.Snowflake(123): cache_utilities.GuildRecord()} ) cache_impl._build_emoji = mock.Mock() - result = cache_impl.get_emojis_view_for_guild(StubModel(9342123)) + result = cache_impl.get_emojis_view_for_guild(hikari_partial_guild) assert result == {} cache_impl._build_emoji.assert_not_called() - def test_set_emoji(self, cache_impl): + def test_set_emoji(self, cache_impl: cache_impl_.CacheImpl): mock_user = mock.Mock(users.User, id=snowflakes.Snowflake(654234)) mock_reffed_user = cache_utilities.RefCell(mock_user) emoji = emojis.KnownCustomEmoji( @@ -474,8 +521,9 @@ def test_set_emoji(self, cache_impl): cache_impl.set_emoji(emoji) assert 65234 in cache_impl._guild_entries - assert cache_impl._guild_entries[snowflakes.Snowflake(65234)].emojis - assert 5123123 in cache_impl._guild_entries[snowflakes.Snowflake(65234)].emojis + guild_entry_emojis = cache_impl._guild_entries[snowflakes.Snowflake(65234)].emojis + assert guild_entry_emojis is not None + assert snowflakes.Snowflake(5123123) in guild_entry_emojis assert 5123123 in cache_impl._emoji_entries emoji_data = cache_impl._emoji_entries[snowflakes.Snowflake(5123123)] cache_impl._set_user.assert_called_once_with(mock_user) @@ -491,7 +539,7 @@ def test_set_emoji(self, cache_impl): assert emoji_data.is_managed is True assert emoji_data.is_available is False - def test_set_emoji_with_pre_cached_emoji(self, cache_impl): + def test_set_emoji_with_pre_cached_emoji(self, cache_impl: cache_impl_.CacheImpl): mock_user = mock.Mock(users.User, id=snowflakes.Snowflake(654234)) emoji = emojis.KnownCustomEmoji( app=cache_impl._app, @@ -509,15 +557,17 @@ def test_set_emoji_with_pre_cached_emoji(self, cache_impl): {snowflakes.Snowflake(5123123): mock.Mock(cache_utilities.KnownCustomEmojiData)} ) cache_impl._set_user = mock.Mock() - cache_impl._increment_user_ref_count = mock.Mock() + cache_impl._increment_user_ref_count = ( + mock.Mock() + ) # FIXME: Seems this is calling an object that does not actually exist. cache_impl.set_emoji(emoji) assert 5123123 in cache_impl._emoji_entries cache_impl._set_user.assert_called_once_with(mock_user) - cache_impl._increment_user_ref_count.assert_not_called() + cache_impl._increment_user_ref_count.assert_not_called() # FIXME: Seems this is calling an object that does not actually exist. - def test_update_emoji(self, cache_impl): + def test_update_emoji(self, cache_impl: cache_impl_.CacheImpl): mock_cached_emoji_1 = mock.Mock(emojis.KnownCustomEmoji) mock_cached_emoji_2 = mock.Mock(emojis.KnownCustomEmoji) mock_emoji = mock.Mock(emojis.KnownCustomEmoji, id=snowflakes.Snowflake(54123123)) @@ -532,7 +582,7 @@ def test_update_emoji(self, cache_impl): ) cache_impl.set_emoji.assert_called_once_with(mock_emoji) - def test__build_sticker(self, cache_impl): + def test__build_sticker(self, cache_impl: cache_impl_.CacheImpl): mock_user = mock.MagicMock(users.User) sticker_data = cache_utilities.GuildStickerData( id=snowflakes.Snowflake(1233534234), @@ -557,7 +607,7 @@ def test__build_sticker(self, cache_impl): assert sticker.is_available is True assert sticker.description == "hi" - def test__build_sticker_with_no_user(self, cache_impl): + def test__build_sticker_with_no_user(self, cache_impl: cache_impl_.CacheImpl): sticker_data = cache_utilities.GuildStickerData( id=snowflakes.Snowflake(1233534234), name="OKOKOKOKOK", @@ -568,14 +618,14 @@ def test__build_sticker_with_no_user(self, cache_impl): user=None, is_available=True, ) - cache_impl._build_user = mock.Mock() + cache_impl._build_user = mock.Mock() # FIXME: Seems this is calling an object that does not actually exist. sticker = cache_impl._build_sticker(sticker_data) - cache_impl._build_user.assert_not_called() + cache_impl._build_user.assert_not_called() # FIXME: Seems this is calling an object that does not actually exist. assert sticker.user is None - def test_clear_stickers(self, cache_impl): + def test_clear_stickers(self, cache_impl: cache_impl_.CacheImpl): mock_user_1 = mock.Mock(cache_utilities.RefCell[users.User]) mock_user_2 = mock.Mock(cache_utilities.RefCell[users.User]) mock_sticker_data_1 = mock.Mock(cache_utilities.GuildStickerData, user=mock_user_1) @@ -609,7 +659,9 @@ def test_clear_stickers(self, cache_impl): [mock.call(mock_sticker_data_1), mock.call(mock_sticker_data_2), mock.call(mock_sticker_data_3)] ) - def test_clear_stickers_for_guild(self, cache_impl): + def test_clear_stickers_for_guild( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_user_1 = mock.Mock(cache_utilities.RefCell[users.User]) mock_user_2 = mock.Mock(cache_utilities.RefCell[users.User]) mock_sticker_data_1 = mock.Mock(cache_utilities.GuildStickerData, user=mock_user_1) @@ -633,21 +685,18 @@ def test_clear_stickers_for_guild(self, cache_impl): ) guild_record = cache_utilities.GuildRecord(stickers=sticker_ids) cache_impl._guild_entries = collections.FreezableDict( - { - snowflakes.Snowflake(432123123): guild_record, - snowflakes.Snowflake(1): mock.Mock(cache_utilities.GuildRecord), - } + {snowflakes.Snowflake(123): guild_record, snowflakes.Snowflake(1): mock.Mock(cache_utilities.GuildRecord)} ) cache_impl._build_sticker = mock.Mock(side_effect=[mock_sticker_1, mock_sticker_2, mock_sticker_3]) cache_impl._remove_guild_record_if_empty = mock.Mock() cache_impl._garbage_collect_user = mock.Mock() - sticker_mapping = cache_impl.clear_stickers_for_guild(StubModel(432123123)) + sticker_mapping = cache_impl.clear_stickers_for_guild(hikari_partial_guild) cache_impl._garbage_collect_user.assert_has_calls( [mock.call(mock_user_1, decrement=1), mock.call(mock_user_2, decrement=1)] ) - cache_impl._remove_guild_record_if_empty.assert_called_once_with(snowflakes.Snowflake(432123123), guild_record) + cache_impl._remove_guild_record_if_empty.assert_called_once_with(snowflakes.Snowflake(123), guild_record) assert sticker_mapping == { snowflakes.Snowflake(6873451): mock_sticker_1, snowflakes.Snowflake(43123123): mock_sticker_2, @@ -656,16 +705,20 @@ def test_clear_stickers_for_guild(self, cache_impl): assert cache_impl._sticker_entries == collections.FreezableDict( {snowflakes.Snowflake(111): mock_other_sticker_data} ) - assert cache_impl._guild_entries[snowflakes.Snowflake(432123123)].stickers is None + assert cache_impl._guild_entries[snowflakes.Snowflake(123)].stickers is None cache_impl._build_sticker.assert_has_calls( [mock.call(mock_sticker_data_1), mock.call(mock_sticker_data_2), mock.call(mock_sticker_data_3)] ) - def test_clear_stickers_for_guild_for_unknown_sticker_cache(self, cache_impl): - cache_impl._sticker_entries = {snowflakes.Snowflake(3123): mock.Mock(cache_utilities.GuildStickerData)} + def test_clear_stickers_for_guild_for_unknown_sticker_cache( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): + cache_impl._sticker_entries = collections.FreezableDict( + {snowflakes.Snowflake(3123): mock.Mock(cache_utilities.GuildStickerData)} + ) cache_impl._guild_entries = collections.FreezableDict( { - snowflakes.Snowflake(432123123): cache_utilities.GuildRecord(), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(), snowflakes.Snowflake(1): mock.Mock(cache_utilities.GuildRecord), } ) @@ -673,15 +726,19 @@ def test_clear_stickers_for_guild_for_unknown_sticker_cache(self, cache_impl): cache_impl._remove_guild_record_if_empty = mock.Mock() cache_impl._garbage_collect_user = mock.Mock() - sticker_mapping = cache_impl.clear_stickers_for_guild(StubModel(432123123)) + sticker_mapping = cache_impl.clear_stickers_for_guild(hikari_partial_guild) cache_impl._garbage_collect_user.assert_not_called() cache_impl._remove_guild_record_if_empty.assert_not_called() assert sticker_mapping == {} cache_impl._build_sticker.assert_not_called() - def test_clear_stickers_for_guild_for_unknown_record(self, cache_impl): - cache_impl._sticker_entries = {snowflakes.Snowflake(123124): mock.Mock(cache_utilities.GuildStickerData)} + def test_clear_stickers_for_guild_for_unknown_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): + cache_impl._sticker_entries = collections.FreezableDict( + {snowflakes.Snowflake(123124): mock.Mock(cache_utilities.GuildStickerData)} + ) cache_impl._guild_entries = collections.FreezableDict( {snowflakes.Snowflake(1): mock.Mock(cache_utilities.GuildRecord)} ) @@ -689,24 +746,24 @@ def test_clear_stickers_for_guild_for_unknown_record(self, cache_impl): cache_impl._remove_guild_record_if_empty = mock.Mock() cache_impl._garbage_collect_user = mock.Mock() - sticker_mapping = cache_impl.clear_stickers_for_guild(StubModel(432123123)) + sticker_mapping = cache_impl.clear_stickers_for_guild(hikari_partial_guild) cache_impl._garbage_collect_user.assert_not_called() cache_impl._remove_guild_record_if_empty.assert_not_called() assert sticker_mapping == {} cache_impl._build_sticker.assert_not_called() - def test_delete_sticker(self, cache_impl): - mock_user = object() + def test_delete_sticker(self, cache_impl: cache_impl_.CacheImpl, hikari_guild_sticker: stickers.GuildSticker): + mock_user = mock.Mock() mock_sticker_data = mock.Mock( cache_utilities.GuildStickerData, user=mock_user, guild_id=snowflakes.Snowflake(123333) ) mock_other_sticker_data = mock.Mock(cache_utilities.GuildStickerData) mock_sticker = mock.Mock(stickers.GuildSticker) sticker_ids = collections.SnowflakeSet() - sticker_ids.add_all([snowflakes.Snowflake(12354123), snowflakes.Snowflake(432123)]) + sticker_ids.add_all([snowflakes.Snowflake(2220), snowflakes.Snowflake(432123)]) cache_impl._sticker_entries = collections.FreezableDict( - {snowflakes.Snowflake(12354123): mock_sticker_data, snowflakes.Snowflake(999): mock_other_sticker_data} + {snowflakes.Snowflake(2220): mock_sticker_data, snowflakes.Snowflake(999): mock_other_sticker_data} ) cache_impl._guild_entries = collections.FreezableDict( {snowflakes.Snowflake(123333): cache_utilities.GuildRecord(stickers=sticker_ids)} @@ -714,7 +771,7 @@ def test_delete_sticker(self, cache_impl): cache_impl._garbage_collect_user = mock.Mock() cache_impl._build_sticker = mock.Mock(return_value=mock_sticker) - result = cache_impl.delete_sticker(StubModel(12354123)) + result = cache_impl.delete_sticker(hikari_guild_sticker) assert result is mock_sticker assert cache_impl._sticker_entries == {snowflakes.Snowflake(999): mock_other_sticker_data} @@ -722,16 +779,18 @@ def test_delete_sticker(self, cache_impl): cache_impl._build_sticker.assert_called_once_with(mock_sticker_data) cache_impl._garbage_collect_user.assert_called_once_with(mock_user, decrement=1) - def test_delete_sticker_without_user(self, cache_impl): + def test_delete_sticker_without_user( + self, cache_impl: cache_impl_.CacheImpl, hikari_guild_sticker: stickers.GuildSticker + ): mock_sticker_data = mock.Mock( cache_utilities.GuildStickerData, user=None, guild_id=snowflakes.Snowflake(123333) ) mock_other_sticker_data = mock.Mock(cache_utilities.GuildStickerData) mock_sticker = mock.Mock(stickers.GuildSticker) sticker_ids = collections.SnowflakeSet() - sticker_ids.add_all([snowflakes.Snowflake(12354123), snowflakes.Snowflake(432123)]) + sticker_ids.add_all([snowflakes.Snowflake(2220), snowflakes.Snowflake(432123)]) cache_impl._sticker_entries = collections.FreezableDict( - {snowflakes.Snowflake(12354123): mock_sticker_data, snowflakes.Snowflake(999): mock_other_sticker_data} + {snowflakes.Snowflake(2220): mock_sticker_data, snowflakes.Snowflake(999): mock_other_sticker_data} ) cache_impl._guild_entries = collections.FreezableDict( {snowflakes.Snowflake(123333): cache_utilities.GuildRecord(stickers=sticker_ids)} @@ -739,7 +798,7 @@ def test_delete_sticker_without_user(self, cache_impl): cache_impl._garbage_collect_user = mock.Mock() cache_impl._build_sticker = mock.Mock(return_value=mock_sticker) - result = cache_impl.delete_sticker(StubModel(12354123)) + result = cache_impl.delete_sticker(hikari_guild_sticker) assert result is mock_sticker assert cache_impl._sticker_entries == {snowflakes.Snowflake(999): mock_other_sticker_data} @@ -747,36 +806,40 @@ def test_delete_sticker_without_user(self, cache_impl): cache_impl._build_sticker.assert_called_once_with(mock_sticker_data) cache_impl._garbage_collect_user.assert_not_called() - def test_delete_sticker_for_unknown_sticker(self, cache_impl): + def test_delete_sticker_for_unknown_sticker( + self, cache_impl: cache_impl_.CacheImpl, hikari_guild_sticker: stickers.GuildSticker + ): cache_impl._garbage_collect_user = mock.Mock() cache_impl._build_sticker = mock.Mock() - result = cache_impl.delete_sticker(StubModel(12354123)) + result = cache_impl.delete_sticker(hikari_guild_sticker) assert result is None cache_impl._build_sticker.assert_not_called() cache_impl._garbage_collect_user.assert_not_called() - def test_get_sticker(self, cache_impl): + def test_get_sticker(self, cache_impl: cache_impl_.CacheImpl, hikari_guild_sticker: stickers.GuildSticker): mock_sticker_data = mock.Mock(cache_utilities.GuildStickerData) mock_sticker = mock.Mock(emojis.KnownCustomEmoji) cache_impl._build_sticker = mock.Mock(return_value=mock_sticker) - cache_impl._sticker_entries = collections.FreezableDict({snowflakes.Snowflake(3422123): mock_sticker_data}) + cache_impl._sticker_entries = collections.FreezableDict({snowflakes.Snowflake(2220): mock_sticker_data}) - result = cache_impl.get_sticker(StubModel(3422123)) + result = cache_impl.get_sticker(hikari_guild_sticker) assert result is mock_sticker cache_impl._build_sticker.assert_called_once_with(mock_sticker_data) - def test_get_sticker_with_unknown_sticker(self, cache_impl): + def test_get_sticker_with_unknown_sticker( + self, cache_impl: cache_impl_.CacheImpl, hikari_guild_sticker: stickers.GuildSticker + ): cache_impl._build_sticker = mock.Mock() - result = cache_impl.get_sticker(StubModel(3422123)) + result = cache_impl.get_sticker(hikari_guild_sticker) assert result is None cache_impl._build_sticker.assert_not_called() - def test_get_stickers_view(self, cache_impl): + def test_get_stickers_view(self, cache_impl: cache_impl_.CacheImpl): mock_sticker_data_1 = mock.Mock(cache_utilities.GuildStickerData) mock_sticker_data_2 = mock.Mock(cache_utilities.GuildStickerData) mock_sticker_1 = mock.Mock(stickers.GuildSticker) @@ -794,7 +857,9 @@ def test_get_stickers_view(self, cache_impl): } cache_impl._build_sticker.assert_has_calls([mock.call(mock_sticker_data_1), mock.call(mock_sticker_data_2)]) - def test_get_stickers_view_for_guild(self, cache_impl): + def test_get_stickers_view_for_guild( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_sticker_data_1 = mock.Mock(cache_utilities.GuildStickerData) mock_sticker_data_2 = mock.Mock(cache_utilities.GuildStickerData) mock_sticker_1 = mock.Mock(stickers.GuildSticker) @@ -811,48 +876,52 @@ def test_get_stickers_view_for_guild(self, cache_impl): cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(99999): mock.Mock(cache_utilities.GuildRecord), - snowflakes.Snowflake(9342123): cache_utilities.GuildRecord(stickers=sticker_ids), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(stickers=sticker_ids), } ) cache_impl._build_sticker = mock.Mock(side_effect=[mock_sticker_1, mock_sticker_2]) - result = cache_impl.get_stickers_view_for_guild(StubModel(9342123)) + result = cache_impl.get_stickers_view_for_guild(hikari_partial_guild) assert result == {snowflakes.Snowflake(65123): mock_sticker_1, snowflakes.Snowflake(43156234): mock_sticker_2} cache_impl._build_sticker.assert_has_calls([mock.call(mock_sticker_data_1), mock.call(mock_sticker_data_2)]) - def test_get_stickers_view_for_guild_for_unknown_sticker_cache(self, cache_impl): + def test_get_stickers_view_for_guild_for_unknown_sticker_cache( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): cache_impl._sticker_entries = collections.FreezableDict( {snowflakes.Snowflake(9999): mock.Mock(cache_utilities.GuildStickerData)} ) cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(99999): mock.Mock(cache_utilities.GuildRecord), - snowflakes.Snowflake(9342123): cache_utilities.GuildRecord(), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(), } ) cache_impl._build_sticker = mock.Mock() - result = cache_impl.get_stickers_view_for_guild(StubModel(9342123)) + result = cache_impl.get_stickers_view_for_guild(hikari_partial_guild) assert result == {} cache_impl._build_sticker.assert_not_called() - def test_get_stickers_view_for_guild_for_unknown_record(self, cache_impl): + def test_get_stickers_view_for_guild_for_unknown_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): cache_impl._sticker_entries = collections.FreezableDict( {snowflakes.Snowflake(12354345): mock.Mock(cache_utilities.GuildStickerData)} ) cache_impl._guild_entries = collections.FreezableDict( - {snowflakes.Snowflake(9342123): cache_utilities.GuildRecord()} + {snowflakes.Snowflake(123): cache_utilities.GuildRecord()} ) cache_impl._build_sticker = mock.Mock() - result = cache_impl.get_stickers_view_for_guild(StubModel(9342123)) + result = cache_impl.get_stickers_view_for_guild(hikari_partial_guild) assert result == {} cache_impl._build_sticker.assert_not_called() - def test_set_sticker(self, cache_impl): + def test_set_sticker(self, cache_impl: cache_impl_.CacheImpl): mock_user = mock.Mock(users.User, id=snowflakes.Snowflake(654234)) mock_reffed_user = cache_utilities.RefCell(mock_user) sticker = stickers.GuildSticker( @@ -871,8 +940,9 @@ def test_set_sticker(self, cache_impl): cache_impl.set_sticker(sticker) assert 65234 in cache_impl._guild_entries - assert cache_impl._guild_entries[snowflakes.Snowflake(65234)].stickers - assert 5123123 in cache_impl._guild_entries[snowflakes.Snowflake(65234)].stickers + guild_entry_stickers = cache_impl._guild_entries[snowflakes.Snowflake(65234)].stickers + assert guild_entry_stickers is not None + assert 5123123 in guild_entry_stickers assert 5123123 in cache_impl._sticker_entries sticker_data = cache_impl._sticker_entries[snowflakes.Snowflake(5123123)] cache_impl._set_user.assert_called_once_with(mock_user) @@ -886,7 +956,7 @@ def test_set_sticker(self, cache_impl): assert sticker_data.tag == "lul" assert sticker_data.description == "Jax cute" - def test_set_sticker_with_pre_cached_sticker(self, cache_impl): + def test_set_sticker_with_pre_cached_sticker(self, cache_impl: cache_impl_.CacheImpl): mock_user = mock.Mock(users.User, id=snowflakes.Snowflake(654234)) sticker = stickers.GuildSticker( id=snowflakes.Snowflake(5123123), @@ -902,15 +972,17 @@ def test_set_sticker_with_pre_cached_sticker(self, cache_impl): {snowflakes.Snowflake(5123123): mock.Mock(cache_utilities.GuildStickerData)} ) cache_impl._set_user = mock.Mock() - cache_impl._increment_user_ref_count = mock.Mock() + cache_impl._increment_user_ref_count = ( + mock.Mock() + ) # FIXME: Seems this is calling an object that does not actually exist. cache_impl.set_sticker(sticker) assert 5123123 in cache_impl._sticker_entries cache_impl._set_user.assert_called_once_with(mock_user) - cache_impl._increment_user_ref_count.assert_not_called() + cache_impl._increment_user_ref_count.assert_not_called() # FIXME: Seems this is calling an object that does not actually exist. - def test_clear_guilds_when_no_guilds_cached(self, cache_impl): + def test_clear_guilds_when_no_guilds_cached(self, cache_impl: cache_impl_.CacheImpl): cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(423123): cache_utilities.GuildRecord(), @@ -926,7 +998,7 @@ def test_clear_guilds_when_no_guilds_cached(self, cache_impl): snowflakes.Snowflake(675345): cache_utilities.GuildRecord(), } - def test_clear_guilds(self, cache_impl): + def test_clear_guilds(self, cache_impl: cache_impl_.CacheImpl): mock_guild_1 = mock.MagicMock(guilds.GatewayGuild) mock_guild_2 = mock.MagicMock(guilds.GatewayGuild) mock_member = mock.MagicMock(guilds.Member) @@ -949,18 +1021,20 @@ def test_clear_guilds(self, cache_impl): assert cache_impl._guild_entries == { snowflakes.Snowflake(423123): cache_utilities.GuildRecord(), snowflakes.Snowflake(32142): cache_utilities.GuildRecord( - members={snowflakes.Snowflake(3241123): mock_member} + members=collections.FreezableDict({snowflakes.Snowflake(3241123): mock_member}) ), snowflakes.Snowflake(321132): cache_utilities.GuildRecord(), } - def test_delete_guild_for_known_guild(self, cache_impl): + def test_delete_guild_for_known_guild( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_guild = mock.Mock(guilds.GatewayGuild) mock_member = mock.Mock(guilds.Member) cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(354123): cache_utilities.GuildRecord(), - snowflakes.Snowflake(543123): cache_utilities.GuildRecord( + snowflakes.Snowflake(123): cache_utilities.GuildRecord( guild=mock_guild, is_available=True, members=collections.FreezableDict({snowflakes.Snowflake(43123): mock_member}), @@ -968,181 +1042,207 @@ def test_delete_guild_for_known_guild(self, cache_impl): } ) - result = cache_impl.delete_guild(StubModel(543123)) + result = cache_impl.delete_guild(hikari_partial_guild) assert result is mock_guild assert cache_impl._guild_entries == { snowflakes.Snowflake(354123): cache_utilities.GuildRecord(), - snowflakes.Snowflake(543123): cache_utilities.GuildRecord( - members={snowflakes.Snowflake(43123): mock_member} + snowflakes.Snowflake(123): cache_utilities.GuildRecord( + members=collections.FreezableDict({snowflakes.Snowflake(43123): mock_member}) ), } - def test_delete_guild_for_removes_emptied_record(self, cache_impl): + def test_delete_guild_for_removes_emptied_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_guild = mock.Mock(guilds.GatewayGuild) cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(354123): cache_utilities.GuildRecord(), - snowflakes.Snowflake(543123): cache_utilities.GuildRecord(guild=mock_guild, is_available=True), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(guild=mock_guild, is_available=True), } ) - result = cache_impl.delete_guild(StubModel(543123)) + result = cache_impl.delete_guild(hikari_partial_guild) assert result is mock_guild assert cache_impl._guild_entries == {snowflakes.Snowflake(354123): cache_utilities.GuildRecord()} - def test_delete_guild_for_unknown_guild(self, cache_impl): + def test_delete_guild_for_unknown_guild( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(354123): cache_utilities.GuildRecord(), - snowflakes.Snowflake(543123): cache_utilities.GuildRecord(), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(), } ) - result = cache_impl.delete_guild(StubModel(543123)) + result = cache_impl.delete_guild(hikari_partial_guild) assert result is None assert cache_impl._guild_entries == { snowflakes.Snowflake(354123): cache_utilities.GuildRecord(), - snowflakes.Snowflake(543123): cache_utilities.GuildRecord(), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(), } - def test_delete_guild_for_unknown_record(self, cache_impl): + def test_delete_guild_for_unknown_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): cache_impl._guild_entries = collections.FreezableDict( {snowflakes.Snowflake(354123): cache_utilities.GuildRecord()} ) - result = cache_impl.delete_guild(StubModel(543123)) + result = cache_impl.delete_guild(hikari_partial_guild) assert result is None assert cache_impl._guild_entries == {snowflakes.Snowflake(354123): cache_utilities.GuildRecord()} - def test_get_guild_first_tries_get_available_guilds(self, cache_impl): + def test_get_guild_first_tries_get_available_guilds( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_guild = mock.MagicMock(guilds.GatewayGuild) cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(54234123): cache_utilities.GuildRecord(), - snowflakes.Snowflake(543123): cache_utilities.GuildRecord(guild=mock_guild, is_available=True), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(guild=mock_guild, is_available=True), } ) - cached_guild = cache_impl.get_guild(StubModel(543123)) + cached_guild = cache_impl.get_guild(hikari_partial_guild) assert cached_guild == mock_guild assert cache_impl is not mock_guild - def test_get_guild_then_tries_get_unavailable_guilds(self, cache_impl): + def test_get_guild_then_tries_get_unavailable_guilds( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_guild = mock.MagicMock(guilds.GatewayGuild) cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(543123): cache_utilities.GuildRecord(is_available=True), - snowflakes.Snowflake(54234123): cache_utilities.GuildRecord(guild=mock_guild, is_available=False), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(guild=mock_guild, is_available=False), } ) - cached_guild = cache_impl.get_guild(StubModel(54234123)) + cached_guild = cache_impl.get_guild(hikari_partial_guild) assert cached_guild == mock_guild assert cache_impl is not mock_guild - def test_get_available_guild_for_known_guild_when_available(self, cache_impl): + def test_get_available_guild_for_known_guild_when_available( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_guild = mock.MagicMock(guilds.GatewayGuild) cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(54234123): cache_utilities.GuildRecord(), - snowflakes.Snowflake(543123): cache_utilities.GuildRecord(guild=mock_guild, is_available=True), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(guild=mock_guild, is_available=True), } ) - cached_guild = cache_impl.get_available_guild(StubModel(543123)) + cached_guild = cache_impl.get_available_guild(hikari_partial_guild) assert cached_guild == mock_guild assert cache_impl is not mock_guild - def test_get_available_guild_for_known_guild_when_unavailable(self, cache_impl): + def test_get_available_guild_for_known_guild_when_unavailable( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_guild = mock.Mock(guilds.GatewayGuild) cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(54234123): cache_utilities.GuildRecord(), - snowflakes.Snowflake(543123): cache_utilities.GuildRecord(guild=mock_guild, is_available=False), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(guild=mock_guild, is_available=False), } ) - result = cache_impl.get_available_guild(StubModel(543123)) + result = cache_impl.get_available_guild(hikari_partial_guild) assert result is None - def test_get_available_guild_for_unknown_guild(self, cache_impl): + def test_get_available_guild_for_unknown_guild( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(54234123): cache_utilities.GuildRecord(), - snowflakes.Snowflake(543123): cache_utilities.GuildRecord(), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(), } ) - result = cache_impl.get_available_guild(StubModel(543123)) + result = cache_impl.get_available_guild(hikari_partial_guild) assert result is None - def test_get_available_guild_for_unknown_guild_record(self, cache_impl): + def test_get_available_guild_for_unknown_guild_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): cache_impl._guild_entries = collections.FreezableDict( {snowflakes.Snowflake(54234123): cache_utilities.GuildRecord()} ) - result = cache_impl.get_available_guild(StubModel(543123)) + result = cache_impl.get_available_guild(hikari_partial_guild) assert result is None - def test_get_unavailable_guild_for_known_guild_when_unavailable(self, cache_impl): + def test_get_unavailable_guild_for_known_guild_when_unavailable( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_guild = mock.MagicMock(guilds.GatewayGuild) cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(54234123): cache_utilities.GuildRecord(), - snowflakes.Snowflake(452131): cache_utilities.GuildRecord(guild=mock_guild, is_available=False), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(guild=mock_guild, is_available=False), } ) - cached_guild = cache_impl.get_unavailable_guild(StubModel(452131)) + cached_guild = cache_impl.get_unavailable_guild(hikari_partial_guild) assert cached_guild == mock_guild assert cache_impl is not mock_guild - def test_get_unavailable_guild_for_known_guild_when_available(self, cache_impl): + def test_get_unavailable_guild_for_known_guild_when_available( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_guild = mock.Mock(guilds.GatewayGuild) cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(54234123): cache_utilities.GuildRecord(), - snowflakes.Snowflake(654234): cache_utilities.GuildRecord(guild=mock_guild, is_available=True), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(guild=mock_guild, is_available=True), } ) - result = cache_impl.get_unavailable_guild(StubModel(654234)) + result = cache_impl.get_unavailable_guild(hikari_partial_guild) assert result is None - def test_get_unavailable_guild_for_unknown_guild(self, cache_impl): + def test_get_unavailable_guild_for_unknown_guild( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(54234123): cache_utilities.GuildRecord(), - snowflakes.Snowflake(543123): cache_utilities.GuildRecord(), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(), } ) - result = cache_impl.get_unavailable_guild(StubModel(543123)) + result = cache_impl.get_unavailable_guild(hikari_partial_guild) assert result is None - def test_get_unavailable_guild_for_unknown_guild_record(self, cache_impl): + def test_get_unavailable_guild_for_unknown_guild_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): cache_impl._guild_entries = collections.FreezableDict( {snowflakes.Snowflake(54234123): cache_utilities.GuildRecord()} ) - result = cache_impl.get_unavailable_guild(StubModel(543123)) + result = cache_impl.get_unavailable_guild(hikari_partial_guild) assert result is None - def test_get_guilds_view(self, cache_impl): + def test_get_guilds_view(self, cache_impl: cache_impl_.CacheImpl): mock_guild_1 = mock.MagicMock(guilds.GatewayGuild) mock_guild_2 = mock.MagicMock(guilds.GatewayGuild) mock_guild_3 = mock.MagicMock(guilds.GatewayGuild) @@ -1163,7 +1263,7 @@ def test_get_guilds_view(self, cache_impl): snowflakes.Snowflake(65234): mock_guild_3, } - def test_get_available_guilds_view(self, cache_impl): + def test_get_available_guilds_view(self, cache_impl: cache_impl_.CacheImpl): mock_guild_1 = mock.MagicMock(guilds.GatewayGuild) mock_guild_2 = mock.MagicMock(guilds.GatewayGuild) cache_impl._guild_entries = collections.FreezableDict( @@ -1171,7 +1271,7 @@ def test_get_available_guilds_view(self, cache_impl): snowflakes.Snowflake(4312312): cache_utilities.GuildRecord(guild=mock_guild_1, is_available=True), snowflakes.Snowflake(34123): cache_utilities.GuildRecord(), snowflakes.Snowflake(73453): cache_utilities.GuildRecord(guild=mock_guild_2, is_available=True), - snowflakes.Snowflake(6554234): cache_utilities.GuildRecord(guild=object(), is_available=False), + snowflakes.Snowflake(6554234): cache_utilities.GuildRecord(guild=mock.Mock(), is_available=False), } ) @@ -1179,7 +1279,7 @@ def test_get_available_guilds_view(self, cache_impl): assert result == {snowflakes.Snowflake(4312312): mock_guild_1, snowflakes.Snowflake(73453): mock_guild_2} - def test_get_available_guilds_view_when_no_guilds_cached(self, cache_impl): + def test_get_available_guilds_view_when_no_guilds_cached(self, cache_impl: cache_impl_.CacheImpl): cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(4312312): cache_utilities.GuildRecord(), @@ -1192,7 +1292,7 @@ def test_get_available_guilds_view_when_no_guilds_cached(self, cache_impl): assert result == {} - def test_get_unavailable_guilds_view(self, cache_impl): + def test_get_unavailable_guilds_view(self, cache_impl: cache_impl_.CacheImpl): mock_guild_1 = mock.MagicMock(guilds.GatewayGuild) mock_guild_2 = mock.MagicMock(guilds.GatewayGuild) cache_impl._guild_entries = collections.FreezableDict( @@ -1200,7 +1300,7 @@ def test_get_unavailable_guilds_view(self, cache_impl): snowflakes.Snowflake(4312312): cache_utilities.GuildRecord(guild=mock_guild_1, is_available=False), snowflakes.Snowflake(34123): cache_utilities.GuildRecord(), snowflakes.Snowflake(73453): cache_utilities.GuildRecord(guild=mock_guild_2, is_available=False), - snowflakes.Snowflake(6554234): cache_utilities.GuildRecord(guild=object(), is_available=True), + snowflakes.Snowflake(6554234): cache_utilities.GuildRecord(guild=mock.Mock(), is_available=True), } ) @@ -1208,7 +1308,7 @@ def test_get_unavailable_guilds_view(self, cache_impl): assert result == {snowflakes.Snowflake(4312312): mock_guild_1, snowflakes.Snowflake(73453): mock_guild_2} - def test_get_unavailable_guilds_view_when_no_guilds_cached(self, cache_impl): + def test_get_unavailable_guilds_view_when_no_guilds_cached(self, cache_impl: cache_impl_.CacheImpl): cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(4312312): cache_utilities.GuildRecord(), @@ -1221,7 +1321,7 @@ def test_get_unavailable_guilds_view_when_no_guilds_cached(self, cache_impl): assert result == {} - def test_set_guild(self, cache_impl): + def test_set_guild(self, cache_impl: cache_impl_.CacheImpl): mock_guild = mock.MagicMock(guilds.GatewayGuild, id=snowflakes.Snowflake(5123123)) cache_impl.set_guild(mock_guild) @@ -1231,49 +1331,55 @@ def test_set_guild(self, cache_impl): assert cache_impl._guild_entries[snowflakes.Snowflake(5123123)].guild is not mock_guild assert cache_impl._guild_entries[snowflakes.Snowflake(5123123)].is_available is True - def test_set_guild_availability_for_cached_guild(self, cache_impl): - cache_impl._guild_entries = {snowflakes.Snowflake(43123): cache_utilities.GuildRecord(guild=object())} + def test_set_guild_availability_for_cached_guild( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): + cache_impl._guild_entries = collections.FreezableDict( + {snowflakes.Snowflake(123): cache_utilities.GuildRecord(guild=mock.Mock())} + ) - cache_impl.set_guild_availability(StubModel(43123), True) + cache_impl.set_guild_availability(hikari_partial_guild, True) - assert cache_impl._guild_entries[snowflakes.Snowflake(43123)].is_available is True + assert cache_impl._guild_entries[snowflakes.Snowflake(123)].is_available is True - def test_set_guild_availability_for_uncached_guild(self, cache_impl): - cache_impl.set_guild_availability(StubModel(452234123), True) + def test_set_guild_availability_for_uncached_guild( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): + cache_impl.set_guild_availability(hikari_partial_guild, True) assert 452234123 not in cache_impl._guild_entries @pytest.mark.skip(reason="TODO") - def test_update_guild(self, cache_impl): ... + def test_update_guild(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_clear_guild_channels(self, cache_impl): ... + def test_clear_guild_channels(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_clear_guild_channels_for_guild(self, cache_impl): ... + def test_clear_guild_channels_for_guild(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_delete_guild_channel(self, cache_impl): ... + def test_delete_guild_channel(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_get_guild_channel(self, cache_impl): ... + def test_get_guild_channel(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_get_guild_channels_view(self, cache_impl): ... + def test_get_guild_channels_view(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_get_guild_channels_view_for_guild(self, cache_impl): ... + def test_get_guild_channels_view_for_guild(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_set_guild_channel(self, cache_impl): ... + def test_set_guild_channel(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_update_guild_channel(self, cache_impl): ... + def test_update_guild_channel(self, cache_impl: cache_impl_.CacheImpl): ... - def test__build_invite(self, cache_impl): + def test__build_invite(self, cache_impl: cache_impl_.CacheImpl): mock_inviter = mock.MagicMock(users.User) mock_target_user = mock.MagicMock(users.User) - mock_application = object() + mock_application = mock.Mock() invite_data = cache_utilities.InviteData( code="okokok", guild_id=snowflakes.Snowflake(965234), @@ -1311,7 +1417,7 @@ def test__build_invite(self, cache_impl): assert invite.is_temporary is True assert invite.created_at == datetime.datetime(2020, 7, 30, 7, 22, 9, 550233, tzinfo=datetime.timezone.utc) - def test__build_invite_without_users(self, cache_impl): + def test__build_invite_without_users(self, cache_impl: cache_impl_.CacheImpl): invite_data = cache_utilities.InviteData( code="okokok", guild_id=snowflakes.Snowflake(965234), @@ -1332,7 +1438,7 @@ def test__build_invite_without_users(self, cache_impl): assert invite.inviter is None assert invite.target_user is None - def test_clear_invites(self, cache_impl): + def test_clear_invites(self, cache_impl: cache_impl_.CacheImpl): mock_target_user = mock.Mock(cache_utilities.RefCell[users.User], ref_count=5) mock_inviter = mock.Mock(cache_utilities.RefCell[users.User], ref_count=3) mock_invite_data_1 = mock.Mock(cache_utilities.InviteData, target_user=mock_target_user, inviter=mock_inviter) @@ -1354,7 +1460,9 @@ def test_clear_invites(self, cache_impl): ) cache_impl._build_invite.assert_has_calls([mock.call(mock_invite_data_1), mock.call(mock_invite_data_2)]) - def test_clear_invites_for_guild(self, cache_impl): + def test_clear_invites_for_guild( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_target_user = mock.Mock(cache_utilities.RefCell[users.User], ref_count=4) mock_inviter = mock.Mock(cache_utilities.RefCell[users.User], ref_count=42) mock_invite_data_1 = mock.Mock(cache_utilities.InviteData, target_user=mock_target_user, inviter=mock_inviter) @@ -1372,15 +1480,13 @@ def test_clear_invites_for_guild(self, cache_impl): cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(54123): mock.Mock(cache_utilities.GuildRecord), - snowflakes.Snowflake(999888777): cache_utilities.GuildRecord( - invites=["oeoeoeoeooe", "owowowowoowowow"] - ), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(invites=["oeoeoeoeooe", "owowowowoowowow"]), } ) cache_impl._garbage_collect_user = mock.Mock() cache_impl._build_invite = mock.Mock(side_effect=[mock_invite_1, mock_invite_2]) - result = cache_impl.clear_invites_for_guild(StubModel(999888777)) + result = cache_impl.clear_invites_for_guild(hikari_partial_guild) assert result == {"oeoeoeoeooe": mock_invite_1, "owowowowoowowow": mock_invite_2} assert cache_impl._invite_entries == {"oeoeoeoeoeoeoe": mock_other_invite_data} @@ -1389,9 +1495,11 @@ def test_clear_invites_for_guild(self, cache_impl): ) cache_impl._build_invite.assert_has_calls([mock.call(mock_invite_data_1), mock.call(mock_invite_data_2)]) - def test_clear_invites_for_guild_unknown_invite_cache(self, cache_impl): + def test_clear_invites_for_guild_unknown_invite_cache( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_other_invite_data = mock.Mock(cache_utilities.InviteData) - cache_impl._invite_entries = {"oeoeoeoeoeoeoe": mock_other_invite_data} + cache_impl._invite_entries = collections.FreezableDict({"oeoeoeoeoeoeoe": mock_other_invite_data}) cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(54123): mock.Mock(cache_utilities.GuildRecord), @@ -1400,13 +1508,15 @@ def test_clear_invites_for_guild_unknown_invite_cache(self, cache_impl): ) cache_impl._build_invite = mock.Mock() - result = cache_impl.clear_invites_for_guild(StubModel(765234123)) + result = cache_impl.clear_invites_for_guild(hikari_partial_guild) assert result == {} assert cache_impl._invite_entries == {"oeoeoeoeoeoeoe": mock_other_invite_data} cache_impl._build_invite.assert_not_called() - def test_clear_invites_for_guild_unknown_record(self, cache_impl): + def test_clear_invites_for_guild_unknown_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_other_invite_data = mock.Mock(cache_utilities.InviteData) cache_impl._invite_entries = collections.FreezableDict({"oeoeoeoeoeoeoe": mock_other_invite_data}) cache_impl._guild_entries = collections.FreezableDict( @@ -1414,23 +1524,28 @@ def test_clear_invites_for_guild_unknown_record(self, cache_impl): ) cache_impl._build_invite = mock.Mock() - result = cache_impl.clear_invites_for_guild(StubModel(765234123)) + result = cache_impl.clear_invites_for_guild(hikari_partial_guild) assert result == {} assert cache_impl._invite_entries == {"oeoeoeoeoeoeoe": mock_other_invite_data} cache_impl._build_invite.assert_not_called() - def test_clear_invites_for_channel(self, cache_impl): + def test_clear_invites_for_channel( + self, + cache_impl: cache_impl_.CacheImpl, + hikari_partial_guild: guilds.PartialGuild, + hikari_guild_text_channel: GuildTextChannel, + ): mock_target_user = mock.Mock(cache_utilities.RefCell[users.User], ref_count=42) mock_inviter = mock.Mock(cache_utilities.RefCell[users.User], ref_count=280) mock_invite_data_1 = mock.Mock( cache_utilities.InviteData, target_user=mock_target_user, inviter=mock_inviter, - channel_id=snowflakes.Snowflake(34123123), + channel_id=snowflakes.Snowflake(4560), ) mock_invite_data_2 = mock.Mock( - cache_utilities.InviteData, target_user=None, inviter=None, channel_id=snowflakes.Snowflake(34123123) + cache_utilities.InviteData, target_user=None, inviter=None, channel_id=snowflakes.Snowflake(4560) ) mock_other_invite_data = mock.Mock(cache_utilities.InviteData, channel_id=snowflakes.Snowflake(9484732)) mock_other_invite_data_2 = mock.Mock(cache_utilities.InviteData) @@ -1447,7 +1562,7 @@ def test_clear_invites_for_channel(self, cache_impl): cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(54123): mock.Mock(cache_utilities.GuildRecord), - snowflakes.Snowflake(999888777): cache_utilities.GuildRecord( + snowflakes.Snowflake(123): cache_utilities.GuildRecord( invites=["oeoeoeoeooe", "owowowowoowowow", "oeoeoeoeoeoeoe"] ), } @@ -1455,18 +1570,23 @@ def test_clear_invites_for_channel(self, cache_impl): cache_impl._build_invite = mock.Mock(side_effect=[mock_invite_1, mock_invite_2]) cache_impl._garbage_collect_user = mock.Mock() - result = cache_impl.clear_invites_for_channel(StubModel(999888777), StubModel(34123123)) + result = cache_impl.clear_invites_for_channel(hikari_partial_guild, hikari_guild_text_channel) assert result == {"oeoeoeoeooe": mock_invite_1, "owowowowoowowow": mock_invite_2} cache_impl._garbage_collect_user.assert_has_calls( [mock.call(mock_target_user, decrement=1), mock.call(mock_inviter, decrement=1)], any_order=True ) - assert cache_impl._guild_entries[snowflakes.Snowflake(999888777)].invites == ["oeoeoeoeoeoeoe"] + assert cache_impl._guild_entries[snowflakes.Snowflake(123)].invites == ["oeoeoeoeoeoeoe"] assert cache_impl._invite_entries == {"oeoeoeoeoeoeoe": mock_other_invite_data, "oeo": mock_other_invite_data_2} cache_impl._build_invite.assert_has_calls([mock.call(mock_invite_data_1), mock.call(mock_invite_data_2)]) - def test_clear_invites_for_channel_unknown_invite_cache(self, cache_impl): + def test_clear_invites_for_channel_unknown_invite_cache( + self, + cache_impl: cache_impl_.CacheImpl, + hikari_partial_guild: guilds.PartialGuild, + hikari_guild_text_channel: GuildTextChannel, + ): mock_other_invite_data = mock.Mock(cache_utilities.InviteData) cache_impl._invite_entries = collections.FreezableDict({"oeoeoeoeoeoeoe": mock_other_invite_data}) cache_impl._user_entries = collections.FreezableDict( @@ -1480,13 +1600,18 @@ def test_clear_invites_for_channel_unknown_invite_cache(self, cache_impl): ) cache_impl._build_invite = mock.Mock() - result = cache_impl.clear_invites_for_channel(StubModel(765234123), StubModel(12365345)) + result = cache_impl.clear_invites_for_channel(hikari_partial_guild, hikari_guild_text_channel) assert result == {} assert cache_impl._invite_entries == {"oeoeoeoeoeoeoe": mock_other_invite_data} cache_impl._build_invite.assert_not_called() - def test_clear_invites_for_channel_unknown_record(self, cache_impl): + def test_clear_invites_for_channel_unknown_record( + self, + cache_impl: cache_impl_.CacheImpl, + hikari_partial_guild: guilds.PartialGuild, + hikari_guild_text_channel: GuildTextChannel, + ): mock_other_invite_data = mock.Mock(cache_utilities.InviteData) cache_impl._invite_entries = collections.FreezableDict({"oeoeoeoeoeoeoe": mock_other_invite_data}) cache_impl._user_entries = collections.FreezableDict( @@ -1497,13 +1622,13 @@ def test_clear_invites_for_channel_unknown_record(self, cache_impl): ) cache_impl._build_invite = mock.Mock() - result = cache_impl.clear_invites_for_channel(StubModel(765234123), StubModel(76234123)) + result = cache_impl.clear_invites_for_channel(hikari_partial_guild, hikari_guild_text_channel) assert result == {} assert cache_impl._invite_entries == {"oeoeoeoeoeoeoe": mock_other_invite_data} cache_impl._build_invite.assert_not_called() - def test_delete_invite(self, cache_impl): + def test_delete_invite(self, cache_impl: cache_impl_.CacheImpl): mock_inviter = mock.Mock(users.User, id=snowflakes.Snowflake(543123)) mock_target_user = mock.Mock(users.User, id=snowflakes.Snowflake(9191919)) mock_invite_data = mock.Mock( @@ -1513,7 +1638,7 @@ def test_delete_invite(self, cache_impl): target_user=mock_target_user, ) mock_other_invite_data = mock.Mock(cache_utilities.InviteData) - mock_invite = object() + mock_invite = mock.Mock() cache_impl._invite_entries = collections.FreezableDict( {"blamSpat": mock_other_invite_data, "oooooooooooooo": mock_invite_data} ) @@ -1538,7 +1663,7 @@ def test_delete_invite(self, cache_impl): assert cache_impl._invite_entries == {"blamSpat": mock_other_invite_data} assert cache_impl._guild_entries[snowflakes.Snowflake(999999999)].invites == ["ok", "blat"] - def test_delete_invite_with_invite_object(self, cache_impl): + def test_delete_invite_with_invite_object(self, cache_impl: cache_impl_.CacheImpl): mock_invite_data = mock.Mock(cache_utilities.InviteData) mock_other_invite_data = mock.Mock(cache_utilities.InviteData) mock_invite = mock.Mock(invites.InviteWithMetadata, inviter=None, target_user=None, guild_id=None) @@ -1554,7 +1679,7 @@ def test_delete_invite_with_invite_object(self, cache_impl): cache_impl._build_invite.assert_called_once_with(mock_invite_data) assert cache_impl._invite_entries == {"blamSpat": mock_other_invite_data} - def test_delete_invite_when_guild_id_is_None(self, cache_impl): + def test_delete_invite_when_guild_id_is_None(self, cache_impl: cache_impl_.CacheImpl): mock_invite_data = mock.Mock(cache_utilities.InviteData) mock_other_invite_data = mock.Mock(cache_utilities.InviteData) mock_invite = mock.Mock(invites.InviteWithMetadata, inviter=None, target_user=None, guild_id=None) @@ -1572,12 +1697,12 @@ def test_delete_invite_when_guild_id_is_None(self, cache_impl): cache_impl._remove_guild_record_if_empty.assert_not_called() assert cache_impl._invite_entries == {"blamSpat": mock_other_invite_data} - def test_delete_invite_without_users(self, cache_impl): + def test_delete_invite_without_users(self, cache_impl: cache_impl_.CacheImpl): mock_invite_data = mock.Mock( cache_utilities.InviteData, inviter=None, target_user=None, guild_id=snowflakes.Snowflake(999999999) ) mock_other_invite_data = mock.Mock(cache_utilities.InviteData) - mock_invite = object() + mock_invite = mock.Mock() cache_impl._invite_entries = collections.FreezableDict( {"blamSpat": mock_other_invite_data, "oooooooooooooo": mock_invite_data} ) @@ -1599,7 +1724,7 @@ def test_delete_invite_without_users(self, cache_impl): assert cache_impl._invite_entries == {"blamSpat": mock_other_invite_data} assert cache_impl._guild_entries[snowflakes.Snowflake(999999999)].invites == ["ok", "blat"] - def test_delete_invite_for_unknown_invite(self, cache_impl): + def test_delete_invite_for_unknown_invite(self, cache_impl: cache_impl_.CacheImpl): cache_impl._build_invite = mock.Mock() cache_impl._garbage_collect_user = mock.Mock() # TODO: test this is called @@ -1611,7 +1736,7 @@ def test_delete_invite_for_unknown_invite(self, cache_impl): cache_impl._build_invite.assert_not_called() cache_impl._garbage_collect_user.assert_not_called() - def test_get_invite(self, cache_impl): + def test_get_invite(self, cache_impl: cache_impl_.CacheImpl): mock_invite_data = mock.Mock(cache_utilities.InviteData) mock_invite = mock.Mock(invites.InviteWithMetadata) cache_impl._build_invite = mock.Mock(return_value=mock_invite) @@ -1624,7 +1749,7 @@ def test_get_invite(self, cache_impl): assert result is mock_invite cache_impl._build_invite.assert_called_once_with(mock_invite_data) - def test_get_invite_with_invite_object(self, cache_impl): + def test_get_invite_with_invite_object(self, cache_impl: cache_impl_.CacheImpl): mock_invite_data = mock.Mock(cache_utilities.InviteData) mock_invite = mock.Mock(invites.InviteWithMetadata) cache_impl._build_invite = mock.Mock(return_value=mock_invite) @@ -1637,13 +1762,13 @@ def test_get_invite_with_invite_object(self, cache_impl): assert result is mock_invite cache_impl._build_invite.assert_called_once_with(mock_invite_data) - def test_get_invite_for_unknown_invite(self, cache_impl): + def test_get_invite_for_unknown_invite(self, cache_impl: cache_impl_.CacheImpl): cache_impl._build_invite = mock.Mock() cache_impl._invite_entries = collections.FreezableDict({"blam": mock.Mock(cache_utilities.InviteData)}) assert cache_impl.get_invite("okokok") is None cache_impl._build_invite.assert_not_called() - def test_get_invites_view(self, cache_impl): + def test_get_invites_view(self, cache_impl: cache_impl_.CacheImpl): mock_invite_data_1 = mock.Mock(cache_utilities.InviteData) mock_invite_data_2 = mock.Mock(cache_utilities.InviteData) mock_invite_1 = mock.Mock(invites.InviteWithMetadata) @@ -1658,7 +1783,9 @@ def test_get_invites_view(self, cache_impl): assert result == {"okok": mock_invite_1, "blamblam": mock_invite_2} cache_impl._build_invite.assert_has_calls([mock.call(mock_invite_data_1), mock.call(mock_invite_data_2)]) - def test_get_invites_view_for_guild(self, cache_impl): + def test_get_invites_view_for_guild( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_invite_data_1 = mock.Mock(cache_utilities.InviteData) mock_invite_data_2 = mock.Mock(cache_utilities.InviteData) mock_invite_1 = mock.Mock(invites.InviteWithMetadata) @@ -1673,34 +1800,38 @@ def test_get_invites_view_for_guild(self, cache_impl): cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(9544994): mock.Mock(cache_utilities.GuildRecord), - snowflakes.Snowflake(4444444): cache_utilities.GuildRecord(invites=["okok", "dsaytert"]), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(invites=["okok", "dsaytert"]), } ) cache_impl._build_invite = mock.Mock(side_effect=[mock_invite_1, mock_invite_2]) - result = cache_impl.get_invites_view_for_guild(StubModel(4444444)) + result = cache_impl.get_invites_view_for_guild(hikari_partial_guild) assert result == {"okok": mock_invite_1, "dsaytert": mock_invite_2} cache_impl._build_invite.assert_has_calls([mock.call(mock_invite_data_1), mock.call(mock_invite_data_2)]) - def test_get_invites_view_for_guild_unknown_emoji_cache(self, cache_impl): + def test_get_invites_view_for_guild_unknown_emoji_cache( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): cache_impl._invite_entries = collections.FreezableDict( {"okok": mock.Mock(cache_utilities.InviteData), "dsaytert": mock.Mock(cache_utilities.InviteData)} ) cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(9544994): mock.Mock(cache_utilities.GuildRecord), - snowflakes.Snowflake(4444444): cache_utilities.GuildRecord(invites=None), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(invites=None), } ) cache_impl._build_invite = mock.Mock() - result = cache_impl.get_invites_view_for_guild(StubModel(4444444)) + result = cache_impl.get_invites_view_for_guild(hikari_partial_guild) assert result == {} cache_impl._build_invite.assert_not_called() - def test_get_invites_view_for_guild_unknown_record(self, cache_impl): + def test_get_invites_view_for_guild_unknown_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): cache_impl._invite_entries = collections.FreezableDict( {"okok": mock.Mock(cache_utilities.InviteData), "dsaytert": mock.Mock(cache_utilities.InviteData)} ) @@ -1709,14 +1840,19 @@ def test_get_invites_view_for_guild_unknown_record(self, cache_impl): ) cache_impl._build_invite = mock.Mock() - result = cache_impl.get_invites_view_for_guild(StubModel(4444444)) + result = cache_impl.get_invites_view_for_guild(hikari_partial_guild) assert result == {} cache_impl._build_invite.assert_not_called() - def test_get_invites_view_for_channel(self, cache_impl): - mock_invite_data_1 = mock.Mock(channel_id=snowflakes.Snowflake(987987), code="blamBang") - mock_invite_data_2 = mock.Mock(channel_id=snowflakes.Snowflake(987987), code="bingBong") + def test_get_invites_view_for_channel( + self, + cache_impl: cache_impl_.CacheImpl, + hikari_partial_guild: guilds.PartialGuild, + hikari_guild_text_channel: GuildTextChannel, + ): + mock_invite_data_1 = mock.Mock(channel_id=snowflakes.Snowflake(4560), code="blamBang") + mock_invite_data_2 = mock.Mock(channel_id=snowflakes.Snowflake(4560), code="bingBong") mock_invite_1 = mock.Mock(invites.InviteWithMetadata) mock_invite_2 = mock.Mock(invites.InviteWithMetadata) cache_impl._invite_entries = collections.FreezableDict( @@ -1730,34 +1866,44 @@ def test_get_invites_view_for_channel(self, cache_impl): cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(31423): mock.Mock(cache_utilities.GuildRecord), - snowflakes.Snowflake(83452134): cache_utilities.GuildRecord(invites=["blamBang", "bingBong", "Pop"]), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(invites=["blamBang", "bingBong", "Pop"]), } ) cache_impl._build_invite = mock.Mock(side_effect=[mock_invite_1, mock_invite_2]) - result = cache_impl.get_invites_view_for_channel(StubModel(83452134), StubModel(987987)) + result = cache_impl.get_invites_view_for_channel(hikari_partial_guild, hikari_guild_text_channel) assert result == {"blamBang": mock_invite_1, "bingBong": mock_invite_2} cache_impl._build_invite.assert_has_calls([mock.call(mock_invite_data_1), mock.call(mock_invite_data_2)]) - def test_get_invites_view_for_channel_unknown_emoji_cache(self, cache_impl): + def test_get_invites_view_for_channel_unknown_emoji_cache( + self, + cache_impl: cache_impl_.CacheImpl, + hikari_partial_guild: guilds.PartialGuild, + hikari_guild_text_channel: GuildTextChannel, + ): cache_impl._invite_entries = collections.FreezableDict( {"okok": mock.Mock(cache_utilities.InviteData), "dsaytert": mock.Mock(cache_utilities.InviteData)} ) cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(9544994): mock.Mock(cache_utilities.GuildRecord), - snowflakes.Snowflake(4444444): cache_utilities.GuildRecord(invites=None), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(invites=None), } ) cache_impl._build_invite = mock.Mock() - result = cache_impl.get_invites_view_for_channel(StubModel(4444444), StubModel(942123)) + result = cache_impl.get_invites_view_for_channel(hikari_partial_guild, hikari_guild_text_channel) assert result == {} cache_impl._build_invite.assert_not_called() - def test_get_invites_view_for_channel_unknown_record(self, cache_impl): + def test_get_invites_view_for_channel_unknown_record( + self, + cache_impl: cache_impl_.CacheImpl, + hikari_partial_guild: guilds.PartialGuild, + hikari_guild_text_channel: GuildTextChannel, + ): cache_impl._invite_entries = collections.FreezableDict( {"okok": mock.Mock(cache_utilities.InviteData), "dsaytert": mock.Mock(cache_utilities.InviteData)} ) @@ -1766,12 +1912,12 @@ def test_get_invites_view_for_channel_unknown_record(self, cache_impl): ) cache_impl._build_invite = mock.Mock() - result = cache_impl.get_invites_view_for_channel(StubModel(4444444), StubModel(9543123)) + result = cache_impl.get_invites_view_for_channel(hikari_partial_guild, hikari_guild_text_channel) assert result == {} cache_impl._build_invite.assert_not_called() - def test_update_invite(self, cache_impl): + def test_update_invite(self, cache_impl: cache_impl_.CacheImpl): mock_old_invite = mock.Mock(invites.InviteWithMetadata) mock_new_invite = mock.Mock(invites.InviteWithMetadata) mock_invite = mock.Mock(invites.InviteWithMetadata, code="biggieSmall") @@ -1784,7 +1930,7 @@ def test_update_invite(self, cache_impl): cache_impl.get_invite.assert_has_calls([mock.call("biggieSmall"), mock.call("biggieSmall")]) cache_impl.set_invite.assert_called_once_with(mock_invite) - def test_delete_me_for_known_me(self, cache_impl): + def test_delete_me_for_known_me(self, cache_impl: cache_impl_.CacheImpl): mock_own_user = mock.Mock(users.OwnUser) cache_impl._me = mock_own_user @@ -1793,13 +1939,13 @@ def test_delete_me_for_known_me(self, cache_impl): assert result is mock_own_user assert cache_impl._me is None - def test_delete_me_for_unknown_me(self, cache_impl): + def test_delete_me_for_unknown_me(self, cache_impl: cache_impl_.CacheImpl): result = cache_impl.delete_me() assert result is None assert cache_impl._me is None - def test_get_me_for_known_me(self, cache_impl): + def test_get_me_for_known_me(self, cache_impl: cache_impl_.CacheImpl): mock_own_user = mock.MagicMock(users.OwnUser) cache_impl._me = mock_own_user @@ -1808,10 +1954,10 @@ def test_get_me_for_known_me(self, cache_impl): assert cached_me == mock_own_user assert cached_me is not mock_own_user - def test_get_me_for_unknown_me(self, cache_impl): + def test_get_me_for_unknown_me(self, cache_impl: cache_impl_.CacheImpl): assert cache_impl.get_me() is None - def test_set_me(self, cache_impl): + def test_set_me(self, cache_impl: cache_impl_.CacheImpl): mock_own_user = mock.MagicMock(users.OwnUser) cache_impl.set_me(mock_own_user) @@ -1819,14 +1965,14 @@ def test_set_me(self, cache_impl): assert cache_impl._me == mock_own_user assert cache_impl._me is not mock_own_user - def test_set_me_when_not_enabled(self, cache_impl): + def test_set_me_when_not_enabled(self, cache_impl: cache_impl_.CacheImpl): cache_impl._settings.components = 0 - cache_impl.set_me(object()) + cache_impl.set_me(mock.Mock()) assert cache_impl._me is None - def test_update_me_for_cached_me(self, cache_impl): + def test_update_me_for_cached_me(self, cache_impl: cache_impl_.CacheImpl): mock_cached_own_user = mock.MagicMock(users.OwnUser) mock_own_user = mock.MagicMock(users.OwnUser) cache_impl._me = mock_cached_own_user @@ -1836,7 +1982,7 @@ def test_update_me_for_cached_me(self, cache_impl): assert result == (mock_cached_own_user, mock_own_user) assert cache_impl._me == mock_own_user - def test_update_me_for_uncached_me(self, cache_impl): + def test_update_me_for_uncached_me(self, cache_impl: cache_impl_.CacheImpl): mock_own_user = mock.MagicMock(users.OwnUser) result = cache_impl.update_me(mock_own_user) @@ -1844,18 +1990,18 @@ def test_update_me_for_uncached_me(self, cache_impl): assert result == (None, mock_own_user) assert cache_impl._me == mock_own_user - def test_update_me_for_when_not_enabled(self, cache_impl): + def test_update_me_for_when_not_enabled(self, cache_impl: cache_impl_.CacheImpl): cache_impl._settings.components = 0 cache_impl.get_me = mock.Mock() cache_impl.set_me = mock.Mock() - result = cache_impl.update_me(object()) + result = cache_impl.update_me(mock.Mock()) assert result == (None, None) cache_impl.get_me.assert_not_called() cache_impl.set_me.assert_not_called() - def test__build_member(self, cache_impl): + def test__build_member(self, cache_impl: cache_impl_.CacheImpl): mock_user = mock.MagicMock(users.User) member_data = cache_utilities.MemberData( user=cache_utilities.RefCell(mock_user), @@ -1895,13 +2041,13 @@ def test__build_member(self, cache_impl): ) assert member.guild_flags == guilds.GuildMemberFlags.DID_REJOIN - def test_clear_members(self, cache_impl): + def test_clear_members(self, cache_impl: cache_impl_.CacheImpl): mock_user_1 = cache_utilities.RefCell(mock.Mock(id=snowflakes.Snowflake(2123123))) mock_user_2 = cache_utilities.RefCell(mock.Mock(id=snowflakes.Snowflake(212314423))) mock_user_3 = cache_utilities.RefCell(mock.Mock(id=snowflakes.Snowflake(2123166623))) mock_user_4 = cache_utilities.RefCell(mock.Mock(id=snowflakes.Snowflake(21237777123))) mock_user_5 = cache_utilities.RefCell(mock.Mock(id=snowflakes.Snowflake(212399999123))) - mock_data_member_1 = cache_utilities.RefCell( + mock_data_member_1: cache_utilities.RefCell[cache_utilities.MemberData] = cache_utilities.RefCell( mock.Mock( cache_utilities.MemberData, user=mock_user_1, @@ -1909,7 +2055,7 @@ def test_clear_members(self, cache_impl): has_been_deleted=False, ) ) - mock_data_member_2 = cache_utilities.RefCell( + mock_data_member_2: cache_utilities.RefCell[cache_utilities.MemberData] = cache_utilities.RefCell( mock.Mock( cache_utilities.MemberData, user=mock_user_2, @@ -1917,7 +2063,7 @@ def test_clear_members(self, cache_impl): has_been_deleted=False, ) ) - mock_data_member_3 = cache_utilities.RefCell( + mock_data_member_3: cache_utilities.RefCell[cache_utilities.MemberData] = cache_utilities.RefCell( mock.Mock( cache_utilities.MemberData, user=mock_user_3, @@ -1925,7 +2071,7 @@ def test_clear_members(self, cache_impl): has_been_deleted=False, ) ) - mock_data_member_4 = cache_utilities.RefCell( + mock_data_member_4: cache_utilities.RefCell[cache_utilities.MemberData] = cache_utilities.RefCell( mock.Mock( cache_utilities.MemberData, user=mock_user_4, @@ -1933,7 +2079,7 @@ def test_clear_members(self, cache_impl): has_been_deleted=False, ) ) - mock_data_member_5 = cache_utilities.RefCell( + mock_data_member_5: cache_utilities.RefCell[cache_utilities.MemberData] = cache_utilities.RefCell( mock.Mock( cache_utilities.MemberData, user=mock_user_5, @@ -1941,11 +2087,11 @@ def test_clear_members(self, cache_impl): has_been_deleted=False, ) ) - mock_member_1 = object() - mock_member_2 = object() - mock_member_3 = object() - mock_member_4 = object() - mock_member_5 = object() + mock_member_1 = mock.Mock() + mock_member_2 = mock.Mock() + mock_member_3 = mock.Mock() + mock_member_4 = mock.Mock() + mock_member_5 = mock.Mock() guild_record_1 = cache_utilities.GuildRecord( members=collections.FreezableDict( {snowflakes.Snowflake(2123123): mock_data_member_1, snowflakes.Snowflake(212314423): mock_data_member_2} @@ -2013,92 +2159,118 @@ def test_clear_members(self, cache_impl): assert guild_record_2.members is None @pytest.mark.skip(reason="TODO") - def test_clear_members_for_guild(self, cache_impl): ... + def test_clear_members_for_guild(self, cache_impl: cache_impl_.CacheImpl): ... - def test_delete_member_for_unknown_guild_record(self, cache_impl): - result = cache_impl.delete_member(StubModel(42123), StubModel(67876)) + def test_delete_member_for_unknown_guild_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): + result = cache_impl.delete_member(hikari_partial_guild, hikari_user) assert result is None - def test_delete_member_for_unknown_member_cache(self, cache_impl): - cache_impl._guild_entries = {snowflakes.Snowflake(42123): cache_utilities.GuildRecord()} + def test_delete_member_for_unknown_member_cache( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): + cache_impl._guild_entries = collections.FreezableDict( + {snowflakes.Snowflake(123): cache_utilities.GuildRecord()} + ) - result = cache_impl.delete_member(StubModel(42123), StubModel(67876)) + result = cache_impl.delete_member(hikari_partial_guild, hikari_user) assert result is None - def test_delete_member_for_known_member(self, cache_impl): + def test_delete_member_for_known_member( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): mock_member = mock.Mock(guilds.Member) - mock_user = cache_utilities.RefCell(mock.Mock(id=snowflakes.Snowflake(67876))) + mock_user = cache_utilities.RefCell(mock.Mock(id=snowflakes.Snowflake(789))) mock_member_data = mock.Mock( - cache_utilities.MemberData, user=mock_user, guild_id=snowflakes.Snowflake(42123), has_been_deleted=False + cache_utilities.MemberData, user=mock_user, guild_id=snowflakes.Snowflake(123), has_been_deleted=False + ) + mock_reffed_member: cache_utilities.RefCell[cache_utilities.MemberData] = cache_utilities.RefCell( + mock_member_data ) - mock_reffed_member = cache_utilities.RefCell(mock_member_data) - guild_record = cache_utilities.GuildRecord(members={snowflakes.Snowflake(67876): mock_reffed_member}) - cache_impl._guild_entries = collections.FreezableDict({snowflakes.Snowflake(42123): guild_record}) + guild_record = cache_utilities.GuildRecord( + members=collections.FreezableDict({snowflakes.Snowflake(789): mock_reffed_member}) + ) + cache_impl._guild_entries = collections.FreezableDict({snowflakes.Snowflake(123): guild_record}) cache_impl._remove_guild_record_if_empty = mock.Mock() cache_impl._garbage_collect_user = mock.Mock() cache_impl._build_member = mock.Mock(return_value=mock_member) - result = cache_impl.delete_member(StubModel(42123), StubModel(67876)) + result = cache_impl.delete_member(hikari_partial_guild, hikari_user) assert result is mock_member - assert cache_impl._guild_entries[snowflakes.Snowflake(42123)].members is None + assert cache_impl._guild_entries[snowflakes.Snowflake(123)].members is None cache_impl._build_member.assert_called_once_with(mock_reffed_member) cache_impl._garbage_collect_user.assert_called_once_with(mock_user, decrement=1) - cache_impl._remove_guild_record_if_empty.assert_called_once_with(snowflakes.Snowflake(42123), guild_record) + cache_impl._remove_guild_record_if_empty.assert_called_once_with(snowflakes.Snowflake(123), guild_record) - def test_delete_member_for_known_hard_referenced_member(self, cache_impl): - mock_member = cache_utilities.RefCell(mock.Mock(has_been_deleted=False), ref_count=1) + def test_delete_member_for_known_hard_referenced_member( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): + mock_member: cache_utilities.RefCell[cache_utilities.MemberData] = cache_utilities.RefCell( + mock.Mock(has_been_deleted=False), ref_count=1 + ) cache_impl._guild_entries = collections.FreezableDict( { - snowflakes.Snowflake(42123): cache_utilities.GuildRecord( - members=collections.FreezableDict({snowflakes.Snowflake(67876): mock_member}) + snowflakes.Snowflake(123): cache_utilities.GuildRecord( + members=collections.FreezableDict({snowflakes.Snowflake(789): mock_member}) ) } ) - result = cache_impl.delete_member(StubModel(42123), StubModel(67876)) + result = cache_impl.delete_member(hikari_partial_guild, hikari_user) assert result is None assert mock_member.object.has_been_deleted is True - def test_get_member_for_unknown_member_cache(self, cache_impl): + def test_get_member_for_unknown_member_cache( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): cache_impl._guild_entries = collections.FreezableDict( - {snowflakes.Snowflake(1234213): cache_utilities.GuildRecord()} + {snowflakes.Snowflake(123): cache_utilities.GuildRecord()} ) - result = cache_impl.get_member(StubModel(1234213), StubModel(512312354)) + result = cache_impl.get_member(hikari_partial_guild, hikari_user) assert result is None - def test_get_member_for_unknown_member(self, cache_impl): + def test_get_member_for_unknown_member( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): cache_impl._guild_entries = collections.FreezableDict( { - snowflakes.Snowflake(1234213): cache_utilities.GuildRecord( - members={snowflakes.Snowflake(43123): mock.Mock(cache_utilities.MemberData)} + snowflakes.Snowflake(123): cache_utilities.GuildRecord( + members=collections.FreezableDict( + {snowflakes.Snowflake(43123): mock.Mock(cache_utilities.MemberData)} + ) ) } ) - result = cache_impl.get_member(StubModel(1234213), StubModel(512312354)) + result = cache_impl.get_member(hikari_partial_guild, hikari_user) assert result is None - def test_get_member_for_unknown_guild_record(self, cache_impl): - result = cache_impl.get_member(StubModel(1234213), StubModel(512312354)) + def test_get_member_for_unknown_guild_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): + result = cache_impl.get_member(hikari_partial_guild, hikari_user) assert result is None - def test_get_member_for_known_member(self, cache_impl): + def test_get_member_for_known_member( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): mock_member_data = mock.Mock(cache_utilities.MemberData) mock_member = mock.Mock(guilds.Member) cache_impl._guild_entries = collections.FreezableDict( { - snowflakes.Snowflake(1234213): cache_utilities.GuildRecord( + snowflakes.Snowflake(123): cache_utilities.GuildRecord( members=collections.FreezableDict( { - snowflakes.Snowflake(512312354): mock_member_data, + snowflakes.Snowflake(789): mock_member_data, snowflakes.Snowflake(321): mock.Mock(cache_utilities.MemberData), } ) @@ -2108,22 +2280,22 @@ def test_get_member_for_known_member(self, cache_impl): cache_impl._user_entries = collections.FreezableDict({}) cache_impl._build_member = mock.Mock(return_value=mock_member) - result = cache_impl.get_member(StubModel(1234213), StubModel(512312354)) + result = cache_impl.get_member(hikari_partial_guild, hikari_user) assert result is mock_member cache_impl._build_member.assert_called_once_with(mock_member_data) - def test_get_members_view(self, cache_impl): - mock_member_data_1 = cache_utilities.RefCell(object()) - mock_member_data_2 = cache_utilities.RefCell(object()) - mock_member_data_3 = cache_utilities.RefCell(object()) - mock_member_data_4 = cache_utilities.RefCell(object()) - mock_member_data_5 = cache_utilities.RefCell(object()) - mock_member_1 = object() - mock_member_2 = object() - mock_member_3 = object() - mock_member_4 = object() - mock_member_5 = object() + def test_get_members_view(self, cache_impl: cache_impl_.CacheImpl): + mock_member_data_1: cache_utilities.RefCell[cache_utilities.MemberData] = cache_utilities.RefCell(mock.Mock()) + mock_member_data_2: cache_utilities.RefCell[cache_utilities.MemberData] = cache_utilities.RefCell(mock.Mock()) + mock_member_data_3: cache_utilities.RefCell[cache_utilities.MemberData] = cache_utilities.RefCell(mock.Mock()) + mock_member_data_4: cache_utilities.RefCell[cache_utilities.MemberData] = cache_utilities.RefCell(mock.Mock()) + mock_member_data_5: cache_utilities.RefCell[cache_utilities.MemberData] = cache_utilities.RefCell(mock.Mock()) + mock_member_1 = mock.Mock() + mock_member_2 = mock.Mock() + mock_member_3 = mock.Mock() + mock_member_4 = mock.Mock() + mock_member_5 = mock.Mock() cache_impl._build_member = mock.Mock( side_effect=[mock_member_1, mock_member_2, mock_member_3, mock_member_4, mock_member_5] ) @@ -2171,23 +2343,37 @@ def test_get_members_view(self, cache_impl): ] ) - def test_get_members_view_for_guild_unknown_record(self, cache_impl): - members_mapping = cache_impl.get_members_view_for_guild(StubModel(42334)) + def test_get_members_view_for_guild_unknown_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): + members_mapping = cache_impl.get_members_view_for_guild(hikari_partial_guild.id) assert members_mapping == {} - def test_get_members_view_for_guild_unknown_member_cache(self, cache_impl): + def test_get_members_view_for_guild_unknown_member_cache( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): cache_impl._guild_entries = collections.FreezableDict( - {snowflakes.Snowflake(42334): cache_utilities.GuildRecord()} + {snowflakes.Snowflake(123): cache_utilities.GuildRecord()} ) - members_mapping = cache_impl.get_members_view_for_guild(StubModel(42334)) + members_mapping = cache_impl.get_members_view_for_guild(hikari_partial_guild.id) assert members_mapping == {} - def test_get_members_view_for_guild(self, cache_impl): - mock_member_data_1 = cache_utilities.RefCell(mock.Mock(cache_utilities.MemberData, has_been_deleted=False)) - mock_member_data_2 = cache_utilities.RefCell(mock.Mock(cache_utilities.MemberData, has_been_deleted=False)) + def test_get_members_view_for_guild( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): + mock_member_data_1: cache_utilities.RefCell[cache_utilities.MemberData] = cache_utilities.RefCell( + mock.Mock(cache_utilities.MemberData, has_been_deleted=False) + ) + mock_member_data_2: cache_utilities.RefCell[cache_utilities.MemberData] = cache_utilities.RefCell( + mock.Mock(cache_utilities.MemberData, has_been_deleted=False) + ) + mock_member_data_3: cache_utilities.RefCell[cache_utilities.MemberData] = cache_utilities.RefCell( + mock.Mock(cache_utilities.MemberData, has_been_deleted=True) + ) + mock_member_1 = mock.Mock(guilds.Member) mock_member_2 = mock.Mock(guilds.Member) guild_record = cache_utilities.GuildRecord( @@ -2195,21 +2381,19 @@ def test_get_members_view_for_guild(self, cache_impl): { snowflakes.Snowflake(3214321): mock_member_data_1, snowflakes.Snowflake(53224): mock_member_data_2, - snowflakes.Snowflake(9000): cache_utilities.RefCell( - mock.Mock(cache_utilities.MemberData, has_been_deleted=True) - ), + snowflakes.Snowflake(9000): mock_member_data_3, } ) ) - cache_impl._guild_entries = collections.FreezableDict({snowflakes.Snowflake(42334): guild_record}) + cache_impl._guild_entries = collections.FreezableDict({snowflakes.Snowflake(123): guild_record}) cache_impl._build_member = mock.Mock(side_effect=[mock_member_1, mock_member_2]) - result = cache_impl.get_members_view_for_guild(StubModel(42334)) + result = cache_impl.get_members_view_for_guild(hikari_partial_guild.id) assert result == {snowflakes.Snowflake(3214321): mock_member_1, snowflakes.Snowflake(53224): mock_member_2} cache_impl._build_member.assert_has_calls([mock.call(mock_member_data_1), mock.call(mock_member_data_2)]) - def test_set_member(self, cache_impl): + def test_set_member(self, cache_impl: cache_impl_.CacheImpl): mock_user = mock.Mock(users.User, id=snowflakes.Snowflake(645234123)) mock_user_ref = cache_utilities.RefCell(mock_user) member_model = guilds.Member( @@ -2237,11 +2421,10 @@ def test_set_member(self, cache_impl): cache_impl._set_user.assert_called_once_with(mock_user) cache_impl._increment_ref_count.assert_called_once_with(mock_user_ref) - assert 67345234 in cache_impl._guild_entries - assert 645234123 in cache_impl._guild_entries[snowflakes.Snowflake(67345234)].members - member_entry = cache_impl._guild_entries[snowflakes.Snowflake(67345234)].members[ - snowflakes.Snowflake(645234123) - ] + guild_entry_members = cache_impl._guild_entries[snowflakes.Snowflake(67345234)].members + assert guild_entry_members is not None + assert 645234123 in guild_entry_members + member_entry = guild_entry_members[snowflakes.Snowflake(645234123)] assert member_entry.object.user is mock_user_ref assert member_entry.object.guild_id == 67345234 assert member_entry.object.nickname == "A NICK LOL" @@ -2264,11 +2447,13 @@ def test_set_member(self, cache_impl): ) assert member_entry.object.guild_flags == guilds.GuildMemberFlags.DID_REJOIN - def test_set_member_doesnt_increment_user_ref_count_for_pre_cached_member(self, cache_impl): + def test_set_member_doesnt_increment_user_ref_count_for_pre_cached_member(self, cache_impl: cache_impl_.CacheImpl): mock_user = mock.Mock(users.User, id=snowflakes.Snowflake(645234123)) member_model = mock.MagicMock(guilds.Member, user=mock_user, guild_id=snowflakes.Snowflake(67345234)) cache_impl._set_user = mock.Mock() - cache_impl._increment_user_ref_count = mock.Mock() + cache_impl._increment_user_ref_count = ( + mock.Mock() + ) # FIXME: Seems this is calling an object that does not actually exist. cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(67345234): cache_utilities.GuildRecord( @@ -2282,9 +2467,9 @@ def test_set_member_doesnt_increment_user_ref_count_for_pre_cached_member(self, cache_impl.set_member(member_model) cache_impl._set_user.assert_called_once_with(mock_user) - cache_impl._increment_user_ref_count.assert_not_called() + cache_impl._increment_user_ref_count.assert_not_called() # FIXME: Seems this is calling an object that does not actually exist. - def test_update_member(self, cache_impl): + def test_update_member(self, cache_impl: cache_impl_.CacheImpl): mock_old_cached_member = mock.Mock(guilds.Member) mock_new_cached_member = mock.Mock(guilds.Member) mock_member = mock.Mock( @@ -2302,56 +2487,58 @@ def test_update_member(self, cache_impl): cache_impl.set_member.assert_called_once_with(mock_member) @pytest.mark.skip(reason="TODO") - def test_clear_presences(self, cache_impl): ... + def test_clear_presences(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_clear_presences_for_guild(self, cache_impl): ... + def test_clear_presences_for_guild(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_delete_presence(self, cache_impl): ... + def test_delete_presence(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_get_presence(self, cache_impl): ... + def test_get_presence(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_get_presences_view(self, cache_impl): ... + def test_get_presences_view(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_get_presences_view_for_guild(self, cache_impl): ... + def test_get_presences_view_for_guild(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_set_presence(self, cache_impl): ... + def test_set_presence(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_update_presence(self, cache_impl): ... + def test_update_presence(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_clear_roles(self, cache_impl): ... + def test_clear_roles(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_clear_roles_for_guild(self, cache_impl): ... + def test_clear_roles_for_guild(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_delete_role(self, cache_impl): ... + def test_delete_role(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_get_role(self, cache_impl): ... + def test_get_role(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_get_roles_view(self, cache_impl): ... + def test_get_roles_view(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_get_roles_view_for_guild(self, cache_impl): ... + def test_get_roles_view_for_guild(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_set_role(self, cache_impl): ... + def test_set_role(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_update_role(self, cache_impl): ... + def test_update_role(self, cache_impl: cache_impl_.CacheImpl): ... - def test__garbage_collect_user_for_known_unreferenced_user(self, cache_impl): - mock_user = cache_utilities.RefCell(mock.Mock(id=snowflakes.Snowflake(21231234)), ref_count=1) - mock_other_user = mock.Mock(cache_utilities.RefCell, ref_count=1) + def test__garbage_collect_user_for_known_unreferenced_user(self, cache_impl: cache_impl_.CacheImpl): + mock_user: cache_utilities.RefCell[users.User] = cache_utilities.RefCell( + mock.Mock(id=snowflakes.Snowflake(21231234)), ref_count=1 + ) + mock_other_user: cache_utilities.RefCell[users.User] = mock.Mock(cache_utilities.RefCell, ref_count=1) cache_impl._user_entries = collections.FreezableDict( {snowflakes.Snowflake(21231234): mock_user, snowflakes.Snowflake(645234): mock_other_user} ) @@ -2360,9 +2547,15 @@ def test__garbage_collect_user_for_known_unreferenced_user(self, cache_impl): assert cache_impl._user_entries == {snowflakes.Snowflake(645234): mock_other_user} - def test__garbage_collect_user_for_known_unreferenced_user_removes_cached_dm_channelo(self, cache_impl): - mock_user = cache_utilities.RefCell(mock.Mock(id=snowflakes.Snowflake(21231234)), ref_count=1) - cache_impl._dm_channel_entries = collections.FreezableDict({21231234: 123123123}) + def test__garbage_collect_user_for_known_unreferenced_user_removes_cached_dm_channelo( + self, cache_impl: cache_impl_.CacheImpl + ): + mock_user: cache_utilities.RefCell[users.User] = cache_utilities.RefCell( + mock.Mock(id=snowflakes.Snowflake(21231234)), ref_count=1 + ) + cache_impl._dm_channel_entries = collections.FreezableDict( + {snowflakes.Snowflake(21231234): snowflakes.Snowflake(123123123)} + ) mock_other_user = mock.Mock(cache_utilities.RefCell, ref_count=1) cache_impl._user_entries = collections.FreezableDict( {snowflakes.Snowflake(21231234): mock_user, snowflakes.Snowflake(645234): mock_other_user} @@ -2373,9 +2566,11 @@ def test__garbage_collect_user_for_known_unreferenced_user_removes_cached_dm_cha assert cache_impl._user_entries == {snowflakes.Snowflake(645234): mock_other_user} assert cache_impl._dm_channel_entries == {} - def test_garbage_collect_user_for_referenced_user(self, cache_impl): - mock_user = cache_utilities.RefCell(mock.Mock(id=snowflakes.Snowflake(21231234)), ref_count=2) - mock_other_user = mock.Mock(cache_utilities.RefCell) + def test_garbage_collect_user_for_referenced_user(self, cache_impl: cache_impl_.CacheImpl): + mock_user: cache_utilities.RefCell[users.User] = cache_utilities.RefCell( + mock.Mock(id=snowflakes.Snowflake(21231234)), ref_count=2 + ) + mock_other_user: cache_utilities.RefCell[users.User] = mock.Mock(cache_utilities.RefCell) cache_impl._user_entries = collections.FreezableDict( {snowflakes.Snowflake(21231234): mock_user, snowflakes.Snowflake(645234): mock_other_user} ) @@ -2388,29 +2583,33 @@ def test_garbage_collect_user_for_referenced_user(self, cache_impl): } assert mock_user.ref_count == 1 - def test_garbage_collect_user_for_unknown_user(self, cache_impl): - mock_user = cache_utilities.RefCell(mock.Mock(id=snowflakes.Snowflake(21235432), ref_count=0)) + def test_garbage_collect_user_for_unknown_user(self, cache_impl: cache_impl_.CacheImpl): + mock_user: cache_utilities.RefCell[users.User] = cache_utilities.RefCell( + mock.Mock(id=snowflakes.Snowflake(21235432), ref_count=0) + ) cache_impl._user_entries = collections.FreezableDict({snowflakes.Snowflake(21231234): mock_user}) cache_impl._garbage_collect_user(mock_user) assert cache_impl._user_entries == {snowflakes.Snowflake(21231234): mock_user} - def test_get_user_for_known_user(self, cache_impl): + def test_get_user_for_known_user(self, cache_impl: cache_impl_.CacheImpl, hikari_user: users.User): mock_user = mock.MagicMock(users.User) cache_impl._user_entries = collections.FreezableDict( { - snowflakes.Snowflake(21231234): cache_utilities.RefCell(mock_user), + snowflakes.Snowflake(789): cache_utilities.RefCell(mock_user), snowflakes.Snowflake(645234): mock.Mock(cache_utilities.RefCell), } ) - cache_impl._build_user = mock.Mock(return_value=mock_user) + cache_impl._build_user = mock.Mock( + return_value=mock_user + ) # FIXME: Seems this is calling an object that does not actually exist. - result = cache_impl.get_user(StubModel(21231234)) + result = cache_impl.get_user(hikari_user) assert result == mock_user - def test_get_users_view_for_filled_user_cache(self, cache_impl): + def test_get_users_view_for_filled_user_cache(self, cache_impl: cache_impl_.CacheImpl): mock_user_1 = mock.MagicMock(users.User) mock_user_2 = mock.MagicMock(users.User) cache_impl._user_entries = collections.FreezableDict( @@ -2424,10 +2623,10 @@ def test_get_users_view_for_filled_user_cache(self, cache_impl): assert result == {snowflakes.Snowflake(54123): mock_user_1, snowflakes.Snowflake(76345): mock_user_2} - def test_get_users_view_for_empty_user_cache(self, cache_impl): + def test_get_users_view_for_empty_user_cache(self, cache_impl: cache_impl_.CacheImpl): assert cache_impl.get_users_view() == {} - def test__set_user(self, cache_impl): + def test__set_user(self, cache_impl: cache_impl_.CacheImpl): mock_user = mock.MagicMock(users.User, id=snowflakes.Snowflake(6451234123)) cache_impl._user_entries = collections.FreezableDict( {snowflakes.Snowflake(542143): mock.Mock(cache_utilities.RefCell)} @@ -2440,7 +2639,7 @@ def test__set_user(self, cache_impl): assert cache_impl._user_entries[snowflakes.Snowflake(6451234123)].object == mock_user assert cache_impl._user_entries[snowflakes.Snowflake(6451234123)].object is not mock_user - def test__set_user_carries_over_ref_count(self, cache_impl): + def test__set_user_carries_over_ref_count(self, cache_impl: cache_impl_.CacheImpl): mock_user = mock.MagicMock(users.User, id=snowflakes.Snowflake(6451234123)) cache_impl._user_entries = collections.FreezableDict( { @@ -2457,7 +2656,7 @@ def test__set_user_carries_over_ref_count(self, cache_impl): assert cache_impl._user_entries[snowflakes.Snowflake(6451234123)].object is not mock_user assert cache_impl._user_entries[snowflakes.Snowflake(6451234123)].ref_count == 42 - def test__build_voice_state(self, cache_impl): + def test__build_voice_state(self, cache_impl: cache_impl_.CacheImpl): mock_member = mock.Mock(guilds.Member, user=mock.Mock(users.User, id=snowflakes.Snowflake(7512312))) mock_member_data = mock.Mock(cache_utilities.MemberData, build_entity=mock.Mock(return_value=mock_member)) voice_state_data = cache_utilities.VoiceStateData( @@ -2496,14 +2695,16 @@ def test__build_voice_state(self, cache_impl): ) @pytest.mark.skip(reason="TODO") - def test_clear_voice_states(self, cache_impl): ... + def test_clear_voice_states(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_clear_voice_states_for_channel(self, cache_impl): ... + def test_clear_voice_states_for_channel(self, cache_impl: cache_impl_.CacheImpl): ... - def test_clear_voice_states_for_guild(self, cache_impl): - mock_member_data_1 = object() - mock_member_data_2 = object() + def test_clear_voice_states_for_guild( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): + mock_member_data_1 = mock.Mock() + mock_member_data_2 = mock.Mock() mock_voice_state_data_1 = mock.Mock(cache_utilities.VoiceStateData, member=mock_member_data_1) mock_voice_state_data_2 = mock.Mock(cache_utilities.VoiceStateData, member=mock_member_data_2) mock_voice_state_1 = mock.Mock(voices.VoiceState) @@ -2518,10 +2719,10 @@ def test_clear_voice_states_for_guild(self, cache_impl): ) cache_impl._remove_guild_record_if_empty = mock.Mock() cache_impl._garbage_collect_member = mock.Mock() - cache_impl._guild_entries = collections.FreezableDict({snowflakes.Snowflake(54123123): record}) + cache_impl._guild_entries = collections.FreezableDict({snowflakes.Snowflake(123): record}) cache_impl._build_voice_state = mock.Mock(side_effect=[mock_voice_state_1, mock_voice_state_2]) - result = cache_impl.clear_voice_states_for_guild(StubModel(54123123)) + result = cache_impl.clear_voice_states_for_guild(hikari_partial_guild) assert result == { snowflakes.Snowflake(7512312): mock_voice_state_1, @@ -2530,25 +2731,31 @@ def test_clear_voice_states_for_guild(self, cache_impl): cache_impl._garbage_collect_member.assert_has_calls( [mock.call(record, mock_member_data_1, decrement=1), mock.call(record, mock_member_data_2, decrement=1)] ) - cache_impl._remove_guild_record_if_empty.assert_called_once_with(snowflakes.Snowflake(54123123), record) + cache_impl._remove_guild_record_if_empty.assert_called_once_with(snowflakes.Snowflake(123), record) cache_impl._build_voice_state.assert_has_calls( [mock.call(mock_voice_state_data_1), mock.call(mock_voice_state_data_2)] ) - def test_clear_voice_states_for_guild_unknown_voice_state_cache(self, cache_impl): - cache_impl._guild_entries[snowflakes.Snowflake(24123)] = cache_utilities.GuildRecord() + def test_clear_voice_states_for_guild_unknown_voice_state_cache( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): + cache_impl._guild_entries[snowflakes.Snowflake(123)] = cache_utilities.GuildRecord() - result = cache_impl.clear_voice_states_for_guild(StubModel(24123)) + result = cache_impl.clear_voice_states_for_guild(hikari_partial_guild) assert result == {} - def test_clear_voice_states_for_guild_unknown_record(self, cache_impl): - result = cache_impl.clear_voice_states_for_guild(StubModel(24123)) + def test_clear_voice_states_for_guild_unknown_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild + ): + result = cache_impl.clear_voice_states_for_guild(hikari_partial_guild) assert result == {} - def test_delete_voice_state(self, cache_impl): - mock_member_data = object() + def test_delete_voice_state( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): + mock_member_data = mock.Mock() mock_voice_state_data = mock.Mock(cache_utilities.VoiceStateData, member=mock_member_data) mock_other_voice_state_data = mock.Mock(cache_utilities.VoiceStateData) mock_voice_state = mock.Mock(voices.VoiceState) @@ -2556,36 +2763,38 @@ def test_delete_voice_state(self, cache_impl): guild_record = cache_utilities.GuildRecord( voice_states=collections.FreezableDict( { - snowflakes.Snowflake(12354345): mock_voice_state_data, + snowflakes.Snowflake(789): mock_voice_state_data, snowflakes.Snowflake(6541234): mock_other_voice_state_data, } ), members=collections.FreezableDict( - {snowflakes.Snowflake(12354345): mock_member_data, snowflakes.Snowflake(9955959): object()} + {snowflakes.Snowflake(789): mock_member_data, snowflakes.Snowflake(9955959): mock.Mock()} ), ) cache_impl._user_entries = collections.FreezableDict( - {snowflakes.Snowflake(12354345): object(), snowflakes.Snowflake(9393): object()} + {snowflakes.Snowflake(789): mock.Mock(), snowflakes.Snowflake(9393): mock.Mock()} ) cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(65234): mock.Mock(cache_utilities.GuildRecord), - snowflakes.Snowflake(43123): guild_record, + snowflakes.Snowflake(123): guild_record, } ) cache_impl._remove_guild_record_if_empty = mock.Mock() cache_impl._garbage_collect_member = mock.Mock() - result = cache_impl.delete_voice_state(StubModel(43123), StubModel(12354345)) + result = cache_impl.delete_voice_state(hikari_partial_guild, hikari_user) assert result is mock_voice_state cache_impl._garbage_collect_member.assert_called_once_with(guild_record, mock_member_data, decrement=1) - cache_impl._remove_guild_record_if_empty.assert_called_once_with(snowflakes.Snowflake(43123), guild_record) - assert cache_impl._guild_entries[snowflakes.Snowflake(43123)].voice_states == { + cache_impl._remove_guild_record_if_empty.assert_called_once_with(snowflakes.Snowflake(123), guild_record) + assert cache_impl._guild_entries[snowflakes.Snowflake(123)].voice_states == { snowflakes.Snowflake(6541234): mock_other_voice_state_data } - def test_delete_voice_state_unknown_state(self, cache_impl): + def test_delete_voice_state_unknown_state( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): mock_other_voice_state_data = mock.Mock(cache_utilities.VoiceStateData) cache_impl._build_voice_state = mock.Mock() guild_record = cache_utilities.GuildRecord( @@ -2594,68 +2803,78 @@ def test_delete_voice_state_unknown_state(self, cache_impl): cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(65234): mock.Mock(cache_utilities.GuildRecord), - snowflakes.Snowflake(43123): guild_record, + snowflakes.Snowflake(123): guild_record, } ) cache_impl._remove_guild_record_if_empty = mock.Mock() - result = cache_impl.delete_voice_state(StubModel(43123), StubModel(12354345)) + result = cache_impl.delete_voice_state(hikari_partial_guild, hikari_user) assert result is None cache_impl._remove_guild_record_if_empty.assert_not_called() - assert cache_impl._guild_entries[snowflakes.Snowflake(43123)].voice_states == { + assert cache_impl._guild_entries[snowflakes.Snowflake(123)].voice_states == { snowflakes.Snowflake(6541234): mock_other_voice_state_data } - def test_delete_voice_state_unknown_state_cache(self, cache_impl): + def test_delete_voice_state_unknown_state_cache( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): cache_impl._build_voice_state = mock.Mock() guild_record = cache_utilities.GuildRecord(voice_states=None) cache_impl._guild_entries = collections.FreezableDict( { snowflakes.Snowflake(65234): mock.Mock(cache_utilities.GuildRecord), - snowflakes.Snowflake(43123): guild_record, + snowflakes.Snowflake(123): guild_record, } ) cache_impl._remove_guild_record_if_empty = mock.Mock() - result = cache_impl.delete_voice_state(StubModel(43123), StubModel(12354345)) + result = cache_impl.delete_voice_state(hikari_partial_guild, hikari_user) assert result is None cache_impl._remove_guild_record_if_empty.assert_not_called() - def test_delete_voice_state_unknown_record(self, cache_impl): + def test_delete_voice_state_unknown_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): cache_impl._build_voice_state = mock.Mock() cache_impl._guild_entries = collections.FreezableDict( {snowflakes.Snowflake(65234): mock.Mock(cache_utilities.GuildRecord)} ) cache_impl._remove_guild_record_if_empty = mock.Mock() - result = cache_impl.delete_voice_state(StubModel(43123), StubModel(12354345)) + result = cache_impl.delete_voice_state(hikari_partial_guild, hikari_user) assert result is None cache_impl._remove_guild_record_if_empty.assert_not_called() - def test_get_voice_state_for_known_voice_state(self, cache_impl): + def test_get_voice_state_for_known_voice_state( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): mock_voice_state_data = mock.Mock(cache_utilities.VoiceStateData) mock_voice_state = mock.Mock(voices.VoiceState) cache_impl._build_voice_state = mock.Mock(return_value=mock_voice_state) - guild_record = cache_utilities.GuildRecord(voice_states={snowflakes.Snowflake(43124): mock_voice_state_data}) + guild_record = cache_utilities.GuildRecord( + voice_states=collections.FreezableDict({snowflakes.Snowflake(789): mock_voice_state_data}) + ) cache_impl._guild_entries = collections.FreezableDict( { - snowflakes.Snowflake(1235123): guild_record, + snowflakes.Snowflake(123): guild_record, snowflakes.Snowflake(73245): mock.Mock(cache_utilities.GuildRecord), } ) - result = cache_impl.get_voice_state(StubModel(1235123), StubModel(43124)) + result = cache_impl.get_voice_state(hikari_partial_guild, hikari_user) assert result is mock_voice_state cache_impl._build_voice_state.assert_called_once_with(mock_voice_state_data) - def test_get_voice_state_for_unknown_voice_state(self, cache_impl): + def test_get_voice_state_for_unknown_voice_state( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): cache_impl._guild_entries = collections.FreezableDict( { - snowflakes.Snowflake(1235123): cache_utilities.GuildRecord( + snowflakes.Snowflake(123): cache_utilities.GuildRecord( voice_states=collections.FreezableDict( {snowflakes.Snowflake(54123): mock.Mock(cache_utilities.VoiceStateData)} ) @@ -2664,43 +2883,49 @@ def test_get_voice_state_for_unknown_voice_state(self, cache_impl): } ) - result = cache_impl.get_voice_state(StubModel(1235123), StubModel(43124)) + result = cache_impl.get_voice_state(hikari_partial_guild, hikari_user) assert result is None - def test_get_voice_state_for_unknown_voice_state_cache(self, cache_impl): + def test_get_voice_state_for_unknown_voice_state_cache( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): cache_impl._guild_entries = collections.FreezableDict( { - snowflakes.Snowflake(1235123): cache_utilities.GuildRecord(), + snowflakes.Snowflake(123): cache_utilities.GuildRecord(), snowflakes.Snowflake(73245): mock.Mock(cache_utilities.GuildRecord), } ) - result = cache_impl.get_voice_state(StubModel(1235123), StubModel(43124)) + result = cache_impl.get_voice_state(hikari_partial_guild, hikari_user) assert result is None - def test_get_voice_state_for_unknown_record(self, cache_impl): - cache_impl._guild_entries = {snowflakes.Snowflake(73245): mock.Mock(cache_utilities.GuildRecord)} + def test_get_voice_state_for_unknown_record( + self, cache_impl: cache_impl_.CacheImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): + cache_impl._guild_entries = collections.FreezableDict( + {snowflakes.Snowflake(73245): mock.Mock(cache_utilities.GuildRecord)} + ) - result = cache_impl.get_voice_state(StubModel(1235123), StubModel(43124)) + result = cache_impl.get_voice_state(hikari_partial_guild, hikari_user) assert result is None @pytest.mark.skip(reason="TODO") - def test_get_voice_state_view(self, cache_impl): ... + def test_get_voice_state_view(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_get_voice_states_view_for_channel(self, cache_impl): ... + def test_get_voice_states_view_for_channel(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_get_voice_states_view_for_guild(self, cache_impl): ... + def test_get_voice_states_view_for_guild(self, cache_impl: cache_impl_.CacheImpl): ... - def test_set_voice_state(self, cache_impl): - mock_member = object() - mock_reffed_member = cache_utilities.RefCell(object()) + def test_set_voice_state(self, cache_impl: cache_impl_.CacheImpl): + mock_member = mock.Mock() + mock_reffed_member = cache_utilities.RefCell(mock.Mock()) voice_state = voices.VoiceState( - app=None, + app=mock.Mock(), channel_id=snowflakes.Snowflake(239211023123), guild_id=snowflakes.Snowflake(43123123), is_guild_muted=True, @@ -2722,7 +2947,9 @@ def test_set_voice_state(self, cache_impl): cache_impl._increment_ref_count.assert_called_with(mock_reffed_member) cache_impl._set_member.assert_called_once_with(mock_member) - voice_state_data = cache_impl._guild_entries[43123123].voice_states[4531231] + guild_entry_voice_states = cache_impl._guild_entries[snowflakes.Snowflake(43123123)].voice_states + assert guild_entry_voice_states is not None + voice_state_data = guild_entry_voice_states[snowflakes.Snowflake(4531231)] assert voice_state_data.channel_id == 239211023123 assert voice_state_data.guild_id == 43123123 assert voice_state_data.is_guild_muted is True @@ -2735,7 +2962,7 @@ def test_set_voice_state(self, cache_impl): 2021, 4, 17, 10, 13, 56, 939273, tzinfo=datetime.timezone.utc ) - def test_update_voice_state(self, cache_impl): + def test_update_voice_state(self, cache_impl: cache_impl_.CacheImpl): mock_old_voice_state = mock.Mock(voices.VoiceState) mock_new_voice_state = mock.Mock(voices.VoiceState) voice_state = mock.Mock( @@ -2755,13 +2982,15 @@ def test_update_voice_state(self, cache_impl): ] ) - def test__build_message(self, cache_impl): + def test__build_message(self, cache_impl: cache_impl_.CacheImpl): mock_author = mock.MagicMock(users.User) - mock_member = object() + mock_member = mock.Mock() member_data = mock.Mock(build_entity=mock.Mock(return_value=mock_member)) mock_channel = mock.MagicMock() - mock_mention_user = mock.MagicMock() - mock_user_mentions = {snowflakes.Snowflake(4231): cache_utilities.RefCell(mock_mention_user)} + mock_mention_user = mock.MagicMock(users.User) + mock_user_mentions: collections.FreezableDict[snowflakes.Snowflake, cache_utilities.RefCell[users.User]] = ( + collections.FreezableDict({snowflakes.Snowflake(4231): cache_utilities.RefCell(mock_mention_user)}) + ) mock_role_mention_ids = (snowflakes.Snowflake(21323123),) mock_channel_mentions = {snowflakes.Snowflake(4444): mock_channel} mock_attachment = mock.MagicMock(messages.Attachment) @@ -2773,9 +3002,9 @@ def test__build_message(self, cache_impl): mock_activity = mock.MagicMock(messages.MessageActivity) mock_application = mock.MagicMock(messages.MessageApplication) mock_reference = mock.MagicMock(messages.MessageReference) - mock_referenced_message = object() + mock_referenced_message = mock.Mock() mock_message_snapshots = mock.MagicMock() - mock_component = object() + mock_component = mock.Mock() mock_referenced_message_data = mock.Mock( cache_utilities.MessageData, build_entity=mock.Mock(return_value=mock_referenced_message) ) @@ -2872,12 +3101,12 @@ def test__build_message(self, cache_impl): assert result.thread == mock_thread assert result.interaction_metadata == mock_interaction_metadata - def test__build_message_with_null_fields(self, cache_impl): + def test__build_message_with_null_fields(self, cache_impl: cache_impl_.CacheImpl): message_data = cache_utilities.MessageData( id=snowflakes.Snowflake(32123123), channel_id=snowflakes.Snowflake(3123123123), guild_id=snowflakes.Snowflake(5555555), - author=cache_utilities.RefCell(object()), + author=cache_utilities.RefCell(mock.Mock()), member=None, content=None, timestamp=datetime.datetime(2020, 7, 30, 7, 10, 9, 550233, tzinfo=datetime.timezone.utc), @@ -2897,7 +3126,7 @@ def test__build_message_with_null_fields(self, cache_impl): activity=None, application=None, message_reference=None, - message_snapshots=None, + message_snapshots=[], flags=messages.MessageFlag.CROSSPOSTED, nonce=None, referenced_message=None, @@ -2932,50 +3161,48 @@ def test__build_message_with_null_fields(self, cache_impl): assert result.interaction_metadata is None @pytest.mark.skip(reason="TODO") - def test_clear_messages(self, cache_impl): - raise NotImplementedError + def test_clear_messages(self, cache_impl: cache_impl_.CacheImpl): ... @pytest.mark.skip(reason="TODO") - def test_delete_message(self, cache_impl): - raise NotImplementedError + def test_delete_message(self, cache_impl: cache_impl_.CacheImpl): ... - def test_get_message(self, cache_impl): - mock_message_data = object() - mock_message = object() + def test_get_message(self, cache_impl: cache_impl_.CacheImpl, hikari_message: messages.Message): + mock_message_data = mock.Mock() + mock_message = mock.Mock() cache_impl._build_message = mock.Mock(return_value=mock_message) - cache_impl._message_entries[snowflakes.Snowflake(32332123)] = mock_message_data + cache_impl._message_entries[snowflakes.Snowflake(101)] = mock_message_data - result = cache_impl.get_message(StubModel(32332123)) + result = cache_impl.get_message(hikari_message) assert result is mock_message cache_impl._build_message.assert_called_once_with(mock_message_data) - def test_get_message_reference_only(self, cache_impl): - mock_message_data = object() - mock_message = object() + def test_get_message_reference_only(self, cache_impl: cache_impl_.CacheImpl, hikari_message: messages.Message): + mock_message_data = mock.Mock() + mock_message = mock.Mock() cache_impl._build_message = mock.Mock(return_value=mock_message) - cache_impl._referenced_messages[snowflakes.Snowflake(32332123)] = mock_message_data + cache_impl._referenced_messages[snowflakes.Snowflake(101)] = mock_message_data - result = cache_impl.get_message(StubModel(32332123)) + result = cache_impl.get_message(hikari_message) assert result is mock_message cache_impl._build_message.assert_called_once_with(mock_message_data) - def test_get_message_for_unknown_message(self, cache_impl): + def test_get_message_for_unknown_message(self, cache_impl: cache_impl_.CacheImpl, hikari_message: messages.Message): cache_impl._build_message = mock.Mock() - result = cache_impl.get_message(StubModel(32332123)) + result = cache_impl.get_message(hikari_message) assert result is None cache_impl._build_message.assert_not_called() - def test_get_messages_view(self, cache_impl): - mock_message_data_1 = object() - mock_message_data_2 = object() - mock_message_data_3 = object() - mock_message_1 = object() - mock_message_2 = object() - mock_message_3 = object() + def test_get_messages_view(self, cache_impl: cache_impl_.CacheImpl): + mock_message_data_1 = mock.Mock() + mock_message_data_2 = mock.Mock() + mock_message_data_3 = mock.Mock() + mock_message_1 = mock.Mock() + mock_message_2 = mock.Mock() + mock_message_3 = mock.Mock() cache_impl._build_message = mock.Mock(side_effect=[mock_message_1, mock_message_2, mock_message_3]) cache_impl._message_entries = collections.FreezableDict( {snowflakes.Snowflake(32123): mock_message_data_1, snowflakes.Snowflake(451231): mock_message_data_2} @@ -2990,12 +3217,11 @@ def test_get_messages_view(self, cache_impl): ) @pytest.mark.skip(reason="TODO") - def test_set_message(self, cache_impl): - raise NotImplementedError + def test_set_message(self, cache_impl: cache_impl_.CacheImpl): ... - def test_update_message_for_full_message(self, cache_impl): + def test_update_message_for_full_message(self, cache_impl: cache_impl_.CacheImpl): message = mock.Mock(messages.Message, id=snowflakes.Snowflake(45312312)) - cached_message = object() + cached_message = mock.Mock() cache_impl.get_message = mock.Mock(side_effect=(None, cached_message)) cache_impl.set_message = mock.Mock() @@ -3006,10 +3232,9 @@ def test_update_message_for_full_message(self, cache_impl): cache_impl.get_message.assert_has_calls([mock.call(45312312), mock.call(45312312)]) @pytest.mark.skip(reason="TODO") - def test_update_message_for_partial_message(self, cache_impl): - raise NotImplementedError + def test_update_message_for_partial_message(self, cache_impl: cache_impl_.CacheImpl): ... - def test_update_message_for_unknown_partial_message(self, cache_impl): + def test_update_message_for_unknown_partial_message(self, cache_impl: cache_impl_.CacheImpl): message = mock.Mock(messages.PartialMessage, id=snowflakes.Snowflake(2123123123)) cache_impl.get_message = mock.Mock(side_effect=(None, None)) cache_impl.set_message = mock.Mock() @@ -3151,7 +3376,13 @@ def test_update_message_for_unknown_partial_message(self, cache_impl): ("update_voice_state", config_api.CacheComponents.VOICE_STATES, (None, None)), ], ) - def test_function_default(self, cache_impl, name, component, expected): + def test_function_default( + self, + cache_impl: cache_impl_.CacheImpl, + name: str, + component: config_api.CacheComponents, + expected: cache_utilities.EmptyCacheView | None | tuple[None, None], + ): cache_impl._is_cache_enabled_for = mock.Mock(return_value=False) fn = getattr(cache_impl, name) diff --git a/tests/hikari/impl/test_config.py b/tests/hikari/impl/test_config.py index 09fd4d6c17..547a25cb7e 100644 --- a/tests/hikari/impl/test_config.py +++ b/tests/hikari/impl/test_config.py @@ -21,7 +21,9 @@ from __future__ import annotations import ssl +import typing +import mock import pytest from hikari.impl import config as config_ @@ -41,7 +43,7 @@ def test_when_value_is_False(self): assert returned.verify_mode is ssl.CERT_NONE def test_when_value_is_non_bool(self): - value = object() + value = mock.Mock() assert config_._ssl_factory(value) is value @@ -50,46 +52,46 @@ class TestBasicAuthHeader: def config(self): return config_.BasicAuthHeader(username="davfsa", password="securepassword123") - def test_header_property(self, config): + def test_header_property(self, config: config_.BasicAuthHeader): assert config.header == f"{config_._BASICAUTH_TOKEN_PREFIX} ZGF2ZnNhOnNlY3VyZXBhc3N3b3JkMTIz" - def test_str(self, config): + def test_str(self, config: config_.BasicAuthHeader): assert str(config) == f"{config_._BASICAUTH_TOKEN_PREFIX} ZGF2ZnNhOnNlY3VyZXBhc3N3b3JkMTIz" class TestHTTPTimeoutSettings: @pytest.mark.parametrize("arg", ["acquire_and_connect", "request_socket_connect", "request_socket_read", "total"]) - def test_max_redirects_validator_when_not_None_nor_int_nor_float(self, arg): + def test_max_redirects_validator_when_not_None_nor_int_nor_float(self, arg: str): with pytest.raises(ValueError, match=rf"HTTPTimeoutSettings.{arg} must be None, or a POSITIVE float/int"): - config_.HTTPTimeoutSettings(**{arg: object()}) + config_.HTTPTimeoutSettings(**{arg: mock.Mock()}) @pytest.mark.parametrize("arg", ["acquire_and_connect", "request_socket_connect", "request_socket_read", "total"]) - def test_max_redirects_validator_when_negative_int(self, arg): + def test_max_redirects_validator_when_negative_int(self, arg: str): with pytest.raises(ValueError, match=rf"HTTPTimeoutSettings.{arg} must be None, or a POSITIVE float/int"): config_.HTTPTimeoutSettings(**{arg: -1}) @pytest.mark.parametrize("arg", ["acquire_and_connect", "request_socket_connect", "request_socket_read", "total"]) - def test_max_redirects_validator_when_negative_float(self, arg): + def test_max_redirects_validator_when_negative_float(self, arg: str): with pytest.raises(ValueError, match=rf"HTTPTimeoutSettings.{arg} must be None, or a POSITIVE float/int"): config_.HTTPTimeoutSettings(**{arg: -1.1}) @pytest.mark.parametrize("arg", ["acquire_and_connect", "request_socket_connect", "request_socket_read", "total"]) @pytest.mark.parametrize("value", [1, 1.1, None]) - def test_max_redirects_validator(self, arg, value): + def test_max_redirects_validator(self, arg: str, value: typing.Optional[typing.Union[float, int]]): config_.HTTPTimeoutSettings(**{arg: value}) class TestHTTPSettings: def test_max_redirects_validator_when_not_None_nor_int(self): with pytest.raises(ValueError, match=r"http_settings.max_redirects must be None or a POSITIVE integer"): - config_.HTTPSettings(max_redirects=object()) + config_.HTTPSettings(max_redirects=mock.Mock()) def test_max_redirects_validator_when_negative(self): with pytest.raises(ValueError, match=r"http_settings.max_redirects must be None or a POSITIVE integer"): config_.HTTPSettings(max_redirects=-1) @pytest.mark.parametrize("value", [1, None]) - def test_max_redirects_validator(self, value): + def test_max_redirects_validator(self, value: typing.Optional[int]): config_.HTTPSettings(max_redirects=value) def test_ssl(self): diff --git a/tests/hikari/impl/test_entity_factory.py b/tests/hikari/impl/test_entity_factory.py index 1eeac95f92..b01f35494b 100644 --- a/tests/hikari/impl/test_entity_factory.py +++ b/tests/hikari/impl/test_entity_factory.py @@ -64,12 +64,12 @@ @pytest.fixture -def permission_overwrite_payload(): +def permission_overwrite_payload() -> dict[str, typing.Any]: return {"id": "4242", "type": 1, "allow": 65, "deny": 49152, "allow_new": "65", "deny_new": "49152"} @pytest.fixture -def guild_text_channel_payload(permission_overwrite_payload): +def guild_text_channel_payload(permission_overwrite_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "123", "guild_id": "567", @@ -88,7 +88,7 @@ def guild_text_channel_payload(permission_overwrite_payload): @pytest.fixture -def guild_voice_channel_payload(permission_overwrite_payload): +def guild_voice_channel_payload(permission_overwrite_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "555", "guild_id": "789", @@ -107,7 +107,7 @@ def guild_voice_channel_payload(permission_overwrite_payload): @pytest.fixture -def guild_news_channel_payload(permission_overwrite_payload): +def guild_news_channel_payload(permission_overwrite_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "7777", "guild_id": "123", @@ -219,7 +219,7 @@ def primary_guild_payload() -> dict[str, typing.Any]: @pytest.fixture -def user_payload(primary_guild_payload: dict[str, typing.Any]): +def user_payload(primary_guild_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "115590097100865541", "username": "nyaa", @@ -235,12 +235,12 @@ def user_payload(primary_guild_payload: dict[str, typing.Any]): @pytest.fixture -def custom_emoji_payload(): +def custom_emoji_payload() -> dict[str, typing.Any]: return {"id": "691225175349395456", "name": "test", "animated": True} @pytest.fixture -def known_custom_emoji_payload(user_payload): +def known_custom_emoji_payload(user_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "12345", "name": "testing", @@ -254,7 +254,7 @@ def known_custom_emoji_payload(user_payload): @pytest.fixture -def member_payload(user_payload): +def member_payload(user_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "nick": "foobarbaz", "roles": ["11111", "22222", "33333", "44444"], @@ -271,7 +271,7 @@ def member_payload(user_payload): @pytest.fixture -def presence_activity_payload(custom_emoji_payload): +def presence_activity_payload(custom_emoji_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "name": "an activity", "type": 1, @@ -297,7 +297,9 @@ def presence_activity_payload(custom_emoji_payload): @pytest.fixture -def member_presence_payload(user_payload, presence_activity_payload): +def member_presence_payload( + user_payload: dict[str, typing.Any], presence_activity_payload: dict[str, typing.Any] +) -> dict[str, typing.Any]: return { "user": user_payload, "activity": presence_activity_payload, @@ -309,7 +311,7 @@ def member_presence_payload(user_payload, presence_activity_payload): @pytest.fixture -def guild_role_payload(): +def guild_role_payload() -> dict[str, typing.Any]: return { "id": "41771983423143936", "name": "WE DEM BOYZZ!!!!!!", @@ -333,7 +335,7 @@ def guild_role_payload(): @pytest.fixture -def voice_state_payload(member_payload): +def voice_state_payload(member_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "guild_id": "929292929292992", "channel_id": "157733188964188161", @@ -386,17 +388,12 @@ def test__deserialize_max_age_returns_null(): @pytest.fixture -def mock_app() -> traits.RESTAware: - return mock.Mock() - - -@pytest.fixture -def entity_factory_impl(mock_app) -> entity_factory.EntityFactoryImpl: - return hikari_test_helpers.mock_class_namespace(entity_factory.EntityFactoryImpl, slots_=False)(mock_app) +def entity_factory_impl(hikari_app: traits.RESTAware) -> entity_factory.EntityFactoryImpl: + return hikari_test_helpers.mock_class_namespace(entity_factory.EntityFactoryImpl, slots_=False)(hikari_app) class TestGatewayGuildDefinition: - def test_id_property(self, entity_factory_impl): + def test_id_property(self, entity_factory_impl: entity_factory.EntityFactoryImpl): guild_definition = entity_factory_impl.deserialize_gateway_guild( {"id": "123123451234"}, user_id=snowflakes.Snowflake(43123) ) @@ -404,7 +401,11 @@ def test_id_property(self, entity_factory_impl): assert guild_definition.id == 123123451234 def test_channels( - self, entity_factory_impl, guild_text_channel_payload, guild_voice_channel_payload, guild_news_channel_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + guild_text_channel_payload: dict[str, typing.Any], + guild_voice_channel_payload: dict[str, typing.Any], + guild_news_channel_payload: dict[str, typing.Any], ): guild_definition = entity_factory_impl.deserialize_gateway_guild( { @@ -426,30 +427,33 @@ def test_channels( ), } - def test_channels_returns_cached_values(self, entity_factory_impl): + def test_channels_returns_cached_values(self, entity_factory_impl: entity_factory.EntityFactoryImpl): guild_definition = entity_factory_impl.deserialize_gateway_guild( {"id": "265828729970753537"}, user_id=snowflakes.Snowflake(43123) ) - mock_channel = object() - guild_definition._channels = {"123321": mock_channel} - entity_factory_impl.deserialize_guild_text_channel = mock.Mock() - entity_factory_impl.deserialize_guild_voice_channel = mock.Mock() - entity_factory_impl.deserialize_guild_news_channel = mock.Mock() + mock_channel = mock.Mock() - assert guild_definition.channels() == {"123321": mock_channel} + with mock.patch.object(guild_definition, "_channels", {"123321": mock_channel}): + entity_factory_impl.deserialize_guild_text_channel = mock.Mock() + entity_factory_impl.deserialize_guild_voice_channel = mock.Mock() + entity_factory_impl.deserialize_guild_news_channel = mock.Mock() - entity_factory_impl.deserialize_guild_text_channel.assert_not_called() - entity_factory_impl.deserialize_guild_voice_channel.assert_not_called() - entity_factory_impl.deserialize_guild_news_channel.assert_not_called() + assert guild_definition.channels() == {"123321": mock_channel} - def test_channels_ignores_unrecognised_channels(self, entity_factory_impl): + entity_factory_impl.deserialize_guild_text_channel.assert_not_called() + entity_factory_impl.deserialize_guild_voice_channel.assert_not_called() + entity_factory_impl.deserialize_guild_news_channel.assert_not_called() + + def test_channels_ignores_unrecognised_channels(self, entity_factory_impl: entity_factory.EntityFactoryImpl): guild_definition = entity_factory_impl.deserialize_gateway_guild( {"id": "9494949", "channels": [{"id": 123, "type": 1000}]}, user_id=snowflakes.Snowflake(43123) ) assert guild_definition.channels() == {} - def test_emojis(self, entity_factory_impl, known_custom_emoji_payload): + def test_emojis( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, known_custom_emoji_payload: dict[str, typing.Any] + ): guild_definition = entity_factory_impl.deserialize_gateway_guild( {"id": "265828729970753537", "emojis": [known_custom_emoji_payload]}, user_id=snowflakes.Snowflake(43123) ) @@ -460,19 +464,19 @@ def test_emojis(self, entity_factory_impl, known_custom_emoji_payload): ) } - def test_emojis_returns_cached_values(self, entity_factory_impl): - mock_emoji = object() + def test_emojis_returns_cached_values(self, entity_factory_impl: entity_factory.EntityFactoryImpl): + mock_emoji = mock.Mock() entity_factory_impl.deserialize_known_custom_emoji = mock.Mock() guild_definition = entity_factory_impl.deserialize_gateway_guild( {"id": "265828729970753537"}, user_id=snowflakes.Snowflake(43123) ) - guild_definition._emojis = {"21323232": mock_emoji} - assert guild_definition.emojis() == {"21323232": mock_emoji} + with mock.patch.object(guild_definition, "_emojis", {"21323232": mock_emoji}): + assert guild_definition.emojis() == {"21323232": mock_emoji} - entity_factory_impl.deserialize_known_custom_emoji.assert_not_called() + entity_factory_impl.deserialize_known_custom_emoji.assert_not_called() - def test_guild(self, entity_factory_impl, mock_app): + def test_guild(self, entity_factory_impl: entity_factory.EntityFactoryImpl, hikari_app: traits.RESTAware): guild_definition = entity_factory_impl.deserialize_gateway_guild( { "afk_channel_id": "99998888777766", @@ -522,7 +526,7 @@ def test_guild(self, entity_factory_impl, mock_app): ) guild = guild_definition.guild() - assert guild.app is mock_app + assert guild.app is hikari_app assert guild.id == 265828729970753537 assert guild.name == "L33t guild" assert guild.icon_hash == "1a2b3c4d" @@ -566,7 +570,7 @@ def test_guild(self, entity_factory_impl, mock_app): assert guild.public_updates_channel_id == 33333333 assert guild.nsfw_level == guild_models.GuildNSFWLevel.DEFAULT - def test_guild_with_unset_fields(self, entity_factory_impl): + def test_guild_with_unset_fields(self, entity_factory_impl: entity_factory.EntityFactoryImpl): guild_definition = entity_factory_impl.deserialize_gateway_guild( { "afk_channel_id": "99998888777766", @@ -613,7 +617,7 @@ def test_guild_with_unset_fields(self, entity_factory_impl): assert guild.widget_channel_id is None assert guild.is_widget_enabled is None - def test_guild_with_null_fields(self, entity_factory_impl): + def test_guild_with_null_fields(self, entity_factory_impl: entity_factory.EntityFactoryImpl): guild_definition = entity_factory_impl.deserialize_gateway_guild( { "afk_channel_id": None, @@ -683,19 +687,24 @@ def test_guild_with_null_fields(self, entity_factory_impl): assert guild.premium_subscription_count is None assert guild.public_updates_channel_id is None - def test_guild_returns_cached_values(self, entity_factory_impl): - mock_guild = object() - entity_factory_impl.set_guild_attributes = mock.Mock() + def test_guild_returns_cached_values(self, entity_factory_impl: entity_factory.EntityFactoryImpl): + mock_guild = mock.Mock() + + entity_factory_impl.set_guild_attributes = ( + mock.Mock() + ) # FIXME: Seems this is calling an object that does not actually exist. guild_definition = entity_factory_impl.deserialize_gateway_guild( {"id": "9393939"}, user_id=snowflakes.Snowflake(43123) ) - guild_definition._guild = mock_guild - assert guild_definition.guild() is mock_guild + with mock.patch.object(guild_definition, "_guild", mock_guild): + assert guild_definition.guild() is mock_guild - entity_factory_impl.set_guild_attributes.assert_not_called() + entity_factory_impl.set_guild_attributes.assert_not_called() # FIXME: Seems this is calling an object that does not actually exist. - def test_members(self, entity_factory_impl, member_payload): + def test_members( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, member_payload: dict[str, typing.Any] + ): guild_definition = entity_factory_impl.deserialize_gateway_guild( {"id": "265828729970753537", "members": [member_payload]}, user_id=snowflakes.Snowflake(43123) ) @@ -706,19 +715,20 @@ def test_members(self, entity_factory_impl, member_payload): ) } - def test_members_returns_cached_values(self, entity_factory_impl): - mock_member = object() + def test_members_returns_cached_values(self, entity_factory_impl: entity_factory.EntityFactoryImpl): + mock_member = mock.Mock() entity_factory_impl.deserialize_member = mock.Mock() guild_definition = entity_factory_impl.deserialize_gateway_guild( {"id": "92929292"}, user_id=snowflakes.Snowflake(43123) ) - guild_definition._members = {"93939393": mock_member} - - assert guild_definition.members() == {"93939393": mock_member} - entity_factory_impl.deserialize_member.assert_not_called() + with mock.patch.object(guild_definition, "_members", {"93939393": mock_member}): + assert guild_definition.members() == {"93939393": mock_member} + entity_factory_impl.deserialize_member.assert_not_called() - def test_presences(self, entity_factory_impl, member_presence_payload): + def test_presences( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, member_presence_payload: dict[str, typing.Any] + ): guild_definition = entity_factory_impl.deserialize_gateway_guild( {"id": "265828729970753537", "presences": [member_presence_payload]}, user_id=snowflakes.Snowflake(43123) ) @@ -729,19 +739,21 @@ def test_presences(self, entity_factory_impl, member_presence_payload): ) } - def test_presences_returns_cached_values(self, entity_factory_impl): - mock_presence = object() + def test_presences_returns_cached_values(self, entity_factory_impl: entity_factory.EntityFactoryImpl): + mock_presence = mock.Mock() entity_factory_impl.deserialize_member_presence = mock.Mock() guild_definition = entity_factory_impl.deserialize_gateway_guild( {"id": "29292992"}, user_id=snowflakes.Snowflake(43123) ) - guild_definition._presences = {"3939393993": mock_presence} - assert guild_definition.presences() == {"3939393993": mock_presence} + with mock.patch.object(guild_definition, "_presences", {"3939393993": mock_presence}): + assert guild_definition.presences() == {"3939393993": mock_presence} - entity_factory_impl.deserialize_member_presence.assert_not_called() + entity_factory_impl.deserialize_member_presence.assert_not_called() - def test_roles(self, entity_factory_impl, guild_role_payload): + def test_roles( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, guild_role_payload: dict[str, typing.Any] + ): guild_definition = entity_factory_impl.deserialize_gateway_guild( {"id": "265828729970753537", "roles": [guild_role_payload]}, user_id=snowflakes.Snowflake(43123) ) @@ -752,17 +764,17 @@ def test_roles(self, entity_factory_impl, guild_role_payload): ) } - def test_roles_returns_cached_values(self, entity_factory_impl): - mock_role = object() + def test_roles_returns_cached_values(self, entity_factory_impl: entity_factory.EntityFactoryImpl): + mock_role = mock.Mock() entity_factory_impl.deserialize_role = mock.Mock() guild_definition = entity_factory_impl.deserialize_gateway_guild( {"id": "9292929"}, user_id=snowflakes.Snowflake(43123) ) - guild_definition._roles = {"32132123123": mock_role} - assert guild_definition.roles() == {"32132123123": mock_role} + with mock.patch.object(guild_definition, "_roles", {"32132123123": mock_role}): + assert guild_definition.roles() == {"32132123123": mock_role} - entity_factory_impl.deserialize_role.assert_not_called() + entity_factory_impl.deserialize_role.assert_not_called() def test_threads( self, @@ -798,16 +810,15 @@ def test_threads( } def test_threads_returns_cached_values(self, entity_factory_impl: entity_factory.EntityFactoryImpl): - mock_thread = object() + mock_thread = mock.Mock() entity_factory_impl.deserialize_guild_thread = mock.Mock() guild_definition = entity_factory_impl.deserialize_gateway_guild( {"id": "92929292"}, user_id=snowflakes.Snowflake(43123) ) - guild_definition._threads = {54312312: mock_thread} - assert guild_definition.threads() == {54312312: mock_thread} - - entity_factory_impl.deserialize_guild_thread.assert_not_called() + with mock.patch.object(guild_definition, "_threads", {54312312: mock_thread}): + assert guild_definition.threads() == {54312312: mock_thread} + entity_factory_impl.deserialize_guild_thread.assert_not_called() def test_threads_when_no_threads_field(self, entity_factory_impl: entity_factory.EntityFactoryImpl): entity_factory_impl.deserialize_guild_thread = mock.Mock() @@ -831,12 +842,17 @@ def test_threads_ignores_unrecognised_and_threads(self, entity_factory_impl: ent threads = [{"id": str(id_), "type": type_} for id_, type_ in zip(iter(range(len(thread_types))), thread_types)] assert threads guild_definition = entity_factory_impl.deserialize_gateway_guild( - {"id": "4212312", "threads": threads}, user_id=123321 + {"id": "4212312", "threads": threads}, user_id=snowflakes.Snowflake(123321) ) assert guild_definition.threads() == {} - def test_voice_states(self, entity_factory_impl, member_payload, voice_state_payload): + def test_voice_states( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + member_payload: dict[str, typing.Any], + voice_state_payload: dict[str, typing.Any], + ): guild_definition = entity_factory_impl.deserialize_gateway_guild( {"id": "265828729970753537", "voice_states": [voice_state_payload], "members": [member_payload]}, user_id=snowflakes.Snowflake(43123), @@ -851,13 +867,15 @@ def test_voice_states(self, entity_factory_impl, member_payload, voice_state_pay ) } - def test_voice_states_returns_cached_values(self, entity_factory_impl): - mock_voice_state = object() + def test_voice_states_returns_cached_values(self, entity_factory_impl: entity_factory.EntityFactoryImpl): + mock_voice_state = mock.Mock() entity_factory_impl.deserialize_voice_state = mock.Mock() guild_definition = entity_factory_impl.deserialize_gateway_guild( {"id": "292929"}, user_id=snowflakes.Snowflake(43123) ) - guild_definition._voice_states = {"9393939393": mock_voice_state} + guild_definition._voice_states = { + "9393939393": mock_voice_state + } # FIXME: Seems this is calling an object that does not actually exist. assert guild_definition.voice_states() == {"9393939393": mock_voice_state} @@ -865,15 +883,15 @@ def test_voice_states_returns_cached_values(self, entity_factory_impl): class TestEntityFactoryImpl: - def test_app(self, entity_factory_impl, mock_app): - assert entity_factory_impl.app is mock_app + def test_app(self, entity_factory_impl: entity_factory.EntityFactoryImpl, hikari_app: traits.RESTAware): + assert entity_factory_impl.app is hikari_app ###################### # APPLICATION MODELS # ###################### @pytest.fixture - def partial_integration(self): + def partial_integration(self) -> dict[str, typing.Any]: return { "id": "123123123123123", "name": "A Name", @@ -882,7 +900,7 @@ def partial_integration(self): } @pytest.fixture - def own_connection_payload(self, partial_integration): + def own_connection_payload(self, partial_integration: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "friend_sync": False, "id": "2513849648abc", @@ -895,7 +913,12 @@ def own_connection_payload(self, partial_integration): "visibility": 0, } - def test_deserialize_own_connection(self, entity_factory_impl, own_connection_payload, partial_integration): + def test_deserialize_own_connection( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + own_connection_payload: dict[str, typing.Any], + partial_integration: dict[str, typing.Any], + ): own_connection = entity_factory_impl.deserialize_own_connection(own_connection_payload) assert own_connection.id == "2513849648abc" assert own_connection.name == "FS" @@ -909,7 +932,7 @@ def test_deserialize_own_connection(self, entity_factory_impl, own_connection_pa assert isinstance(own_connection, application_models.OwnConnection) def test_deserialize_own_connection_with_nullable_and_optional_fields( - self, entity_factory_impl, own_connection_payload + self, entity_factory_impl: entity_factory.EntityFactoryImpl, own_connection_payload: dict[str, typing.Any] ): del own_connection_payload["integrations"] del own_connection_payload["revoked"] @@ -926,7 +949,7 @@ def test_deserialize_own_connection_with_nullable_and_optional_fields( assert isinstance(own_connection, application_models.OwnConnection) @pytest.fixture - def own_guild_payload(self): + def own_guild_payload(self) -> dict[str, typing.Any]: return { "id": "152559372126519269", "name": "Isopropyl", @@ -938,7 +961,12 @@ def own_guild_payload(self): "approximate_presence_count": 784, } - def test_deserialize_own_guild(self, entity_factory_impl, mock_app, own_guild_payload): + def test_deserialize_own_guild( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + own_guild_payload: dict[str, typing.Any], + ): own_guild = entity_factory_impl.deserialize_own_guild(own_guild_payload) assert own_guild.id == 152559372126519269 @@ -950,7 +978,9 @@ def test_deserialize_own_guild(self, entity_factory_impl, mock_app, own_guild_pa assert own_guild.approximate_member_count == 3268 assert own_guild.approximate_active_member_count == 784 - def test_deserialize_own_guild_with_null_and_unset_fields(self, entity_factory_impl): + def test_deserialize_own_guild_with_null_and_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): own_guild = entity_factory_impl.deserialize_own_guild( { "id": "152559372126519269", @@ -966,14 +996,16 @@ def test_deserialize_own_guild_with_null_and_unset_fields(self, entity_factory_i assert own_guild.icon_hash is None @pytest.fixture - def role_connection_payload(self): + def role_connection_payload(self) -> dict[str, typing.Any]: return { "platform_name": "Muck", "platform_username": "Muck Muck Muck", "metadata": {"key": "value", "key2": "value2"}, } - def test_deserialize_own_application_role_connection(self, entity_factory_impl, role_connection_payload): + def test_deserialize_own_application_role_connection( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, role_connection_payload: dict[str, typing.Any] + ): role_connection = entity_factory_impl.deserialize_own_application_role_connection(role_connection_payload) assert role_connection.platform_name == "Muck" @@ -982,11 +1014,13 @@ def test_deserialize_own_application_role_connection(self, entity_factory_impl, assert isinstance(role_connection, application_models.OwnApplicationRoleConnection) @pytest.fixture - def owner_payload(self, user_payload): + def owner_payload(self, user_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return {**user_payload, "flags": 1 << 10} @pytest.fixture - def application_payload(self, owner_payload, user_payload): + def application_payload( + self, owner_payload: dict[str, typing.Any], user_payload: dict[str, typing.Any] + ) -> dict[str, typing.Any]: return { "id": "209333111222", "name": "Dream Sweet in Sea Major", @@ -1023,11 +1057,16 @@ def application_payload(self, owner_payload, user_payload): } def test_deserialize_application( - self, entity_factory_impl, mock_app, application_payload, owner_payload, user_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + application_payload: dict[str, typing.Any], + owner_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], ): application = entity_factory_impl.deserialize_application(application_payload) - assert application.app is mock_app + assert application.app is hikari_app assert application.id == 209333111222 assert application.name == "Dream Sweet in Sea Major" assert application.description == "I am an application" @@ -1049,6 +1088,7 @@ def test_deserialize_application( assert application.approximate_guild_count == 10000 assert application.approximate_user_install_count == 10001 # Install Parameters + assert application.install_parameters is not None assert application.install_parameters.scopes == [ application_models.OAuth2Scope.BOT, application_models.OAuth2Scope.APPLICATIONS_COMMANDS, @@ -1056,6 +1096,7 @@ def test_deserialize_application( assert application.install_parameters.permissions == permission_models.Permissions.ADMINISTRATOR assert isinstance(application.install_parameters, application_models.ApplicationInstallParameters) # Team + assert application.team is not None assert application.team.id == 202020202 assert application.team.name == "Hikari Development" assert application.team.icon_hash == "hashtag" @@ -1063,7 +1104,7 @@ def test_deserialize_application( assert isinstance(application.team, application_models.Team) # TeamMember assert len(application.team.members) == 1 - member = application.team.members[115590097100865541] + member = application.team.members[snowflakes.Snowflake(115590097100865541)] assert member.membership_state == application_models.TeamMembershipState.INVITED assert member.permissions == ["*"] assert member.team_id == 209333111222 @@ -1093,7 +1134,12 @@ def test_deserialize_application( assert application.cover_image_hash == "hashmebaby" assert isinstance(application, application_models.Application) - def test_deserialize_application_with_unset_fields(self, entity_factory_impl, mock_app, owner_payload): + def test_deserialize_application_with_unset_fields( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + owner_payload: dict[str, typing.Any], + ): application = entity_factory_impl.deserialize_application( { "id": "209333111222", @@ -1117,7 +1163,12 @@ def test_deserialize_application_with_unset_fields(self, entity_factory_impl, mo assert application.terms_of_service_url is None assert application.role_connections_verification_url is None - def test_deserialize_application_with_null_fields(self, entity_factory_impl, mock_app, owner_payload): + def test_deserialize_application_with_null_fields( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + owner_payload: dict[str, typing.Any], + ): application = entity_factory_impl.deserialize_application( { "id": "209333111222", @@ -1147,7 +1198,7 @@ def test_deserialize_application_with_null_fields(self, entity_factory_impl, moc assert application.tags == [] @pytest.fixture - def invite_application_payload(self): + def invite_application_payload(self) -> dict[str, typing.Any]: return { "id": "773336526917861400", "name": "Betrayal.io", @@ -1158,7 +1209,7 @@ def invite_application_payload(self): } @pytest.fixture - def authorization_information_payload(self, user_payload): + def authorization_information_payload(self, user_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "application": { "id": "4123123123123", @@ -1178,7 +1229,10 @@ def authorization_information_payload(self, user_payload): } def test_deserialize_authorization_information( - self, entity_factory_impl, authorization_information_payload, user_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + authorization_information_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], ): authorization_information = entity_factory_impl.deserialize_authorization_information( authorization_information_payload @@ -1203,7 +1257,9 @@ def test_deserialize_authorization_information( assert authorization_information.user == entity_factory_impl.deserialize_user(user_payload) def test_deserialize_authorization_information_with_unset_fields( - self, entity_factory_impl, authorization_information_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + authorization_information_payload: dict[str, typing.Any], ): del authorization_information_payload["application"]["icon"] del authorization_information_payload["application"]["bot_public"] @@ -1223,7 +1279,7 @@ def test_deserialize_authorization_information_with_unset_fields( assert authorization_information.application.privacy_policy_url is None @pytest.fixture - def application_connection_metadata_record_payload(self): + def application_connection_metadata_record_payload(self) -> dict[str, typing.Any]: return { "type": 7, "key": "developer_value", @@ -1237,7 +1293,9 @@ def application_connection_metadata_record_payload(self): } def test_deserialize_application_connection_metadata_record( - self, entity_factory_impl, application_connection_metadata_record_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + application_connection_metadata_record_payload: dict[str, typing.Any], ): record = entity_factory_impl.deserialize_application_connection_metadata_record( application_connection_metadata_record_payload @@ -1254,7 +1312,9 @@ def test_deserialize_application_connection_metadata_record( } def test_deserialize_application_connection_metadata_record_with_missing_fields( - self, entity_factory_impl, application_connection_metadata_record_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + application_connection_metadata_record_payload: dict[str, typing.Any], ): del application_connection_metadata_record_payload["name_localizations"] del application_connection_metadata_record_payload["description_localizations"] @@ -1266,7 +1326,9 @@ def test_deserialize_application_connection_metadata_record_with_missing_fields( assert record.name_localizations == {} assert record.description_localizations == {} - def test_serialize_application_connection_metadata_record(self, entity_factory_impl): + def test_serialize_application_connection_metadata_record( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): record = application_models.ApplicationRoleConnectionMetadataRecord( type=application_models.ApplicationRoleConnectionMetadataRecordType.BOOLEAN_EQUAL, key="some_key", @@ -1288,7 +1350,7 @@ def test_serialize_application_connection_metadata_record(self, entity_factory_i assert entity_factory_impl.serialize_application_connection_metadata_record(record) == expected_result @pytest.fixture - def client_credentials_payload(self): + def client_credentials_payload(self) -> dict[str, typing.Any]: return { "access_token": "6qrZcUqja7812RVdnEKjpzOL4CvHBFG", "token_type": "Bearer", @@ -1296,7 +1358,9 @@ def client_credentials_payload(self): "scope": "identify connections", } - def test_deserialize_partial_token(self, entity_factory_impl, client_credentials_payload): + def test_deserialize_partial_token( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, client_credentials_payload: dict[str, typing.Any] + ): partial_token = entity_factory_impl.deserialize_partial_token(client_credentials_payload) assert partial_token.access_token == "6qrZcUqja7812RVdnEKjpzOL4CvHBFG" @@ -1309,7 +1373,9 @@ def test_deserialize_partial_token(self, entity_factory_impl, client_credentials assert isinstance(partial_token, application_models.PartialOAuth2Token) @pytest.fixture - def access_token_payload(self, rest_guild_payload, incoming_webhook_payload): + def access_token_payload( + self, rest_guild_payload: dict[str, typing.Any], incoming_webhook_payload: dict[str, typing.Any] + ) -> dict[str, typing.Any]: return { "token_type": "Bearer", "guild": rest_guild_payload, @@ -1321,7 +1387,11 @@ def access_token_payload(self, rest_guild_payload, incoming_webhook_payload): } def test_deserialize_authorization_token( - self, entity_factory_impl, access_token_payload, rest_guild_payload, incoming_webhook_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + access_token_payload: dict[str, typing.Any], + rest_guild_payload: dict[str, typing.Any], + incoming_webhook_payload: dict[str, typing.Any], ): access_token = entity_factory_impl.deserialize_authorization_token(access_token_payload) @@ -1336,7 +1406,9 @@ def test_deserialize_authorization_token( assert access_token.refresh_token == "mgp8qnvBwJcmadwgCYKyYD5CAzGAX4" assert access_token.webhook == entity_factory_impl.deserialize_incoming_webhook(incoming_webhook_payload) - def test_deserialize_authorization_token_without_optional_fields(self, entity_factory_impl, access_token_payload): + def test_deserialize_authorization_token_without_optional_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, access_token_payload: dict[str, typing.Any] + ): del access_token_payload["guild"] del access_token_payload["webhook"] @@ -1346,7 +1418,7 @@ def test_deserialize_authorization_token_without_optional_fields(self, entity_fa assert access_token.webhook is None @pytest.fixture - def implicit_token_payload(self): + def implicit_token_payload(self) -> dict[str, typing.Any]: return { "access_token": "RTfP0OK99U3kbRtHOoKLmJbOn45PjL", "token_type": "Basic", @@ -1355,7 +1427,9 @@ def implicit_token_payload(self): "state": "15773059ghq9183habn", } - def test_deserialize_implicit_token(self, entity_factory_impl, implicit_token_payload): + def test_deserialize_implicit_token( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, implicit_token_payload: dict[str, str] + ): implicit_token = entity_factory_impl.deserialize_implicit_token(implicit_token_payload) assert implicit_token.access_token == "RTfP0OK99U3kbRtHOoKLmJbOn45PjL" @@ -1365,7 +1439,9 @@ def test_deserialize_implicit_token(self, entity_factory_impl, implicit_token_pa assert implicit_token.state == "15773059ghq9183habn" assert isinstance(implicit_token, application_models.OAuth2ImplicitToken) - def test_deserialize_implicit_token_without_state(self, entity_factory_impl, implicit_token_payload): + def test_deserialize_implicit_token_without_state( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, implicit_token_payload: dict[str, str] + ): del implicit_token_payload["state"] implicit_token = entity_factory_impl.deserialize_implicit_token(implicit_token_payload) @@ -1376,16 +1452,16 @@ def test_deserialize_implicit_token_without_state(self, entity_factory_impl, imp # AUDIT LOGS MODELS # ##################### - def test__deserialize_audit_log_change_roles(self, entity_factory_impl): + def test__deserialize_audit_log_change_roles(self, entity_factory_impl: entity_factory.EntityFactoryImpl): test_role_payloads = [{"id": "24", "name": "roleA"}] roles = entity_factory_impl._deserialize_audit_log_change_roles(test_role_payloads) assert len(roles) == 1 - role = roles[24] + role = roles[snowflakes.Snowflake(24)] assert role.id == 24 assert role.name == "roleA" assert isinstance(role, guild_models.PartialRole) - def test__deserialize_audit_log_overwrites(self, entity_factory_impl): + def test__deserialize_audit_log_overwrites(self, entity_factory_impl: entity_factory.EntityFactoryImpl): test_overwrite_payloads = [ {"id": "24", "type": 0, "allow": "21", "deny": "0"}, {"id": "48", "type": 1, "deny": "42", "allow": "0"}, @@ -1401,10 +1477,12 @@ def test__deserialize_audit_log_overwrites(self, entity_factory_impl): } @pytest.fixture - def overwrite_info_payload(self): + def overwrite_info_payload(self) -> dict[str, typing.Any]: return {"id": "123123123", "type": 0, "role_name": "aRole"} - def test__deserialize_channel_overwrite_entry_info(self, entity_factory_impl, overwrite_info_payload): + def test__deserialize_channel_overwrite_entry_info( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, overwrite_info_payload: dict[str, typing.Any] + ): overwrite_entry_info = entity_factory_impl._deserialize_channel_overwrite_entry_info(overwrite_info_payload) assert overwrite_entry_info.id == 123123123 assert overwrite_entry_info.type is channel_models.PermissionOverwriteType.ROLE @@ -1412,30 +1490,38 @@ def test__deserialize_channel_overwrite_entry_info(self, entity_factory_impl, ov assert isinstance(overwrite_entry_info, audit_log_models.ChannelOverwriteEntryInfo) @pytest.fixture - def message_pin_info_payload(self): + def message_pin_info_payload(self) -> dict[str, typing.Any]: return {"channel_id": "123123123", "message_id": "69696969"} - def test__deserialize_message_pin_entry_info(self, entity_factory_impl, message_pin_info_payload): + def test__deserialize_message_pin_entry_info( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, message_pin_info_payload: dict[str, typing.Any] + ): message_pin_info = entity_factory_impl._deserialize_message_pin_entry_info(message_pin_info_payload) assert message_pin_info.channel_id == 123123123 assert message_pin_info.message_id == 69696969 assert isinstance(message_pin_info, audit_log_models.MessagePinEntryInfo) @pytest.fixture - def member_prune_info_payload(self): + def member_prune_info_payload(self) -> dict[str, typing.Any]: return {"delete_member_days": "7", "members_removed": "1"} - def test__deserialize_member_prune_entry_info(self, entity_factory_impl, member_prune_info_payload): + def test__deserialize_member_prune_entry_info( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, member_prune_info_payload: dict[str, typing.Any] + ): member_prune_info = entity_factory_impl._deserialize_member_prune_entry_info(member_prune_info_payload) assert member_prune_info.delete_member_days == datetime.timedelta(days=7) assert member_prune_info.members_removed == 1 assert isinstance(member_prune_info, audit_log_models.MemberPruneEntryInfo) @pytest.fixture - def message_bulk_delete_info_payload(self): + def message_bulk_delete_info_payload(self) -> dict[str, typing.Any]: return {"count": "42"} - def test__deserialize_message_bulk_delete_entry_info(self, entity_factory_impl, message_bulk_delete_info_payload): + def test__deserialize_message_bulk_delete_entry_info( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + message_bulk_delete_info_payload: dict[str, typing.Any], + ): message_bulk_delete_entry_info = entity_factory_impl._deserialize_message_bulk_delete_entry_info( message_bulk_delete_info_payload ) @@ -1443,10 +1529,12 @@ def test__deserialize_message_bulk_delete_entry_info(self, entity_factory_impl, assert isinstance(message_bulk_delete_entry_info, audit_log_models.MessageBulkDeleteEntryInfo) @pytest.fixture - def message_delete_info_payload(self): + def message_delete_info_payload(self) -> dict[str, typing.Any]: return {"count": "42", "channel_id": "4206942069"} - def test__deserialize_message_delete_entry_info(self, entity_factory_impl, message_delete_info_payload): + def test__deserialize_message_delete_entry_info( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, message_delete_info_payload: dict[str, typing.Any] + ): message_delete_entry_info = entity_factory_impl._deserialize_message_delete_entry_info( message_delete_info_payload ) @@ -1455,10 +1543,14 @@ def test__deserialize_message_delete_entry_info(self, entity_factory_impl, messa assert isinstance(message_delete_entry_info, audit_log_models.MessageDeleteEntryInfo) @pytest.fixture - def member_disconnect_info_payload(self): + def member_disconnect_info_payload(self) -> dict[str, typing.Any]: return {"count": "42"} - def test__deserialize_member_disconnect_entry_info(self, entity_factory_impl, member_disconnect_info_payload): + def test__deserialize_member_disconnect_entry_info( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + member_disconnect_info_payload: dict[str, typing.Any], + ): member_disconnect_entry_info = entity_factory_impl._deserialize_member_disconnect_entry_info( member_disconnect_info_payload ) @@ -1466,16 +1558,18 @@ def test__deserialize_member_disconnect_entry_info(self, entity_factory_impl, me assert isinstance(member_disconnect_entry_info, audit_log_models.MemberDisconnectEntryInfo) @pytest.fixture - def member_move_info_payload(self): + def member_move_info_payload(self) -> dict[str, typing.Any]: return {"count": "42", "channel_id": "22222222"} - def test__deserialize_member_move_entry_info(self, entity_factory_impl, member_move_info_payload): + def test__deserialize_member_move_entry_info( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, member_move_info_payload: dict[str, typing.Any] + ): member_move_entry_info = entity_factory_impl._deserialize_member_move_entry_info(member_move_info_payload) assert member_move_entry_info.channel_id == 22222222 assert isinstance(member_move_entry_info, audit_log_models.MemberMoveEntryInfo) @pytest.fixture - def audit_log_entry_payload(self): + def audit_log_entry_payload(self) -> dict[str, typing.Any]: return { "action_type": 14, "changes": [ @@ -1493,32 +1587,34 @@ def audit_log_entry_payload(self): } @pytest.fixture - def partial_integration_payload(self): + def partial_integration_payload(self) -> dict[str, typing.Any]: return {"id": "4949494949", "name": "Blah blah", "type": "twitch", "account": {"id": "543453", "name": "Blam"}} def test_deserialize_audit_log_entry( self, - entity_factory_impl, - auto_mod_rule_payload, - audit_log_entry_payload, - application_webhook_payload, - incoming_webhook_payload, - follower_webhook_payload, - partial_integration_payload, - mock_app, + entity_factory_impl: entity_factory.EntityFactoryImpl, + auto_mod_rule_payload: dict[str, typing.Any], + audit_log_entry_payload: dict[str, typing.Any], + application_webhook_payload: dict[str, typing.Any], + incoming_webhook_payload: dict[str, typing.Any], + follower_webhook_payload: dict[str, typing.Any], + partial_integration_payload: dict[str, typing.Any], + hikari_app: traits.RESTAware, ): entry = entity_factory_impl.deserialize_audit_log_entry( audit_log_entry_payload, guild_id=snowflakes.Snowflake(123321) ) - assert entry.app is mock_app + assert entry.app is hikari_app assert entry.id == 694026906592477214 assert entry.target_id == 115590097100865541 assert entry.user_id == 560984860634644482 assert entry.action_type == audit_log_models.AuditLogEventType.CHANNEL_OVERWRITE_UPDATE - assert entry.options.id == 115590097100865541 - assert entry.options.type == channel_models.PermissionOverwriteType.MEMBER - assert entry.options.role_name is None + assert entry.options.id == 115590097100865541 # FIXME: I am unsure as to how to fix this + assert ( + entry.options.type == channel_models.PermissionOverwriteType.MEMBER + ) # FIXME: I am unsure as to how to fix this + assert entry.options.role_name is None # FIXME: I am unsure as to how to fix this assert entry.guild_id == 123321 assert entry.reason == "An artificial insanity." @@ -1526,20 +1622,22 @@ def test_deserialize_audit_log_entry( change = entry.changes[0] assert change.key == audit_log_models.AuditLogChangeKey.ADD_ROLE_TO_MEMBER + assert change.new_value is not None assert len(change.new_value) == 1 role = change.new_value[568651298858074123] - role.app is mock_app - role.id == 568651298858074123 - role.name == "Casual" + assert role.app is hikari_app + assert role.id == 568651298858074123 + assert role.name == "Casual" + assert change.old_value is not None assert len(change.old_value) == 1 role = change.old_value[123123123312312] - role.app is mock_app - role.id == 123123123312312 - role.name == "aRole" + assert role.app is hikari_app + assert role.id == 123123123312312 + assert role.name == "aRole" def test_deserialize_audit_log_entry_when_guild_id_in_payload( - self, entity_factory_impl, audit_log_entry_payload, mock_app + self, entity_factory_impl: entity_factory.EntityFactoryImpl, audit_log_entry_payload: dict[str, typing.Any] ): audit_log_entry_payload["guild_id"] = 431123123 @@ -1548,7 +1646,7 @@ def test_deserialize_audit_log_entry_when_guild_id_in_payload( assert entry.guild_id == 431123123 def test_deserialize_audit_log_entry_with_unset_or_unknown_fields( - self, entity_factory_impl, audit_log_entry_payload + self, entity_factory_impl: entity_factory.EntityFactoryImpl, audit_log_entry_payload: dict[str, typing.Any] ): # Unset fields audit_log_entry_payload["changes"] = None @@ -1569,7 +1667,9 @@ def test_deserialize_audit_log_entry_with_unset_or_unknown_fields( assert entry.options is None assert entry.reason is None - def test_deserialize_audit_log_entry_with_unhandled_change_key(self, entity_factory_impl, audit_log_entry_payload): + def test_deserialize_audit_log_entry_with_unhandled_change_key( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, audit_log_entry_payload: dict[str, typing.Any] + ): # Unset fields audit_log_entry_payload["changes"][0]["key"] = "name" @@ -1583,7 +1683,9 @@ def test_deserialize_audit_log_entry_with_unhandled_change_key(self, entity_fact assert change.new_value == [{"id": "568651298858074123", "name": "Casual"}] assert change.old_value == [{"id": "123123123312312", "name": "aRole"}] - def test_deserialize_audit_log_entry_with_change_key_unknown(self, entity_factory_impl, audit_log_entry_payload): + def test_deserialize_audit_log_entry_with_change_key_unknown( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, audit_log_entry_payload: dict[str, typing.Any] + ): # Unset fields audit_log_entry_payload["changes"][0]["key"] = "unknown" @@ -1597,7 +1699,9 @@ def test_deserialize_audit_log_entry_with_change_key_unknown(self, entity_factor assert change.new_value == [{"id": "568651298858074123", "name": "Casual"}] assert change.old_value == [{"id": "123123123312312", "name": "aRole"}] - def test_deserialize_audit_log_entry_for_unknown_action_type(self, entity_factory_impl, audit_log_entry_payload): + def test_deserialize_audit_log_entry_for_unknown_action_type( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, audit_log_entry_payload: dict[str, typing.Any] + ): # Unset fields audit_log_entry_payload["action_type"] = 1000 audit_log_entry_payload["options"] = {"field1": "value1", "field2": 96} @@ -1610,17 +1714,17 @@ def test_deserialize_audit_log_entry_for_unknown_action_type(self, entity_factor @pytest.fixture def audit_log_payload( self, - audit_log_entry_payload, - auto_mod_rule_payload, - user_payload, - incoming_webhook_payload, - application_webhook_payload, - follower_webhook_payload, - partial_integration_payload, - guild_public_thread_payload, - guild_private_thread_payload, - guild_news_thread_payload, - ): + audit_log_entry_payload: dict[str, typing.Any], + auto_mod_rule_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + incoming_webhook_payload: dict[str, typing.Any], + application_webhook_payload: dict[str, typing.Any], + follower_webhook_payload: dict[str, typing.Any], + partial_integration_payload: dict[str, typing.Any], + guild_public_thread_payload: dict[str, typing.Any], + guild_private_thread_payload: dict[str, typing.Any], + guild_news_thread_payload: dict[str, typing.Any], + ) -> dict[str, typing.Any]: return { "audit_log_entries": [audit_log_entry_payload], "auto_moderation_rules": [auto_mod_rule_payload], @@ -1632,19 +1736,19 @@ def audit_log_payload( def test_deserialize_audit_log( self, - entity_factory_impl, - mock_app, - audit_log_payload, - audit_log_entry_payload, - auto_mod_rule_payload, - user_payload, - incoming_webhook_payload, - application_webhook_payload, - follower_webhook_payload, - partial_integration_payload, - guild_public_thread_payload, - guild_private_thread_payload, - guild_news_thread_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + audit_log_payload: dict[str, typing.Any], + audit_log_entry_payload: dict[str, typing.Any], + auto_mod_rule_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + incoming_webhook_payload: dict[str, typing.Any], + application_webhook_payload: dict[str, typing.Any], + follower_webhook_payload: dict[str, typing.Any], + partial_integration_payload: dict[str, typing.Any], + guild_public_thread_payload: dict[str, typing.Any], + guild_private_thread_payload: dict[str, typing.Any], + guild_news_thread_payload: dict[str, typing.Any], ): audit_log = entity_factory_impl.deserialize_audit_log(audit_log_payload, guild_id=snowflakes.Snowflake(123321)) @@ -1673,7 +1777,9 @@ def test_deserialize_audit_log( 752831914402115456: entity_factory_impl.deserialize_channel_follower_webhook(follower_webhook_payload), } - def test_deserialize_audit_log_with_action_type_unknown_gets_ignored(self, entity_factory_impl, audit_log_payload): + def test_deserialize_audit_log_with_action_type_unknown_gets_ignored( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, audit_log_payload: dict[str, typing.Any] + ): # Unset fields audit_log_payload["audit_log_entries"][0]["action_type"] = 1000 audit_log_payload["audit_log_entries"][0]["options"] = {"field1": "value1", "field2": 96} @@ -1683,7 +1789,10 @@ def test_deserialize_audit_log_with_action_type_unknown_gets_ignored(self, entit assert len(audit_log.entries) == 0 def test_deserialize_audit_log_skips_unknown_webhook_type( - self, entity_factory_impl, incoming_webhook_payload, application_webhook_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + incoming_webhook_payload: dict[str, typing.Any], + application_webhook_payload: dict[str, typing.Any], ): audit_log = entity_factory_impl.deserialize_audit_log( { @@ -1703,7 +1812,10 @@ def test_deserialize_audit_log_skips_unknown_webhook_type( } def test_deserialize_audit_log_skips_unknown_thread_type( - self, entity_factory_impl, guild_public_thread_payload, guild_private_thread_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + guild_public_thread_payload: dict[str, typing.Any], + guild_private_thread_payload: dict[str, typing.Any], ): audit_log = entity_factory_impl.deserialize_audit_log( { @@ -1722,7 +1834,9 @@ def test_deserialize_audit_log_skips_unknown_thread_type( 947690637610844210: entity_factory_impl.deserialize_guild_private_thread(guild_private_thread_payload), } - def test_deserialize_audit_log_skips_unknown_auto_mod_rule_type(self, entity_factory_impl, auto_mod_rule_payload): + def test_deserialize_audit_log_skips_unknown_auto_mod_rule_type( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, auto_mod_rule_payload: dict[str, typing.Any] + ): audit_log = entity_factory_impl.deserialize_audit_log( { "auto_moderation_rules": [{"id": "4949", "trigger_type": -6959595}, auto_mod_rule_payload], @@ -1743,14 +1857,16 @@ def test_deserialize_audit_log_skips_unknown_auto_mod_rule_type(self, entity_fac # CHANNEL MODELS # ################## - def test_deserialize_channel_follow(self, entity_factory_impl, mock_app): + def test_deserialize_channel_follow( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, hikari_app: traits.RESTAware + ): follow = entity_factory_impl.deserialize_channel_follow({"channel_id": "41231", "webhook_id": "939393"}) - assert follow.app is mock_app + assert follow.app is hikari_app assert follow.channel_id == 41231 assert follow.webhook_id == 939393 @pytest.mark.parametrize("type", [0, 1]) - def test_deserialize_permission_overwrite(self, entity_factory_impl, type): + def test_deserialize_permission_overwrite(self, entity_factory_impl: entity_factory.EntityFactoryImpl, type: int): permission_overwrite_payload = { "id": "4242", "type": type, @@ -1768,33 +1884,46 @@ def test_deserialize_permission_overwrite(self, entity_factory_impl, type): @pytest.mark.parametrize( "type", [channel_models.PermissionOverwriteType.MEMBER, channel_models.PermissionOverwriteType.ROLE] ) - def test_serialize_permission_overwrite(self, entity_factory_impl, type): + def test_serialize_permission_overwrite( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, type: channel_models.PermissionOverwriteType + ): overwrite = channel_models.PermissionOverwrite(id=123123, type=type, allow=42, deny=62) payload = entity_factory_impl.serialize_permission_overwrite(overwrite) assert payload == {"id": "123123", "type": int(type), "allow": "42", "deny": "62"} @pytest.fixture - def partial_channel_payload(self): + def partial_channel_payload(self) -> dict[str, typing.Any]: return {"id": "561884984214814750", "name": "general", "type": 0} - def test_deserialize_partial_channel(self, entity_factory_impl, mock_app, partial_channel_payload): + def test_deserialize_partial_channel( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + partial_channel_payload: dict[str, typing.Any], + ): partial_channel = entity_factory_impl.deserialize_partial_channel(partial_channel_payload) - assert partial_channel.app is mock_app + assert partial_channel.app is hikari_app assert partial_channel.id == 561884984214814750 assert partial_channel.name == "general" assert partial_channel.type == channel_models.ChannelType.GUILD_TEXT assert isinstance(partial_channel, channel_models.PartialChannel) - def test_deserialize_partial_channel_with_unset_fields(self, entity_factory_impl): + def test_deserialize_partial_channel_with_unset_fields(self, entity_factory_impl: entity_factory.EntityFactoryImpl): assert entity_factory_impl.deserialize_partial_channel({"id": "22", "type": 0}).name is None @pytest.fixture - def dm_channel_payload(self, user_payload): + def dm_channel_payload(self, user_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return {"id": "123", "last_message_id": "456", "type": 1, "recipients": [user_payload]} - def test_deserialize_dm_channel(self, entity_factory_impl, mock_app, dm_channel_payload, user_payload): + def test_deserialize_dm_channel( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + dm_channel_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + ): dm_channel = entity_factory_impl.deserialize_dm(dm_channel_payload) - assert dm_channel.app is mock_app + assert dm_channel.app is hikari_app assert dm_channel.id == 123 assert dm_channel.name is None assert dm_channel.last_message_id == 456 @@ -1802,18 +1931,22 @@ def test_deserialize_dm_channel(self, entity_factory_impl, mock_app, dm_channel_ assert dm_channel.recipient == entity_factory_impl.deserialize_user(user_payload) assert isinstance(dm_channel, channel_models.DMChannel) - def test_deserialize_dm_channel_with_null_fields(self, entity_factory_impl, user_payload): + def test_deserialize_dm_channel_with_null_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, user_payload: dict[str, typing.Any] + ): dm_channel = entity_factory_impl.deserialize_dm( {"id": "123", "last_message_id": None, "type": 1, "recipients": [user_payload]} ) assert dm_channel.last_message_id is None - def test_deserialize_dm_channel_with_unsetfields(self, entity_factory_impl, user_payload): + def test_deserialize_dm_channel_with_unsetfields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, user_payload: dict[str, typing.Any] + ): dm_channel = entity_factory_impl.deserialize_dm({"id": "123", "type": 1, "recipients": [user_payload]}) assert dm_channel.last_message_id is None @pytest.fixture - def group_dm_channel_payload(self, user_payload): + def group_dm_channel_payload(self, user_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "123", "name": "Secret Developer Group", @@ -1826,9 +1959,15 @@ def group_dm_channel_payload(self, user_payload): "recipients": [user_payload], } - def test_deserialize_group_dm_channel(self, entity_factory_impl, mock_app, group_dm_channel_payload, user_payload): + def test_deserialize_group_dm_channel( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + group_dm_channel_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + ): group_dm = entity_factory_impl.deserialize_group_dm(group_dm_channel_payload) - assert group_dm.app is mock_app + assert group_dm.app is hikari_app assert group_dm.id == 123 assert group_dm.name == "Secret Developer Group" assert group_dm.icon_hash == "123asdf123adsf" @@ -1839,7 +1978,9 @@ def test_deserialize_group_dm_channel(self, entity_factory_impl, mock_app, group assert group_dm.recipients == {115590097100865541: entity_factory_impl.deserialize_user(user_payload)} assert isinstance(group_dm, channel_models.GroupDMChannel) - def test_test_deserialize_group_dm_channel_with_unset_fields(self, entity_factory_impl, user_payload): + def test_test_deserialize_group_dm_channel_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, user_payload: dict[str, typing.Any] + ): group_dm = entity_factory_impl.deserialize_group_dm( { "id": "123", @@ -1855,7 +1996,7 @@ def test_test_deserialize_group_dm_channel_with_unset_fields(self, entity_factor assert group_dm.last_message_id is None @pytest.fixture - def guild_category_payload(self, permission_overwrite_payload): + def guild_category_payload(self, permission_overwrite_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "123", "permission_overwrites": [permission_overwrite_payload], @@ -1868,10 +2009,14 @@ def guild_category_payload(self, permission_overwrite_payload): } def test_deserialize_guild_category( - self, entity_factory_impl, mock_app, guild_category_payload, permission_overwrite_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + guild_category_payload: dict[str, typing.Any], + permission_overwrite_payload: dict[str, typing.Any], ): guild_category = entity_factory_impl.deserialize_guild_category(guild_category_payload) - assert guild_category.app is mock_app + assert guild_category.app is hikari_app assert guild_category.id == 123 assert guild_category.name == "Test" assert guild_category.type == channel_models.ChannelType.GUILD_CATEGORY @@ -1885,7 +2030,9 @@ def test_deserialize_guild_category( assert guild_category.parent_id is None assert isinstance(guild_category, channel_models.GuildCategory) - def test_deserialize_guild_category_with_unset_fields(self, entity_factory_impl, permission_overwrite_payload): + def test_deserialize_guild_category_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, permission_overwrite_payload: dict[str, typing.Any] + ): guild_category = entity_factory_impl.deserialize_guild_category( { "id": "123", @@ -1899,7 +2046,9 @@ def test_deserialize_guild_category_with_unset_fields(self, entity_factory_impl, assert guild_category.parent_id is None assert guild_category.is_nsfw is False - def test_deserialize_guild_category_with_null_fields(self, entity_factory_impl, permission_overwrite_payload): + def test_deserialize_guild_category_with_null_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, permission_overwrite_payload: dict[str, typing.Any] + ): guild_category = entity_factory_impl.deserialize_guild_category( { "id": "123", @@ -1915,10 +2064,14 @@ def test_deserialize_guild_category_with_null_fields(self, entity_factory_impl, assert guild_category.parent_id is None def test_deserialize_guild_text_channel( - self, entity_factory_impl, mock_app, guild_text_channel_payload, permission_overwrite_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + guild_text_channel_payload: dict[str, typing.Any], + permission_overwrite_payload: dict[str, typing.Any], ): guild_text_channel = entity_factory_impl.deserialize_guild_text_channel(guild_text_channel_payload) - assert guild_text_channel.app is mock_app + assert guild_text_channel.app is hikari_app assert guild_text_channel.id == 123 assert guild_text_channel.name == "general" assert guild_text_channel.type == channel_models.ChannelType.GUILD_TEXT @@ -1938,7 +2091,9 @@ def test_deserialize_guild_text_channel( assert guild_text_channel.default_auto_archive_duration == datetime.timedelta(minutes=10080) assert isinstance(guild_text_channel, channel_models.GuildTextChannel) - def test_deserialize_guild_text_channel_with_unset_fields(self, entity_factory_impl): + def test_deserialize_guild_text_channel_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): guild_text_channel = entity_factory_impl.deserialize_guild_text_channel( { "id": "123", @@ -1957,7 +2112,9 @@ def test_deserialize_guild_text_channel_with_unset_fields(self, entity_factory_i assert guild_text_channel.last_message_id is None assert guild_text_channel.default_auto_archive_duration == datetime.timedelta(minutes=1440) - def test_deserialize_guild_text_channel_with_null_fields(self, entity_factory_impl): + def test_deserialize_guild_text_channel_with_null_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): guild_text_channel = entity_factory_impl.deserialize_guild_text_channel( { "id": "123", @@ -1980,10 +2137,14 @@ def test_deserialize_guild_text_channel_with_null_fields(self, entity_factory_im assert guild_text_channel.parent_id is None def test_deserialize_guild_news_channel( - self, entity_factory_impl, mock_app, guild_news_channel_payload, permission_overwrite_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + guild_news_channel_payload: dict[str, typing.Any], + permission_overwrite_payload: dict[str, typing.Any], ): news_channel = entity_factory_impl.deserialize_guild_news_channel(guild_news_channel_payload) - assert news_channel.app is mock_app + assert news_channel.app is hikari_app assert news_channel.id == 7777 assert news_channel.name == "Important Announcements" assert news_channel.type == channel_models.ChannelType.GUILD_NEWS @@ -2002,7 +2163,9 @@ def test_deserialize_guild_news_channel( assert news_channel.default_auto_archive_duration == datetime.timedelta(minutes=4320) assert isinstance(news_channel, channel_models.GuildNewsChannel) - def test_deserialize_guild_news_channel_with_unset_fields(self, entity_factory_impl): + def test_deserialize_guild_news_channel_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): news_channel = entity_factory_impl.deserialize_guild_news_channel( { "id": "567", @@ -2020,7 +2183,9 @@ def test_deserialize_guild_news_channel_with_unset_fields(self, entity_factory_i assert news_channel.last_message_id is None assert news_channel.default_auto_archive_duration == datetime.timedelta(minutes=1440) - def test_deserialize_guild_news_channel_with_null_fields(self, entity_factory_impl): + def test_deserialize_guild_news_channel_with_null_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): news_channel = entity_factory_impl.deserialize_guild_news_channel( { "id": "567", @@ -2042,7 +2207,11 @@ def test_deserialize_guild_news_channel_with_null_fields(self, entity_factory_im assert news_channel.last_pin_timestamp is None def test_deserialize_guild_voice_channel( - self, entity_factory_impl, mock_app, guild_voice_channel_payload, permission_overwrite_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + guild_voice_channel_payload: dict[str, typing.Any], + permission_overwrite_payload: dict[str, typing.Any], ): voice_channel = entity_factory_impl.deserialize_guild_voice_channel(guild_voice_channel_payload) assert voice_channel.id == 555 @@ -2061,7 +2230,9 @@ def test_deserialize_guild_voice_channel( assert voice_channel.user_limit == 3 assert isinstance(voice_channel, channel_models.GuildVoiceChannel) - def test_deserialize_guild_voice_channel_with_null_fields(self, entity_factory_impl): + def test_deserialize_guild_voice_channel_with_null_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): voice_channel = entity_factory_impl.deserialize_guild_voice_channel( { "id": "123", @@ -2081,7 +2252,9 @@ def test_deserialize_guild_voice_channel_with_null_fields(self, entity_factory_i assert voice_channel.parent_id is None assert voice_channel.region is None - def test_deserialize_guild_voice_channel_with_unset_fields(self, entity_factory_impl): + def test_deserialize_guild_voice_channel_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): voice_channel = entity_factory_impl.deserialize_guild_voice_channel( { "id": "123", @@ -2100,7 +2273,7 @@ def test_deserialize_guild_voice_channel_with_unset_fields(self, entity_factory_ assert voice_channel.region is None @pytest.fixture - def guild_stage_channel_payload(self, permission_overwrite_payload): + def guild_stage_channel_payload(self, permission_overwrite_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "555", "guild_id": "666", @@ -2117,7 +2290,11 @@ def guild_stage_channel_payload(self, permission_overwrite_payload): } def test_deserialize_guild_stage_channel( - self, entity_factory_impl, mock_app, guild_stage_channel_payload, permission_overwrite_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + guild_stage_channel_payload: dict[str, typing.Any], + permission_overwrite_payload: dict[str, typing.Any], ): voice_channel = entity_factory_impl.deserialize_guild_stage_channel(guild_stage_channel_payload) assert voice_channel.id == 555 @@ -2136,7 +2313,9 @@ def test_deserialize_guild_stage_channel( assert voice_channel.last_message_id == 1000101 assert isinstance(voice_channel, channel_models.GuildStageChannel) - def test_deserialize_guild_stage_channel_with_null_fields(self, entity_factory_impl): + def test_deserialize_guild_stage_channel_with_null_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): voice_channel = entity_factory_impl.deserialize_guild_stage_channel( { "id": "123", @@ -2157,7 +2336,9 @@ def test_deserialize_guild_stage_channel_with_null_fields(self, entity_factory_i assert voice_channel.region is None assert voice_channel.last_message_id is None - def test_deserialize_guild_stage_channel_with_unset_fields(self, entity_factory_impl): + def test_deserialize_guild_stage_channel_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): voice_channel = entity_factory_impl.deserialize_guild_stage_channel( { "id": "123", @@ -2176,7 +2357,7 @@ def test_deserialize_guild_stage_channel_with_unset_fields(self, entity_factory_ assert voice_channel.last_message_id is None @pytest.fixture - def guild_forum_channel_payload(self, permission_overwrite_payload): + def guild_forum_channel_payload(self, permission_overwrite_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "961367432532987974", "type": 15, @@ -2208,10 +2389,14 @@ def guild_forum_channel_payload(self, permission_overwrite_payload): } def test_deserialize_guild_forum_channel( - self, entity_factory_impl, mock_app, guild_forum_channel_payload, permission_overwrite_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + guild_forum_channel_payload: dict[str, typing.Any], + permission_overwrite_payload: dict[str, typing.Any], ): forum_channel = entity_factory_impl.deserialize_guild_forum_channel(guild_forum_channel_payload) - assert forum_channel.app is mock_app + assert forum_channel.app is hikari_app assert forum_channel.id == 961367432532987974 assert forum_channel.name == "testing_forum_channel" assert forum_channel.topic == "A fun place to discuss fun stuff!" @@ -2249,7 +2434,9 @@ def test_deserialize_guild_forum_channel( assert isinstance(tag2, channel_models.ForumTag) assert isinstance(forum_channel, channel_models.GuildForumChannel) - def test_deserialize_guild_forum_channel_with_null_fields(self, entity_factory_impl, guild_forum_channel_payload): + def test_deserialize_guild_forum_channel_with_null_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, guild_forum_channel_payload: dict[str, typing.Any] + ): guild_forum_channel_payload["topic"] = None guild_forum_channel_payload["parent_id"] = None guild_forum_channel_payload["last_message_id"] = None @@ -2266,7 +2453,9 @@ def test_deserialize_guild_forum_channel_with_null_fields(self, entity_factory_i assert forum_channel.default_reaction_emoji_id is None assert forum_channel.default_reaction_emoji_name is None - def test_deserialize_guild_forum_channel_with_unset_fields(self, entity_factory_impl, guild_forum_channel_payload): + def test_deserialize_guild_forum_channel_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, guild_forum_channel_payload: dict[str, typing.Any] + ): del guild_forum_channel_payload["available_tags"] del guild_forum_channel_payload["default_reaction_emoji"] del guild_forum_channel_payload["nsfw"] @@ -2291,12 +2480,16 @@ def test_deserialize_guild_forum_channel_with_unset_fields(self, entity_factory_ assert forum_channel.default_layout == channel_models.ForumLayoutType.NOT_SET def test_deserialize_guild_media_channel( - self, entity_factory_impl, mock_app, guild_forum_channel_payload, permission_overwrite_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + guild_forum_channel_payload: dict[str, typing.Any], + permission_overwrite_payload: dict[str, typing.Any], ): guild_forum_channel_payload["type"] = 16 media_channel = entity_factory_impl.deserialize_guild_media_channel(guild_forum_channel_payload) - assert media_channel.app is mock_app + assert media_channel.app is hikari_app assert media_channel.id == 961367432532987974 assert media_channel.name == "testing_forum_channel" assert media_channel.topic == "A fun place to discuss fun stuff!" @@ -2334,7 +2527,9 @@ def test_deserialize_guild_media_channel( assert isinstance(tag2, channel_models.ForumTag) assert isinstance(media_channel, channel_models.GuildMediaChannel) - def test_deserialize_guild_media_channel_with_null_fields(self, entity_factory_impl, guild_forum_channel_payload): + def test_deserialize_guild_media_channel_with_null_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, guild_forum_channel_payload: dict[str, typing.Any] + ): guild_forum_channel_payload["type"] = 16 guild_forum_channel_payload["topic"] = None @@ -2353,7 +2548,9 @@ def test_deserialize_guild_media_channel_with_null_fields(self, entity_factory_i assert media_channel.default_reaction_emoji_id is None assert media_channel.default_reaction_emoji_name is None - def test_deserialize_guild_media_channel_with_unset_fields(self, entity_factory_impl, guild_forum_channel_payload): + def test_deserialize_guild_media_channel_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, guild_forum_channel_payload: dict[str, typing.Any] + ): guild_forum_channel_payload["type"] = 16 del guild_forum_channel_payload["available_tags"] @@ -2379,10 +2576,10 @@ def test_deserialize_guild_media_channel_with_unset_fields(self, entity_factory_ assert media_channel.default_sort_order == channel_models.ForumSortOrderType.LATEST_ACTIVITY assert media_channel.default_layout == channel_models.ForumLayoutType.NOT_SET - def test_serialize_forum_tag(self, entity_factory_impl): + def test_serialize_forum_tag(self, entity_factory_impl: entity_factory.EntityFactoryImpl): tag = channel_models.ForumTag(id=snowflakes.Snowflake(123), name="test", moderated=True, emoji=None) - unicode_emoji = object() - emoji_id = object() + unicode_emoji = mock.Mock() + emoji_id = mock.Mock() with mock.patch.object(channel_models.ForumTag, "unicode_emoji", new=unicode_emoji): with mock.patch.object(channel_models.ForumTag, "emoji_id", new=emoji_id): @@ -2408,7 +2605,9 @@ def test_deserialize_thread_member_with_passed_fields( self, entity_factory_impl: entity_factory.EntityFactoryImpl, thread_member_payload: dict[str, typing.Any] ): thread_member = entity_factory_impl.deserialize_thread_member( - {"join_timestamp": "2022-02-28T01:49:03.599821+00:00", "flags": 494949}, thread_id=123321, user_id=65132123 + {"join_timestamp": "2022-02-28T01:49:03.599821+00:00", "flags": 494949}, + thread_id=snowflakes.Snowflake(123321), + user_id=snowflakes.Snowflake(65132123), ) assert thread_member.thread_id == 123321 @@ -2460,6 +2659,7 @@ def test_deserialize_guild_thread_returns_right_type_with_passed_user_id( result = entity_factory_impl.deserialize_guild_thread(payload, user_id=snowflakes.Snowflake(763423454)) + assert result.member is not None assert result.member.user_id == 763423454 @pytest.mark.parametrize( @@ -2481,14 +2681,14 @@ def test_deserialize_guild_thread_handles_unknown_channel_type( def test_deserialize_guild_news_thread( self, entity_factory_impl: entity_factory.EntityFactoryImpl, - mock_app: traits.RESTAware, + hikari_app: traits.RESTAware, guild_news_thread_payload: dict[str, typing.Any], thread_member_payload: dict[str, typing.Any], ): thread = entity_factory_impl.deserialize_guild_news_thread(guild_news_thread_payload) assert thread.id == 946900871160164393 - assert thread.app is mock_app + assert thread.app is hikari_app assert thread.guild_id == 574921006817476608 assert thread.parent_id == 881729820747268137 assert thread.owner_id == 115590097100865541 @@ -2508,7 +2708,7 @@ def test_deserialize_guild_news_thread( assert thread.approximate_member_count == 3 assert thread.rate_limit_per_user == datetime.timedelta(seconds=53) assert thread.member == entity_factory_impl.deserialize_thread_member( - thread_member_payload, thread_id=946900871160164393 + thread_member_payload, thread_id=snowflakes.Snowflake(946900871160164393) ) assert isinstance(thread, channel_models.GuildNewsThread) @@ -2557,19 +2757,20 @@ def test_deserialize_guild_news_thread_when_passed_through_user_id( guild_news_thread_payload, user_id=snowflakes.Snowflake(763423454) ) + assert thread.member is not None assert thread.member.user_id == 763423454 def test_deserialize_guild_public_thread( self, entity_factory_impl: entity_factory.EntityFactoryImpl, - mock_app: traits.RESTAware, + hikari_app: traits.RESTAware, guild_public_thread_payload: dict[str, typing.Any], thread_member_payload: dict[str, typing.Any], ): thread = entity_factory_impl.deserialize_guild_public_thread(guild_public_thread_payload) assert thread.id == 947643783913308301 - assert thread.app is mock_app + assert thread.app is hikari_app assert thread.guild_id == 574921006817476608 assert thread.parent_id == 744183190998089820 assert thread.owner_id == 115590097100865541 @@ -2588,7 +2789,7 @@ def test_deserialize_guild_public_thread( assert thread.approximate_member_count == 3 assert thread.rate_limit_per_user == datetime.timedelta(seconds=23) assert thread.member == entity_factory_impl.deserialize_thread_member( - thread_member_payload, thread_id=947643783913308301 + thread_member_payload, thread_id=snowflakes.Snowflake(947643783913308301) ) assert thread.applied_tag_ids == [123, 456] @@ -2641,19 +2842,20 @@ def test_deserialize_guild_public_thread_when_passed_through_user_id( guild_public_thread_payload, user_id=snowflakes.Snowflake(22123) ) + assert thread.member is not None assert thread.member.user_id == 22123 def test_deserialize_guild_private_thread( self, entity_factory_impl: entity_factory.EntityFactoryImpl, - mock_app: traits.RESTAware, + hikari_app: traits.RESTAware, guild_private_thread_payload: dict[str, typing.Any], thread_member_payload: dict[str, typing.Any], ): thread = entity_factory_impl.deserialize_guild_private_thread(guild_private_thread_payload) assert thread.id == 947690637610844210 - assert thread.app is mock_app + assert thread.app is hikari_app assert thread.guild_id == 574921006817476608 assert thread.parent_id == 744183190998089820 assert thread.owner_id == 115590097100865541 @@ -2674,7 +2876,7 @@ def test_deserialize_guild_private_thread( assert thread.approximate_member_count == 3 assert thread.rate_limit_per_user == datetime.timedelta(seconds=0) assert thread.member == entity_factory_impl.deserialize_thread_member( - thread_member_payload, thread_id=947690637610844210 + thread_member_payload, thread_id=snowflakes.Snowflake(947690637610844210) ) def test_deserialize_guild_private_thread_when_null_fields( @@ -2722,6 +2924,7 @@ def test_deserialize_guild_private_thread_when_passed_through_user_id( guild_private_thread_payload, user_id=snowflakes.Snowflake(22123) ) + assert thread.member is not None assert thread.member.user_id == 22123 def test_deserialize_channel_returns_right_type( @@ -2781,7 +2984,9 @@ def test_deserialize_channel_when_passed_guild_id( assert isinstance(result, channel_models.GuildChannel) assert result.guild_id == 2394949234123 - def test_deserialize_channel_handles_unknown_channel_type(self, entity_factory_impl): + def test_deserialize_channel_handles_unknown_channel_type( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): with pytest.raises(errors.UnrecognisedEntityError): entity_factory_impl.deserialize_channel({"type": -9999999999}) @@ -2795,32 +3000,38 @@ def test_deserialize_channel_handles_unknown_channel_type(self, entity_factory_i (13, "deserialize_guild_stage_channel"), ], ) - def test_deserialize_channel_when_guild(self, mock_app, type_, fn): + def test_deserialize_channel_when_guild(self, hikari_app: traits.RESTAware, type_: int, fn: str): payload = {"type": type_} with mock.patch.object(entity_factory.EntityFactoryImpl, fn) as expected_fn: # We need to instantiate it after the mock so that the functions that are stored in the dicts # are the ones we mock - entity_factory_impl = entity_factory.EntityFactoryImpl(app=mock_app) + entity_factory_impl = entity_factory.EntityFactoryImpl(app=hikari_app) - assert entity_factory_impl.deserialize_channel(payload, guild_id=123) is expected_fn.return_value + assert ( + entity_factory_impl.deserialize_channel(payload, guild_id=snowflakes.Snowflake(123)) + is expected_fn.return_value + ) expected_fn.assert_called_once_with(payload, guild_id=123) @pytest.mark.parametrize(("type_", "fn"), [(1, "deserialize_dm"), (3, "deserialize_group_dm")]) - def test_deserialize_channel_when_dm(self, mock_app, type_, fn): + def test_deserialize_channel_when_dm(self, hikari_app: traits.RESTAware, type_: int, fn: str): payload = {"type": type_} with mock.patch.object(entity_factory.EntityFactoryImpl, fn) as expected_fn: # We need to instantiate it after the mock so that the functions that are stored in the dicts # are the ones we mock - entity_factory_impl = entity_factory.EntityFactoryImpl(app=mock_app) + entity_factory_impl = entity_factory.EntityFactoryImpl(app=hikari_app) - assert entity_factory_impl.deserialize_channel(payload, guild_id=123123123) is expected_fn.return_value + assert ( + entity_factory_impl.deserialize_channel(payload, guild_id=snowflakes.Snowflake(123123123)) + is expected_fn.return_value + ) expected_fn.assert_called_once_with(payload) - def test_deserialize_channel_when_unknown_type(self, entity_factory_impl): + def test_deserialize_channel_when_unknown_type(self, entity_factory_impl: entity_factory.EntityFactoryImpl): with pytest.raises(errors.UnrecognisedEntityError): entity_factory_impl.deserialize_channel({"type": -111}) @@ -2829,7 +3040,7 @@ def test_deserialize_channel_when_unknown_type(self, entity_factory_impl): ################ @pytest.fixture - def embed_payload(self): + def embed_payload(self) -> dict[str, typing.Any]: return { "title": "embed title", "description": "embed description", @@ -2869,7 +3080,9 @@ def embed_payload(self): "fields": [{"name": "title", "value": "some value", "inline": True}], } - def test_deserialize_embed_with_full_embed(self, entity_factory_impl, embed_payload): + def test_deserialize_embed_with_full_embed( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, embed_payload: dict[str, typing.Any] + ): embed = entity_factory_impl.deserialize_embed(embed_payload) assert embed.title == "embed title" assert embed.description == "embed description" @@ -2878,36 +3091,49 @@ def test_deserialize_embed_with_full_embed(self, entity_factory_impl, embed_payl assert embed.color == color_models.Color(14014915) assert isinstance(embed.color, color_models.Color) # EmbedFooter + assert embed.footer is not None assert embed.footer.text == "footer text" + assert embed.footer.icon is not None assert embed.footer.icon.resource.url == "https://somewhere.com/footer.png" + assert embed.footer.icon.proxy_resource is not None assert embed.footer.icon.proxy_resource.url == "https://media.somewhere.com/footer.png" assert isinstance(embed.footer, embed_models.EmbedFooter) # EmbedImage + assert embed.image is not None assert embed.image.resource.url == "https://somewhere.com/image.png" + assert embed.image.proxy_resource is not None assert embed.image.proxy_resource.url == "https://media.somewhere.com/image.png" assert embed.image.height == 122 assert embed.image.width == 133 assert isinstance(embed.image, embed_models.EmbedImage) # EmbedThumbnail + assert embed.thumbnail is not None assert embed.thumbnail.resource.url == "https://somewhere.com/thumbnail.png" + assert embed.thumbnail.proxy_resource is not None assert embed.thumbnail.proxy_resource.url == "https://media.somewhere.com/thumbnail.png" assert embed.thumbnail.height == 123 assert embed.thumbnail.width == 456 assert isinstance(embed.thumbnail, embed_models.EmbedImage) # EmbedVideo + assert embed.video is not None assert embed.video.resource.url == "https://somewhere.com/video.mp4" + assert embed.video.proxy_resource is not None assert embed.video.proxy_resource.url == "https://somewhere.com/proxy/video.mp4" assert embed.video.height == 1234 assert embed.video.width == 4567 assert isinstance(embed.video, embed_models.EmbedVideo) # EmbedProvider + assert embed.provider is not None assert embed.provider.name == "some name" assert embed.provider.url == "https://somewhere.com/provider" assert isinstance(embed.provider, embed_models.EmbedProvider) # EmbedAuthor + assert embed.author is not None assert embed.author.name == "some name" assert embed.author.url == "https://somewhere.com/author-url" + assert embed.author.icon is not None assert embed.author.icon.resource.url == "https://somewhere.com/author.png" + assert embed.author.icon.proxy_resource is not None assert embed.author.icon.proxy_resource.url == "https://media.somewhere.com/author.png" assert isinstance(embed.author, embed_models.EmbedAuthor) # EmbedField @@ -2918,7 +3144,9 @@ def test_deserialize_embed_with_full_embed(self, entity_factory_impl, embed_payl assert field.is_inline is True assert isinstance(field, embed_models.EmbedField) - def test_deserialize_embed_with_partial_sub_fields(self, entity_factory_impl, embed_payload): + def test_deserialize_embed_with_partial_sub_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, embed_payload: dict[str, typing.Any] + ): embed = entity_factory_impl.deserialize_embed( { "footer": {"text": "footer text"}, @@ -2930,32 +3158,40 @@ def test_deserialize_embed_with_partial_sub_fields(self, entity_factory_impl, em } ) # EmbedFooter + assert embed.footer is not None assert embed.footer.text == "footer text" assert embed.footer.icon is None # EmbedImage + assert embed.image is not None assert embed.image.resource.url == "https://blahblah.blahblahblah" assert embed.image.proxy_resource is None assert embed.image.width is None assert embed.image.height is None # EmbedThumbnail + assert embed.thumbnail is not None assert embed.thumbnail.resource.url == "https://blahblah2.blahblahblah" assert embed.thumbnail.proxy_resource is None assert embed.thumbnail.height is None assert embed.thumbnail.width is None # EmbedVideo + assert embed.video is not None assert embed.video.resource.url == "https://blahblah3.blahblahblah" assert embed.video.proxy_resource is None assert embed.video.height is None assert embed.video.width is None # EmbedProvider + assert embed.provider is not None assert embed.provider.name is None assert embed.provider.url == "https://blahbla5h.blahblahblah" # EmbedAuthor + assert embed.author is not None assert embed.author.name == "author name" assert embed.author.url is None assert embed.author.icon is None - def test_deserialize_embed_with_other_null_sub_fields(self, entity_factory_impl, embed_payload): + def test_deserialize_embed_with_other_null_sub_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, embed_payload: dict[str, typing.Any] + ): embed = entity_factory_impl.deserialize_embed( { "footer": {"text": "footer text"}, @@ -2965,14 +3201,18 @@ def test_deserialize_embed_with_other_null_sub_fields(self, entity_factory_impl, } ) # EmbedProvider + assert embed.provider is not None assert embed.provider.name == "name name" assert embed.provider.url is None # EmbedAuthor + assert embed.author is not None assert embed.author.name is None assert embed.author.url == "urlurlurl" assert embed.author.icon is None - def test_deserialize_embed_with_partial_fields(self, entity_factory_impl, embed_payload): + def test_deserialize_embed_with_partial_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, embed_payload: dict[str, typing.Any] + ): embed = entity_factory_impl.deserialize_embed( { "footer": {"text": "footer text"}, @@ -2985,6 +3225,7 @@ def test_deserialize_embed_with_partial_fields(self, entity_factory_impl, embed_ } ) # EmbedFooter + assert embed.footer is not None assert embed.footer.text == "footer text" assert embed.footer.icon is None # EmbedImage @@ -2996,6 +3237,7 @@ def test_deserialize_embed_with_partial_fields(self, entity_factory_impl, embed_ # EmbedProvider assert embed.provider is None # EmbedAuthor + assert embed.author is not None assert embed.author.name == "author name" assert embed.author.url is None assert embed.author.icon is None @@ -3007,7 +3249,7 @@ def test_deserialize_embed_with_partial_fields(self, entity_factory_impl, embed_ assert field.is_inline is False assert isinstance(field, embed_models.EmbedField) - def test_deserialize_embed_with_empty_embed(self, entity_factory_impl): + def test_deserialize_embed_with_empty_embed(self, entity_factory_impl: entity_factory.EntityFactoryImpl): embed = entity_factory_impl.deserialize_embed({}) assert embed.title is None assert embed.description is None @@ -3022,7 +3264,10 @@ def test_deserialize_embed_with_empty_embed(self, entity_factory_impl): assert embed.author is None assert embed.fields == [] - def test_serialize_embed_with_non_url_resources_provides_attachments(self, entity_factory_impl): + def test_serialize_embed_with_non_url_resources_provides_attachments( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): + # FIXME: I am unsure as to how to fix these below files.File() types. footer_icon = embed_models.EmbedResource(resource=files.File("cat.png")) thumbnail = embed_models.EmbedImage(resource=files.File("dog.png")) image = embed_models.EmbedImage(resource=files.Bytes(b"potato kung fu", "sushi.pdf")) @@ -3066,7 +3311,9 @@ def test_serialize_embed_with_non_url_resources_provides_attachments(self, entit "fields": [{"value": "VALUE", "name": "NAME", "inline": True}], } - def test_serialize_embed_with_url_resources_does_not_provide_attachments(self, entity_factory_impl): + def test_serialize_embed_with_url_resources_does_not_provide_attachments( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): class DummyWebResource(files.WebResource): @property def url(self) -> str: @@ -3076,6 +3323,7 @@ def url(self) -> str: def filename(self) -> str: return "lolbook.png" + # FIXME: I am unsure as to how to fix these below files.File() types. footer_icon = embed_models.EmbedResource(resource=files.URL("http://http.cat")) thumbnail = embed_models.EmbedImage(resource=DummyWebResource()) image = embed_models.EmbedImage(resource=files.URL("http://bazbork.com")) @@ -3120,7 +3368,7 @@ def filename(self) -> str: "fields": [{"value": "VALUE", "name": "NAME", "inline": True}], } - def test_serialize_embed_with_null_sub_fields(self, entity_factory_impl): + def test_serialize_embed_with_null_sub_fields(self, entity_factory_impl: entity_factory.EntityFactoryImpl): payload, resources = entity_factory_impl.serialize_embed( embed_models.Embed.from_received_embed( title="Title", @@ -3148,13 +3396,15 @@ def test_serialize_embed_with_null_sub_fields(self, entity_factory_impl): } assert resources == [] - def test_serialize_embed_with_null_attributes(self, entity_factory_impl): + def test_serialize_embed_with_null_attributes(self, entity_factory_impl: entity_factory.EntityFactoryImpl): assert entity_factory_impl.serialize_embed(embed_models.Embed()) == ({}, []) @pytest.mark.parametrize( "field_kwargs", [{"name": None, "value": "correct value"}, {"name": "correct value", "value": None}] ) - def test_serialize_embed_validators(self, entity_factory_impl, field_kwargs): + def test_serialize_embed_validators( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, field_kwargs: dict[str, typing.Any] + ): embed_obj = embed_models.Embed() embed_obj._fields = [embed_models.EmbedField(**field_kwargs)] with pytest.raises(TypeError): @@ -3164,12 +3414,14 @@ def test_serialize_embed_validators(self, entity_factory_impl, field_kwargs): # EMOJI MODELS # ################ - def test_deserialize_unicode_emoji(self, entity_factory_impl): + def test_deserialize_unicode_emoji(self, entity_factory_impl: entity_factory.EntityFactoryImpl): emoji = entity_factory_impl.deserialize_unicode_emoji({"name": "🤷"}) assert emoji.name == "🤷" assert isinstance(emoji, emoji_models.UnicodeEmoji) - def test_deserialize_custom_emoji(self, entity_factory_impl, mock_app, custom_emoji_payload): + def test_deserialize_custom_emoji( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, custom_emoji_payload: dict[str, typing.Any] + ): emoji = entity_factory_impl.deserialize_custom_emoji(custom_emoji_payload) assert emoji.id == snowflakes.Snowflake(691225175349395456) assert emoji.name == "test" @@ -3177,19 +3429,23 @@ def test_deserialize_custom_emoji(self, entity_factory_impl, mock_app, custom_em assert isinstance(emoji, emoji_models.CustomEmoji) def test_deserialize_custom_emoji_with_unset_and_null_fields( - self, entity_factory_impl, mock_app, custom_emoji_payload + self, entity_factory_impl: entity_factory.EntityFactoryImpl ): emoji = entity_factory_impl.deserialize_custom_emoji({"id": "691225175349395456", "name": None}) assert emoji.is_animated is False assert emoji.name is None def test_deserialize_known_custom_emoji( - self, entity_factory_impl, mock_app, user_payload, known_custom_emoji_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + user_payload: dict[str, typing.Any], + known_custom_emoji_payload: dict[str, typing.Any], ): emoji = entity_factory_impl.deserialize_known_custom_emoji( known_custom_emoji_payload, guild_id=snowflakes.Snowflake(1235123) ) - assert emoji.app is mock_app + assert emoji.app is hikari_app assert emoji.id == 12345 assert emoji.guild_id == 1235123 assert emoji.name == "testing" @@ -3201,7 +3457,9 @@ def test_deserialize_known_custom_emoji( assert emoji.is_available is True assert isinstance(emoji, emoji_models.KnownCustomEmoji) - def test_deserialize_known_custom_emoji_with_unset_fields(self, entity_factory_impl): + def test_deserialize_known_custom_emoji_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): emoji = entity_factory_impl.deserialize_known_custom_emoji( { "id": "12345", @@ -3220,7 +3478,12 @@ def test_deserialize_known_custom_emoji_with_unset_fields(self, entity_factory_i ("payload", "expected_type"), [({"name": "🤷"}, emoji_models.UnicodeEmoji), ({"id": "1234", "name": "test"}, emoji_models.CustomEmoji)], ) - def test_deserialize_emoji_returns_expected_type(self, entity_factory_impl, payload, expected_type): + def test_deserialize_emoji_returns_expected_type( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + payload: dict[str, typing.Any], + expected_type: typing.Union[typing.Type[emoji_models.UnicodeEmoji], typing.Type[emoji_models.CustomEmoji]], + ): isinstance(entity_factory_impl.deserialize_emoji(payload), expected_type) ################## @@ -3228,14 +3491,16 @@ def test_deserialize_emoji_returns_expected_type(self, entity_factory_impl, payl ################## @pytest.fixture - def gateway_bot_payload(self): + def gateway_bot_payload(self) -> dict[str, typing.Any]: return { "url": "wss://gateway.discord.gg", "shards": 1, "session_start_limit": {"total": 1000, "remaining": 991, "reset_after": 14170186, "max_concurrency": 5}, } - def test_deserialize_gateway_bot(self, entity_factory_impl, gateway_bot_payload): + def test_deserialize_gateway_bot( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, gateway_bot_payload: dict[str, typing.Any] + ): gateway_bot = entity_factory_impl.deserialize_gateway_bot_info(gateway_bot_payload) assert isinstance(gateway_bot, gateway_models.GatewayBotInfo) assert gateway_bot.url == "wss://gateway.discord.gg" @@ -3252,21 +3517,28 @@ def test_deserialize_gateway_bot(self, entity_factory_impl, gateway_bot_payload) ################ @pytest.fixture - def guild_embed_payload(self): + def guild_embed_payload(self) -> dict[str, typing.Any]: return {"channel_id": "123123123", "enabled": True} - def test_deserialize_widget_embed(self, entity_factory_impl, mock_app, guild_embed_payload): + def test_deserialize_widget_embed( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + guild_embed_payload: dict[str, typing.Any], + ): guild_embed = entity_factory_impl.deserialize_guild_widget(guild_embed_payload) - assert guild_embed.app is mock_app + assert guild_embed.app is hikari_app assert guild_embed.channel_id == 123123123 assert guild_embed.is_enabled is True assert isinstance(guild_embed, guild_models.GuildWidget) - def test_deserialize_guild_embed_with_null_fields(self, entity_factory_impl, mock_app): + def test_deserialize_guild_embed_with_null_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, hikari_app: traits.RESTAware + ): assert entity_factory_impl.deserialize_guild_widget({"channel_id": None, "enabled": True}).channel_id is None @pytest.fixture - def guild_welcome_screen_payload(self): + def guild_welcome_screen_payload(self) -> dict[str, typing.Any]: return { "description": "What does the fox say? Nico Nico Nico NIIIIIIIIIIIIIIIIIIIIIII!!!!", "welcome_channels": [ @@ -3292,7 +3564,12 @@ def guild_welcome_screen_payload(self): ], } - def test_deserialize_welcome_screen(self, entity_factory_impl, mock_app, guild_welcome_screen_payload): + def test_deserialize_welcome_screen( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + guild_welcome_screen_payload: dict[str, typing.Any], + ): welcome_screen = entity_factory_impl.deserialize_welcome_screen(guild_welcome_screen_payload) assert welcome_screen.description == "What does the fox say? Nico Nico Nico NIIIIIIIIIIIIIIIIIIIIIII!!!!" @@ -3314,7 +3591,9 @@ def test_deserialize_welcome_screen(self, entity_factory_impl, mock_app, guild_w assert welcome_screen.channels[3].emoji_name is None assert welcome_screen.channels[3].emoji_id == 49494949 - def test_serialize_welcome_channel_with_custom_emoji(self, entity_factory_impl, mock_app): + def test_serialize_welcome_channel_with_custom_emoji( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, hikari_app: traits.RESTAware + ): channel = guild_models.WelcomeChannel( channel_id=snowflakes.Snowflake(431231), description="meow", @@ -3325,7 +3604,9 @@ def test_serialize_welcome_channel_with_custom_emoji(self, entity_factory_impl, assert result == {"channel_id": "431231", "description": "meow", "emoji_id": "564123"} - def test_serialize_welcome_channel_with_unicode_emoji(self, entity_factory_impl, mock_app): + def test_serialize_welcome_channel_with_unicode_emoji( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, hikari_app: traits.RESTAware + ): channel = guild_models.WelcomeChannel( channel_id=snowflakes.Snowflake(4312311), description="meow1", @@ -3336,7 +3617,9 @@ def test_serialize_welcome_channel_with_unicode_emoji(self, entity_factory_impl, assert result == {"channel_id": "4312311", "description": "meow1", "emoji_name": "a"} - def test_serialize_welcome_channel_with_no_emoji(self, entity_factory_impl, mock_app): + def test_serialize_welcome_channel_with_no_emoji( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, hikari_app: traits.RESTAware + ): channel = guild_models.WelcomeChannel( channel_id=snowflakes.Snowflake(4312312), description="meow2", emoji_id=None, emoji_name=None ) @@ -3366,7 +3649,9 @@ def guild_onboarding_prompt_payload(self) -> typing.MutableMapping[str, typing.A } def test_deserialize_guild_onboarding_prompt( - self, entity_factory_impl, guild_onboarding_prompt_payload: typing.MutableMapping[str, typing.Any] + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + guild_onboarding_prompt_payload: typing.MutableMapping[str, typing.Any], ): prompt = entity_factory_impl._deserialize_guild_onboarding_prompt(guild_onboarding_prompt_payload) assert isinstance(prompt, guild_models.GuildOnboardingPrompt) @@ -3388,7 +3673,9 @@ def test_deserialize_guild_onboarding_prompt( ] def test_deserialize_guild_onboarding( - self, entity_factory_impl, guild_onboarding_prompt_payload: typing.MutableMapping[str, typing.Any] + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + guild_onboarding_prompt_payload: typing.MutableMapping[str, typing.Any], ): payload = { "default_channel_ids": ["123"], @@ -3405,10 +3692,16 @@ def test_deserialize_guild_onboarding( assert len(onboarding.prompts) == 1 assert isinstance(onboarding.prompts[0], guild_models.GuildOnboardingPrompt) - def test_deserialize_member(self, entity_factory_impl, mock_app, member_payload, user_payload): + def test_deserialize_member( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + member_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + ): member_payload = {**member_payload, "guild_id": "76543325"} member = entity_factory_impl.deserialize_member(member_payload) - assert member.app is mock_app + assert member.app is hikari_app assert member.guild_id == 76543325 assert member.guild_avatar_hash == "estrogen" assert member.user == entity_factory_impl.deserialize_user(user_payload) @@ -3426,14 +3719,18 @@ def test_deserialize_member(self, entity_factory_impl, mock_app, member_payload, assert isinstance(member, guild_models.Member) def test_deserialize_member_when_guild_id_already_in_role_array( - self, entity_factory_impl, mock_app, member_payload, user_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + member_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], ): # While this isn't a legitimate case based on the current behaviour of the API, we still want to cover this # to ensure no duplication occurs. member_payload = {**member_payload, "guild_id": "76543325"} member_payload["roles"] = [11111, 22222, 76543325, 33333, 44444] member = entity_factory_impl.deserialize_member(member_payload) - assert member.app is mock_app + assert member.app is hikari_app assert member.guild_id == 76543325 assert member.user == entity_factory_impl.deserialize_user(user_payload) assert member.nickname == "foobarbaz" @@ -3445,7 +3742,9 @@ def test_deserialize_member_when_guild_id_already_in_role_array( assert member.guild_flags == guild_models.GuildMemberFlags.DID_REJOIN assert isinstance(member, guild_models.Member) - def test_deserialize_member_with_null_fields(self, entity_factory_impl, user_payload): + def test_deserialize_member_with_null_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, user_payload: dict[str, typing.Any] + ): member = entity_factory_impl.deserialize_member( { "nick": None, @@ -3467,7 +3766,9 @@ def test_deserialize_member_with_null_fields(self, entity_factory_impl, user_pay assert member.joined_at is None assert isinstance(member, guild_models.Member) - def test_deserialize_member_with_undefined_fields(self, entity_factory_impl, user_payload): + def test_deserialize_member_with_undefined_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, user_payload: dict[str, typing.Any] + ): member = entity_factory_impl.deserialize_member( { "roles": ["11111", "22222", "33333", "44444"], @@ -3483,7 +3784,9 @@ def test_deserialize_member_with_undefined_fields(self, entity_factory_impl, use assert member.is_mute is undefined.UNDEFINED assert member.is_pending is undefined.UNDEFINED - def test_deserialize_member_with_passed_through_user_object_and_guild_id(self, entity_factory_impl): + def test_deserialize_member_with_passed_through_user_object_and_guild_id( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): mock_user = mock.Mock(user_models.UserImpl) member = entity_factory_impl.deserialize_member( { @@ -3500,9 +3803,14 @@ def test_deserialize_member_with_passed_through_user_object_and_guild_id(self, e assert member.user is mock_user assert member.guild_id == 64234 - def test_deserialize_role(self, entity_factory_impl, mock_app, guild_role_payload): + def test_deserialize_role( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + guild_role_payload: dict[str, typing.Any], + ): guild_role = entity_factory_impl.deserialize_role(guild_role_payload, guild_id=snowflakes.Snowflake(76534453)) - assert guild_role.app is mock_app + assert guild_role.app is hikari_app assert guild_role.id == 41771983423143936 assert guild_role.guild_id == 76534453 assert guild_role.name == "WE DEM BOYZZ!!!!!!" @@ -3523,7 +3831,9 @@ def test_deserialize_role(self, entity_factory_impl, mock_app, guild_role_payloa assert guild_role.is_available_for_purchase is True assert isinstance(guild_role, guild_models.Role) - def test_deserialize_role_with_missing_or_unset_fields(self, entity_factory_impl, guild_role_payload): + def test_deserialize_role_with_missing_or_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, guild_role_payload: dict[str, typing.Any] + ): guild_role_payload["tags"] = {} guild_role_payload["unicode_emoji"] = None guild_role = entity_factory_impl.deserialize_role(guild_role_payload, guild_id=snowflakes.Snowflake(76534453)) @@ -3535,14 +3845,18 @@ def test_deserialize_role_with_missing_or_unset_fields(self, entity_factory_impl assert guild_role.is_available_for_purchase is False assert guild_role.unicode_emoji is None - def test_deserialize_role_with_no_tags(self, entity_factory_impl, guild_role_payload): + def test_deserialize_role_with_no_tags( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, guild_role_payload: dict[str, typing.Any] + ): del guild_role_payload["tags"] guild_role = entity_factory_impl.deserialize_role(guild_role_payload, guild_id=snowflakes.Snowflake(76534453)) assert guild_role.bot_id is None assert guild_role.integration_id is None assert guild_role.is_premium_subscriber_role is False - def test_deserialize_partial_integration(self, entity_factory_impl, partial_integration_payload): + def test_deserialize_partial_integration( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, partial_integration_payload: dict[str, typing.Any] + ): partial_integration = entity_factory_impl.deserialize_partial_integration(partial_integration_payload) assert partial_integration.id == 4949494949 assert partial_integration.name == "Blah blah" @@ -3554,7 +3868,7 @@ def test_deserialize_partial_integration(self, entity_factory_impl, partial_inte assert isinstance(partial_integration.account, guild_models.IntegrationAccount) @pytest.fixture - def integration_payload(self, user_payload): + def integration_payload(self, user_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "420", "name": "blaze it", @@ -3586,7 +3900,12 @@ def integration_payload(self, user_payload): }, } - def test_deserialize_integration(self, entity_factory_impl, integration_payload, user_payload): + def test_deserialize_integration( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + integration_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + ): integration = entity_factory_impl.deserialize_integration(integration_payload) assert integration.id == 420 assert integration.guild_id == 9292929292 @@ -3610,6 +3929,7 @@ def test_deserialize_integration(self, entity_factory_impl, integration_payload, ) assert integration.subscriber_count == 69 # IntegrationApplication + assert integration.application is not None assert integration.application.id == 123 assert integration.application.name == "some bot" assert integration.application.icon_hash == "123abc" @@ -3619,7 +3939,9 @@ def test_deserialize_integration(self, entity_factory_impl, integration_payload, ) assert isinstance(integration, guild_models.Integration) - def test_deserialize_guild_integration_with_null_and_unset_fields(self, entity_factory_impl): + def test_deserialize_guild_integration_with_null_and_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): integration = entity_factory_impl.deserialize_integration( { "id": "420", @@ -3644,7 +3966,7 @@ def test_deserialize_guild_integration_with_null_and_unset_fields(self, entity_f assert integration.is_syncing is None assert integration.subscriber_count is None - def test_deserialize_guild_integration_with_unset_bot(self, entity_factory_impl): + def test_deserialize_guild_integration_with_unset_bot(self, entity_factory_impl: entity_factory.EntityFactoryImpl): integration = entity_factory_impl.deserialize_integration( { "id": "420", @@ -3664,23 +3986,31 @@ def test_deserialize_guild_integration_with_unset_bot(self, entity_factory_impl) }, } ) + assert integration.application is not None assert integration.application.bot is None @pytest.fixture - def guild_member_ban_payload(self, user_payload): + def guild_member_ban_payload(self, user_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return {"reason": "Get nyaa'ed", "user": user_payload} - def test_deserialize_guild_member_ban(self, entity_factory_impl, guild_member_ban_payload, user_payload): + def test_deserialize_guild_member_ban( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + guild_member_ban_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + ): member_ban = entity_factory_impl.deserialize_guild_member_ban(guild_member_ban_payload) assert member_ban.reason == "Get nyaa'ed" assert member_ban.user == entity_factory_impl.deserialize_user(user_payload) assert isinstance(member_ban, guild_models.GuildBan) - def test_deserialize_guild_member_ban_with_null_fields(self, entity_factory_impl, user_payload): + def test_deserialize_guild_member_ban_with_null_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, user_payload: dict[str, typing.Any] + ): assert entity_factory_impl.deserialize_guild_member_ban({"reason": None, "user": user_payload}).reason is None @pytest.fixture - def guild_preview_payload(self, known_custom_emoji_payload): + def guild_preview_payload(self, known_custom_emoji_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "152559372126519269", "name": "Isopropyl", @@ -3695,10 +4025,14 @@ def guild_preview_payload(self, known_custom_emoji_payload): } def test_deserialize_guild_preview( - self, entity_factory_impl, mock_app, guild_preview_payload, known_custom_emoji_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + guild_preview_payload: dict[str, typing.Any], + known_custom_emoji_payload: dict[str, typing.Any], ): guild_preview = entity_factory_impl.deserialize_guild_preview(guild_preview_payload) - assert guild_preview.app is mock_app + assert guild_preview.app is hikari_app assert guild_preview.id == 152559372126519269 assert guild_preview.name == "Isopropyl" assert guild_preview.icon_hash == "d4a983885dsaa7691ce8bcaaf945a" @@ -3715,7 +4049,7 @@ def test_deserialize_guild_preview( assert guild_preview.description == "A DESCRIPTION." assert isinstance(guild_preview, guild_models.GuildPreview) - def test_deserialize_guild_preview_with_null_fields(self, entity_factory_impl, mock_app, guild_preview_payload): + def test_deserialize_guild_preview_with_null_fields(self, entity_factory_impl: entity_factory.EntityFactoryImpl): guild_preview = entity_factory_impl.deserialize_guild_preview( { "id": "152559372126519269", @@ -3736,7 +4070,7 @@ def test_deserialize_guild_preview_with_null_fields(self, entity_factory_impl, m assert guild_preview.description is None @pytest.fixture - def guild_incidents_payload(self): + def guild_incidents_payload(self) -> dict[str, typing.Any]: return { "invites_disabled_until": "2023-10-01T00:00:00.000000+00:00", "dms_disabled_until": None, @@ -3745,10 +4079,7 @@ def guild_incidents_payload(self): } def test_deserialize_guild_incidents( - self, - entity_factory_impl: entity_factory.EntityFactoryImpl, - mock_app: traits.RESTAware, - guild_incidents_payload: dict[str, typing.Any], + self, entity_factory_impl: entity_factory.EntityFactoryImpl, guild_incidents_payload: dict[str, typing.Any] ): incidents = entity_factory_impl.deserialize_guild_incidents(guild_incidents_payload) assert incidents.invites_disabled_until == datetime.datetime( @@ -3768,7 +4099,12 @@ def test__deserialize_guild_incidents_with_null_payload( assert incidents.raid_detected_at is None @pytest.fixture - def rest_guild_payload(self, known_custom_emoji_payload, guild_sticker_payload, guild_role_payload): + def rest_guild_payload( + self, + known_custom_emoji_payload: dict[str, typing.Any], + guild_sticker_payload: dict[str, typing.Any], + guild_role_payload: dict[str, typing.Any], + ) -> dict[str, typing.Any]: return { "afk_channel_id": "99998888777766", "afk_timeout": 1200, @@ -3817,15 +4153,15 @@ def rest_guild_payload(self, known_custom_emoji_payload, guild_sticker_payload, def test_deserialize_rest_guild( self, - entity_factory_impl, - mock_app, - rest_guild_payload, - known_custom_emoji_payload, - guild_role_payload, - guild_sticker_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + rest_guild_payload: dict[str, typing.Any], + known_custom_emoji_payload: dict[str, typing.Any], + guild_role_payload: dict[str, typing.Any], + guild_sticker_payload: dict[str, typing.Any], ): guild = entity_factory_impl.deserialize_rest_guild(rest_guild_payload) - assert guild.app is mock_app + assert guild.app is hikari_app assert guild.id == 265828729970753537 assert guild.name == "L33t guild" assert guild.icon_hash == "1a2b3c4d" @@ -3884,7 +4220,7 @@ def test_deserialize_rest_guild( assert guild.approximate_active_member_count == 7 assert guild.nsfw_level == guild_models.GuildNSFWLevel.DEFAULT - def test_deserialize_rest_guild_with_unset_fields(self, entity_factory_impl): + def test_deserialize_rest_guild_with_unset_fields(self, entity_factory_impl: entity_factory.EntityFactoryImpl): guild = entity_factory_impl.deserialize_rest_guild( { "afk_channel_id": "99998888777766", @@ -3931,7 +4267,7 @@ def test_deserialize_rest_guild_with_unset_fields(self, entity_factory_impl): assert guild.approximate_active_member_count is None assert guild.approximate_member_count is None - def test_deserialize_rest_guild_with_null_fields(self, entity_factory_impl): + def test_deserialize_rest_guild_with_null_fields(self, entity_factory_impl: entity_factory.EntityFactoryImpl): guild = entity_factory_impl.deserialize_rest_guild( { "afk_channel_id": None, @@ -3998,18 +4334,18 @@ def test_deserialize_rest_guild_with_null_fields(self, entity_factory_impl): @pytest.fixture def gateway_guild_payload( self, - guild_text_channel_payload, - guild_voice_channel_payload, - guild_news_channel_payload, - known_custom_emoji_payload, - guild_news_thread_payload, - guild_public_thread_payload, - guild_private_thread_payload, - member_payload, - member_presence_payload, - guild_role_payload, - voice_state_payload, - ): + guild_text_channel_payload: dict[str, typing.Any], + guild_voice_channel_payload: dict[str, typing.Any], + guild_news_channel_payload: dict[str, typing.Any], + known_custom_emoji_payload: dict[str, typing.Any], + guild_news_thread_payload: dict[str, typing.Any], + guild_public_thread_payload: dict[str, typing.Any], + guild_private_thread_payload: dict[str, typing.Any], + member_payload: dict[str, typing.Any], + member_presence_payload: dict[str, typing.Any], + guild_role_payload: dict[str, typing.Any], + voice_state_payload: dict[str, typing.Any], + ) -> dict[str, typing.Any]: return { "afk_channel_id": "99998888777766", "afk_timeout": 1200, @@ -4065,14 +4401,11 @@ def gateway_guild_payload( def test_deserialize_gateway_guild( self, entity_factory_impl: entity_factory.EntityFactoryImpl, - mock_app: traits.RESTAware, + hikari_app: traits.RESTAware, gateway_guild_payload: dict[str, typing.Any], guild_text_channel_payload: dict[str, typing.Any], guild_voice_channel_payload: dict[str, typing.Any], guild_news_channel_payload: dict[str, typing.Any], - guild_news_thread_payload: dict[str, typing.Any], - guild_public_thread_payload: dict[str, typing.Any], - guild_private_thread_payload: dict[str, typing.Any], known_custom_emoji_payload: dict[str, typing.Any], member_payload: dict[str, typing.Any], member_presence_payload: dict[str, typing.Any], @@ -4083,7 +4416,7 @@ def test_deserialize_gateway_guild( gateway_guild_payload, user_id=snowflakes.Snowflake(43123) ) guild = guild_definition.guild() - assert guild.app is mock_app + assert guild.app is hikari_app assert guild.id == 265828729970753537 assert guild.name == "L33t guild" assert guild.icon_hash == "1a2b3c4d" @@ -4169,7 +4502,7 @@ def test_deserialize_gateway_guild( ) } - def test_deserialize_gateway_guild_with_unset_fields(self, entity_factory_impl): + def test_deserialize_gateway_guild_with_unset_fields(self, entity_factory_impl: entity_factory.EntityFactoryImpl): guild_definition = entity_factory_impl.deserialize_gateway_guild( { "afk_channel_id": "99998888777766", @@ -4205,7 +4538,7 @@ def test_deserialize_gateway_guild_with_unset_fields(self, entity_factory_impl): "verification_level": 4, "nsfw_level": 0, }, - user_id=65123, + user_id=snowflakes.Snowflake(65123), ) guild = guild_definition.guild() assert guild.joined_at is None @@ -4225,7 +4558,7 @@ def test_deserialize_gateway_guild_with_unset_fields(self, entity_factory_impl): with pytest.raises(LookupError, match=r"'voice_states' not in payload"): guild_definition.voice_states() - def test_deserialize_gateway_guild_with_null_fields(self, entity_factory_impl): + def test_deserialize_gateway_guild_with_null_fields(self, entity_factory_impl: entity_factory.EntityFactoryImpl): guild_definition = entity_factory_impl.deserialize_gateway_guild( { "afk_channel_id": None, @@ -4278,7 +4611,7 @@ def test_deserialize_gateway_guild_with_null_fields(self, entity_factory_impl): "widget_enabled": True, "nsfw_level": 0, }, - user_id=1343123, + user_id=snowflakes.Snowflake(1343123), ) guild = guild_definition.guild() assert guild.icon_hash is None @@ -4300,7 +4633,7 @@ def test_deserialize_gateway_guild_with_null_fields(self, entity_factory_impl): ###################### @pytest.fixture - def slash_command_payload(self): + def slash_command_payload(self) -> dict[str, typing.Any]: return { "id": "1231231231", "application_id": "12354123", @@ -4344,10 +4677,15 @@ def slash_command_payload(self): "contexts": ["0"], } - def test_deserialize_slash_command(self, entity_factory_impl, mock_app, slash_command_payload): - command = entity_factory_impl.deserialize_slash_command(payload=slash_command_payload) - - assert command.app is mock_app + def test_deserialize_slash_command( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + slash_command_payload: dict[str, typing.Any], + ): + command = entity_factory_impl.deserialize_slash_command(payload=slash_command_payload) + + assert command.app is hikari_app assert command.id == 1231231231 assert command.application_id == 12354123 assert command.guild_id == 49949494 @@ -4360,6 +4698,7 @@ def test_deserialize_slash_command(self, entity_factory_impl, mock_app, slash_co assert command.context_types == [application_models.ApplicationContextType.GUILD] # CommandOption + assert command.options is not None assert len(command.options) == 1 option = command.options[0] assert option.type is commands.OptionType.SUB_COMMAND @@ -4378,6 +4717,7 @@ def test_deserialize_slash_command(self, entity_factory_impl, mock_app, slash_co assert option.min_length == 1 assert option.max_length == 44 + assert option.options is not None assert len(option.options) == 1 suboption = option.options[0] assert suboption.type is commands.OptionType.USER @@ -4388,6 +4728,7 @@ def test_deserialize_slash_command(self, entity_factory_impl, mock_app, slash_co assert suboption.channel_types is None # CommandChoice + assert suboption.choices is not None assert len(suboption.choices) == 1 choice = suboption.choices[0] assert isinstance(choice, commands.CommandChoice) @@ -4399,24 +4740,18 @@ def test_deserialize_slash_command(self, entity_factory_impl, mock_app, slash_co assert isinstance(option, commands.CommandOption) assert isinstance(command, commands.SlashCommand) - def test_deserialize_slash_command_with_passed_through_guild_id(self, entity_factory_impl): - payload = { - "id": "1231231231", - "guild_id": "987654321", - "application_id": "12354123", - "type": 1, - "name": "good name", - "description": "very good description", - "options": [], - "default_member_permissions": 0, - "version": "123312", - } - - command = entity_factory_impl.deserialize_slash_command(payload, guild_id=123123) + def test_deserialize_slash_command_with_passed_through_guild_id( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, slash_command_payload: dict[str, typing.Any] + ): + command = entity_factory_impl.deserialize_slash_command( + slash_command_payload, guild_id=snowflakes.Snowflake(123123) + ) assert command.guild_id == 123123 - def test_deserialize_slash_command_with_unset_values(self, entity_factory_impl, slash_command_payload): + def test_deserialize_slash_command_with_null_and_unset_values( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, slash_command_payload: dict[str, typing.Any] + ): del slash_command_payload["options"] del slash_command_payload["nsfw"] del slash_command_payload["integration_types"] @@ -4430,7 +4765,9 @@ def test_deserialize_slash_command_with_unset_values(self, entity_factory_impl, assert command.context_types == [] assert isinstance(command, commands.SlashCommand) - def test_deserialize_slash_command_with_null_values(self, entity_factory_impl, slash_command_payload): + def test_deserialize_slash_command_with_null_values( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, slash_command_payload: dict[str, typing.Any] + ): slash_command_payload["contexts"] = None command = entity_factory_impl.deserialize_slash_command(payload=slash_command_payload) @@ -4438,7 +4775,7 @@ def test_deserialize_slash_command_with_null_values(self, entity_factory_impl, s assert command.context_types == [] def test_deserialize_slash_command_standardizes_default_member_permissions( - self, entity_factory_impl, slash_command_payload + self, entity_factory_impl: entity_factory.EntityFactoryImpl, slash_command_payload: dict[str, typing.Any] ): slash_command_payload["default_member_permissions"] = 0 @@ -4454,24 +4791,27 @@ def test_deserialize_slash_command_standardizes_default_member_permissions( (3, "deserialize_context_menu_command"), ], ) - def test_deserialize_command(self, mock_app, type_, fn): + def test_deserialize_command(self, hikari_app: traits.RESTAware, type_: int, fn: str): payload = {"type": type_} with mock.patch.object(entity_factory.EntityFactoryImpl, fn) as expected_fn: # We need to instantiate it after the mock so that the functions that are stored in the dicts # are the ones we mock - entity_factory_impl = entity_factory.EntityFactoryImpl(app=mock_app) + entity_factory_impl = entity_factory.EntityFactoryImpl(app=hikari_app) - assert entity_factory_impl.deserialize_command(payload, guild_id=123) is expected_fn.return_value + assert ( + entity_factory_impl.deserialize_command(payload, guild_id=snowflakes.Snowflake(123)) + is expected_fn.return_value + ) expected_fn.assert_called_once_with(payload, guild_id=123) - def test_deserialize_command_when_unknown_type(self, entity_factory_impl): + def test_deserialize_command_when_unknown_type(self, entity_factory_impl: entity_factory.EntityFactoryImpl): with pytest.raises(errors.UnrecognisedEntityError): entity_factory_impl.deserialize_command({"type": -111}) @pytest.fixture - def guild_command_permissions_payload(self): + def guild_command_permissions_payload(self) -> dict[str, typing.Any]: return { "id": "123321", "application_id": "431321123", @@ -4479,7 +4819,11 @@ def guild_command_permissions_payload(self): "permissions": [{"id": "22222", "type": 1, "permission": True}], } - def test_deserialize_guild_command_permissions(self, entity_factory_impl, guild_command_permissions_payload): + def test_deserialize_guild_command_permissions( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + guild_command_permissions_payload: dict[str, typing.Any], + ): command = entity_factory_impl.deserialize_guild_command_permissions(guild_command_permissions_payload) assert command.command_id == 123321 @@ -4494,7 +4838,7 @@ def test_deserialize_guild_command_permissions(self, entity_factory_impl, guild_ assert permission.has_access is True assert isinstance(permission, commands.CommandPermission) - def test_serialize_command_permission(self, entity_factory_impl): + def test_serialize_command_permission(self, entity_factory_impl: entity_factory.EntityFactoryImpl): command = commands.CommandPermission(type=commands.CommandPermissionType.ROLE, has_access=True, id=123321) assert entity_factory_impl.serialize_command_permission(command) == { @@ -4504,7 +4848,7 @@ def test_serialize_command_permission(self, entity_factory_impl): } @pytest.fixture - def partial_interaction_payload(self): + def partial_interaction_payload(self) -> dict[str, typing.Any]: return { "id": "795459528803745843", "token": "-- token redacted --", @@ -4516,7 +4860,7 @@ def partial_interaction_payload(self): } @pytest.fixture - def interaction_member_payload(self, user_payload): + def interaction_member_payload(self, user_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "user": user_payload, "is_pending": False, @@ -4530,8 +4874,15 @@ def interaction_member_payload(self, user_payload): "roles": ["582345963851743243", "582689893965365248", "734164204679856290", "757331666388910181"], } - def test__deserialize_interaction_member(self, entity_factory_impl, interaction_member_payload, user_payload): - member = entity_factory_impl._deserialize_interaction_member(interaction_member_payload, guild_id=43123123) + def test__deserialize_interaction_member( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + interaction_member_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + ): + member = entity_factory_impl._deserialize_interaction_member( + interaction_member_payload, guild_id=snowflakes.Snowflake(43123123) + ) assert member.id == 115590097100865541 assert member.joined_at == datetime.datetime(2020, 9, 27, 22, 58, 10, 282000, tzinfo=datetime.timezone.utc) assert member.nickname == "Snab" @@ -4556,7 +4907,7 @@ def test__deserialize_interaction_member(self, entity_factory_impl, interaction_ assert isinstance(member, base_interactions.InteractionMember) def test__deserialize_interaction_member_when_guild_id_already_in_roles_doesnt_duplicate( - self, entity_factory_impl, interaction_member_payload + self, entity_factory_impl: entity_factory.EntityFactoryImpl, interaction_member_payload: dict[str, typing.Any] ): interaction_member_payload["roles"] = [ 582345963851743243, @@ -4566,7 +4917,9 @@ def test__deserialize_interaction_member_when_guild_id_already_in_roles_doesnt_d 43123123, ] - member = entity_factory_impl._deserialize_interaction_member(interaction_member_payload, guild_id=43123123) + member = entity_factory_impl._deserialize_interaction_member( + interaction_member_payload, guild_id=snowflakes.Snowflake(43123123) + ) assert member.role_ids == [ 582345963851743243, 582689893965365248, @@ -4575,64 +4928,76 @@ def test__deserialize_interaction_member_when_guild_id_already_in_roles_doesnt_d 43123123, ] - def test__deserialize_interaction_member_with_unset_fields(self, entity_factory_impl, interaction_member_payload): + def test__deserialize_interaction_member_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, interaction_member_payload: dict[str, typing.Any] + ): del interaction_member_payload["premium_since"] del interaction_member_payload["avatar"] del interaction_member_payload["communication_disabled_until"] - member = entity_factory_impl._deserialize_interaction_member(interaction_member_payload, guild_id=43123123) + member = entity_factory_impl._deserialize_interaction_member( + interaction_member_payload, guild_id=snowflakes.Snowflake(43123123) + ) assert member.guild_avatar_hash is None assert member.premium_since is None assert member.raw_communication_disabled_until is None - def test__deserialize_interaction_member_with_passed_user(self, entity_factory_impl, interaction_member_payload): - mock_user = object() + def test__deserialize_interaction_member_with_passed_user( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, interaction_member_payload: dict[str, typing.Any] + ): + mock_user = mock.Mock() member = entity_factory_impl._deserialize_interaction_member( - interaction_member_payload, guild_id=43123123, user=mock_user + interaction_member_payload, guild_id=snowflakes.Snowflake(43123123), user=mock_user ) assert member.user is mock_user def test__deserialize_resolved_option_data( self, - entity_factory_impl, - interaction_resolved_data_payload, - attachment_payload, - user_payload, - guild_role_payload, - interaction_member_payload, - message_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + interaction_resolved_data_payload: dict[str, typing.Any], + attachment_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + guild_role_payload: dict[str, typing.Any], + interaction_member_payload: dict[str, typing.Any], + message_payload: dict[str, typing.Any], ): resolved = entity_factory_impl._deserialize_resolved_option_data( - interaction_resolved_data_payload, guild_id=123321 + interaction_resolved_data_payload, guild_id=snowflakes.Snowflake(123321) ) assert len(resolved.channels) == 1 - channel = resolved.channels[695382395666300958] + channel = resolved.channels[snowflakes.Snowflake(695382395666300958)] assert channel.type is channel_models.ChannelType.GUILD_TEXT assert channel.id == 695382395666300958 assert channel.name == "discord-announcements" assert channel.permissions == permission_models.Permissions(17179869183) assert isinstance(channel, base_interactions.InteractionChannel) assert len(resolved.members) == 1 - member = resolved.members[115590097100865541] + member = resolved.members[snowflakes.Snowflake(115590097100865541)] assert member == entity_factory_impl._deserialize_interaction_member( - interaction_member_payload, guild_id=123321, user=entity_factory_impl.deserialize_user(user_payload) + interaction_member_payload, + guild_id=snowflakes.Snowflake(123321), + user=entity_factory_impl.deserialize_user(user_payload), ) assert resolved.attachments == { 690922406474154014: entity_factory_impl._deserialize_message_attachment(attachment_payload) } assert resolved.roles == { - 41771983423143936: entity_factory_impl.deserialize_role(guild_role_payload, guild_id=123321) + 41771983423143936: entity_factory_impl.deserialize_role( + guild_role_payload, guild_id=snowflakes.Snowflake(123321) + ) } assert resolved.users == {115590097100865541: entity_factory_impl.deserialize_user(user_payload)} assert resolved.messages == {123: entity_factory_impl.deserialize_message(message_payload)} assert isinstance(resolved, base_interactions.ResolvedOptionData) - def test__deserialize_resolved_option_data_with_empty_resolved_resources(self, entity_factory_impl): + def test__deserialize_resolved_option_data_with_empty_resolved_resources( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): resolved = entity_factory_impl._deserialize_resolved_option_data({}) assert resolved.attachments == {} @@ -4644,8 +5009,13 @@ def test__deserialize_resolved_option_data_with_empty_resolved_resources(self, e @pytest.fixture def interaction_resolved_data_payload( - self, interaction_member_payload, attachment_payload, guild_role_payload, user_payload, message_payload - ): + self, + interaction_member_payload: dict[str, typing.Any], + attachment_payload: dict[str, typing.Any], + guild_role_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + message_payload: dict[str, typing.Any], + ) -> dict[str, typing.Any]: return { "attachments": {"690922406474154014": attachment_payload}, "channels": { @@ -4664,8 +5034,11 @@ def interaction_resolved_data_payload( @pytest.fixture def command_interaction_payload( - self, interaction_member_payload, interaction_resolved_data_payload, guild_text_channel_payload - ): + self, + interaction_member_payload: dict[str, typing.Any], + interaction_resolved_data_payload: dict[str, typing.Any], + guild_text_channel_payload: dict[str, typing.Any], + ) -> dict[str, typing.Any]: return { "id": "3490190239012093", "type": 2, @@ -4715,15 +5088,15 @@ def command_interaction_payload( def test_deserialize_command_interaction( self, - entity_factory_impl, - mock_app, - command_interaction_payload, - interaction_member_payload, - interaction_resolved_data_payload, - guild_text_channel_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + command_interaction_payload: dict[str, typing.Any], + interaction_member_payload: dict[str, typing.Any], + interaction_resolved_data_payload: dict[str, typing.Any], + guild_text_channel_payload: dict[str, typing.Any], ): interaction = entity_factory_impl.deserialize_command_interaction(command_interaction_payload) - assert interaction.app is mock_app + assert interaction.app is hikari_app assert interaction.application_id == 76234234 assert interaction.id == 3490190239012093 assert interaction.type is base_interactions.InteractionType.APPLICATION_COMMAND @@ -4736,13 +5109,14 @@ def test_deserialize_command_interaction( assert interaction.guild_locale == "en-US" assert interaction.guild_locale is locales.Locale.EN_US assert interaction.member == entity_factory_impl._deserialize_interaction_member( - interaction_member_payload, guild_id=43123123 + interaction_member_payload, guild_id=snowflakes.Snowflake(43123123) ) + assert interaction.member is not None assert interaction.user is interaction.member.user assert interaction.command_id == 43123123 assert interaction.command_name == "okokokok" assert interaction.resolved == entity_factory_impl._deserialize_resolved_option_data( - interaction_resolved_data_payload, guild_id=43123123 + interaction_resolved_data_payload, guild_id=snowflakes.Snowflake(43123123) ) assert interaction.app_permissions == 54123 assert len(interaction.entitlements) == 1 @@ -4762,6 +5136,7 @@ def test_deserialize_command_interaction( assert option.name == "an option" assert option.value is None assert option.type is commands.OptionType.SUB_COMMAND + assert option.options is not None assert len(option.options) == 2 sub_option1 = option.options[0] @@ -4784,8 +5159,11 @@ def test_deserialize_command_interaction( @pytest.fixture def context_menu_command_interaction_payload( - self, interaction_member_payload, user_payload, guild_text_channel_payload - ): + self, + interaction_member_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + guild_text_channel_payload: dict[str, typing.Any], + ) -> dict[str, typing.Any]: return { "id": "3490190239012093", "type": 4, @@ -4825,7 +5203,9 @@ def context_menu_command_interaction_payload( } def test_deserialize_command_interaction_with_context_menu_field( - self, entity_factory_impl, context_menu_command_interaction_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + context_menu_command_interaction_payload: dict[str, typing.Any], ): interaction = entity_factory_impl.deserialize_command_interaction(context_menu_command_interaction_payload) assert interaction.target_id == 115590097100865541 @@ -4840,7 +5220,10 @@ def test_deserialize_command_interaction_with_context_menu_field( assert interaction.context == application_models.ApplicationContextType.PRIVATE_CHANNEL def test_deserialize_command_interaction_with_null_attributes( - self, entity_factory_impl, command_interaction_payload, user_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + command_interaction_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], ): del command_interaction_payload["guild_id"] del command_interaction_payload["member"] @@ -4861,7 +5244,13 @@ def test_deserialize_command_interaction_with_null_attributes( assert interaction.registered_guild_id is None @pytest.fixture - def autocomplete_interaction_payload(self, member_payload, user_payload, guild_text_channel_payload): + def autocomplete_interaction_payload( + self, + member_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + interaction_resolved_data_payload: dict[str, typing.Any], + guild_text_channel_payload: dict[str, typing.Any], + ) -> dict[str, typing.Any]: return { "id": "3490190239012093", "type": 4, @@ -4911,17 +5300,18 @@ def autocomplete_interaction_payload(self, member_payload, user_payload, guild_t def test_deserialize_autocomplete_interaction( self, - entity_factory_impl, - mock_app, - member_payload, - autocomplete_interaction_payload, - guild_text_channel_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + member_payload: dict[str, typing.Any], + autocomplete_interaction_payload: dict[str, typing.Any], + interaction_resolved_data_payload: dict[str, typing.Any], + guild_text_channel_payload: dict[str, typing.Any], ): entity_factory_impl._deserialize_interaction_member = mock.Mock() entity_factory_impl._deserialize_resolved_option_data = mock.Mock() interaction = entity_factory_impl.deserialize_autocomplete_interaction(autocomplete_interaction_payload) - assert interaction.app is mock_app + assert interaction.app is hikari_app assert interaction.application_id == 76234234 assert interaction.id == 3490190239012093 assert interaction.type is base_interactions.InteractionType.AUTOCOMPLETE @@ -4948,6 +5338,7 @@ def test_deserialize_autocomplete_interaction( assert option.name == "options" assert option.value is None assert option.type is commands.OptionType.SUB_COMMAND + assert option.options is not None assert len(option.options) == 2 sub_option1 = option.options[0] @@ -4971,7 +5362,10 @@ def test_deserialize_autocomplete_interaction( assert isinstance(interaction, command_interactions.AutocompleteInteraction) def test_deserialize_autocomplete_interaction_with_null_fields( - self, entity_factory_impl, user_payload, autocomplete_interaction_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + user_payload: dict[str, typing.Any], + autocomplete_interaction_payload: dict[str, typing.Any], ): del autocomplete_interaction_payload["guild_locale"] del autocomplete_interaction_payload["guild_id"] @@ -4996,23 +5390,23 @@ def test_deserialize_autocomplete_interaction_with_null_fields( (4, "deserialize_autocomplete_interaction"), ], ) - def test_deserialize_interaction(self, mock_app, type_, fn): + def test_deserialize_interaction(self, hikari_app: traits.RESTAware, type_: int, fn: str): payload = {"type": type_} with mock.patch.object(entity_factory.EntityFactoryImpl, fn) as expected_fn: # We need to instantiate it after the mock so that the functions that are stored in the dicts # are the ones we mock - entity_factory_impl = entity_factory.EntityFactoryImpl(app=mock_app) + entity_factory_impl = entity_factory.EntityFactoryImpl(app=hikari_app) assert entity_factory_impl.deserialize_interaction(payload) is expected_fn.return_value expected_fn.assert_called_once_with(payload) - def test_deserialize_interaction_handles_unknown_type(self, entity_factory_impl): + def test_deserialize_interaction_handles_unknown_type(self, entity_factory_impl: entity_factory.EntityFactoryImpl): with pytest.raises(errors.UnrecognisedEntityError): entity_factory_impl.deserialize_interaction({"type": -999}) - def test_serialize_command_option(self, entity_factory_impl): + def test_serialize_command_option(self, entity_factory_impl: entity_factory.EntityFactoryImpl): option = commands.CommandOption( type=commands.OptionType.INTEGER, name="a name", @@ -5083,7 +5477,7 @@ def test_serialize_command_option(self, entity_factory_impl): } @pytest.fixture - def context_menu_command_payload(self): + def context_menu_command_payload(self) -> dict[str, typing.Any]: return { "id": "1231231231", "application_id": "12354123", @@ -5098,7 +5492,9 @@ def context_menu_command_payload(self): "contexts": ["0"], } - def test_deserialize_context_menu_command(self, entity_factory_impl, context_menu_command_payload): + def test_deserialize_context_menu_command( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, context_menu_command_payload: dict[str, typing.Any] + ): command = entity_factory_impl.deserialize_context_menu_command(context_menu_command_payload) assert isinstance(command, commands.ContextMenuCommand) @@ -5113,8 +5509,12 @@ def test_deserialize_context_menu_command(self, entity_factory_impl, context_men assert command.integration_types == [application_models.ApplicationIntegrationType.GUILD_INSTALL] assert command.context_types == [application_models.ApplicationContextType.GUILD] - def test_deserialize_context_menu_command_with_guild_id(self, entity_factory_impl, context_menu_command_payload): - command = entity_factory_impl.deserialize_command(context_menu_command_payload, guild_id=123) + def test_deserialize_context_menu_command_with_guild_id( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, context_menu_command_payload: dict[str, typing.Any] + ): + command = entity_factory_impl.deserialize_command( + context_menu_command_payload, guild_id=snowflakes.Snowflake(123) + ) assert isinstance(command, commands.ContextMenuCommand) assert command.id == 1231231231 @@ -5128,15 +5528,17 @@ def test_deserialize_context_menu_command_with_guild_id(self, entity_factory_imp assert command.integration_types == [application_models.ApplicationIntegrationType.GUILD_INSTALL] assert command.context_types == [application_models.ApplicationContextType.GUILD] - def test_deserialize_context_menu_command_with_null_values(self, entity_factory_impl, slash_command_payload): - slash_command_payload["contexts"] = None + def test_deserialize_context_menu_command_with_null_values( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, context_menu_command_payload: dict[str, typing.Any] + ): + context_menu_command_payload["contexts"] = None - command = entity_factory_impl.deserialize_slash_command(payload=slash_command_payload) + context_menu = entity_factory_impl.deserialize_context_menu_command(payload=context_menu_command_payload) - assert command.context_types == [] + assert context_menu.context_types == [] def test_deserialize_context_menu_command_with_with__unset_values( - self, entity_factory_impl, context_menu_command_payload + self, entity_factory_impl: entity_factory.EntityFactoryImpl, context_menu_command_payload: dict[str, typing.Any] ): del context_menu_command_payload["dm_permission"] del context_menu_command_payload["nsfw"] @@ -5151,7 +5553,7 @@ def test_deserialize_context_menu_command_with_with__unset_values( assert command.context_types == [] def test_deserialize_context_menu_command_default_member_permissions( - self, entity_factory_impl, context_menu_command_payload + self, entity_factory_impl: entity_factory.EntityFactoryImpl, context_menu_command_payload: dict[str, typing.Any] ): context_menu_command_payload["default_member_permissions"] = 0 @@ -5161,8 +5563,12 @@ def test_deserialize_context_menu_command_default_member_permissions( @pytest.fixture def component_interaction_payload( - self, interaction_member_payload, message_payload, interaction_resolved_data_payload, guild_text_channel_payload - ): + self, + interaction_member_payload: dict[str, typing.Any], + message_payload: dict[str, typing.Any], + interaction_resolved_data_payload: dict[str, typing.Any], + guild_text_channel_payload: dict[str, typing.Any], + ) -> dict[str, typing.Any]: return { "version": 1, "type": 3, @@ -5203,17 +5609,17 @@ def component_interaction_payload( def test_deserialize_component_interaction( self, - entity_factory_impl, - component_interaction_payload, - interaction_member_payload, - mock_app, - message_payload, - interaction_resolved_data_payload, - guild_text_channel_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + component_interaction_payload: dict[str, typing.Any], + interaction_member_payload: dict[str, typing.Any], + hikari_app: traits.RESTAware, + message_payload: dict[str, typing.Any], + interaction_resolved_data_payload: dict[str, typing.Any], + guild_text_channel_payload: dict[str, typing.Any], ): interaction = entity_factory_impl.deserialize_component_interaction(component_interaction_payload) - assert interaction.app is mock_app + assert interaction.app is hikari_app assert interaction.id == 846462639134605312 assert interaction.application_id == 290926444748734465 assert interaction.type is base_interactions.InteractionType.MESSAGE_COMPONENT @@ -5225,8 +5631,9 @@ def test_deserialize_component_interaction( assert interaction.guild_id == 290926798626357999 assert interaction.message == entity_factory_impl.deserialize_message(message_payload) assert interaction.member == entity_factory_impl._deserialize_interaction_member( - interaction_member_payload, guild_id=290926798626357999 + interaction_member_payload, guild_id=snowflakes.Snowflake(290926798626357999) ) + assert interaction.member is not None assert interaction.user is interaction.member.user assert interaction.values == ["1", "2", "67"] assert interaction.locale == "es-ES" @@ -5243,7 +5650,7 @@ def test_deserialize_component_interaction( assert interaction.context == application_models.ApplicationContextType.PRIVATE_CHANNEL # ResolvedData assert interaction.resolved == entity_factory_impl._deserialize_resolved_option_data( - interaction_resolved_data_payload, guild_id=290926798626357999 + interaction_resolved_data_payload, guild_id=snowflakes.Snowflake(290926798626357999) ) assert isinstance(interaction, component_interactions.ComponentInteraction) @@ -5251,7 +5658,11 @@ def test_deserialize_component_interaction( assert interaction.entitlements[0].id == 696969696969696 def test_deserialize_component_interaction_with_undefined_fields( - self, entity_factory_impl, user_payload, message_payload, guild_text_channel_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + user_payload: dict[str, typing.Any], + message_payload: dict[str, typing.Any], + guild_text_channel_payload: dict[str, typing.Any], ): interaction = entity_factory_impl.deserialize_component_interaction( { @@ -5294,7 +5705,12 @@ def test_deserialize_component_interaction_with_undefined_fields( assert isinstance(interaction, component_interactions.ComponentInteraction) @pytest.fixture - def modal_interaction_payload(self, interaction_member_payload, message_payload, guild_text_channel_payload): + def modal_interaction_payload( + self, + interaction_member_payload: dict[str, typing.Any], + message_payload: dict[str, typing.Any], + guild_text_channel_payload: dict[str, typing.Any], + ) -> dict[str, typing.Any]: return { "version": 1, "type": 5, @@ -5340,15 +5756,15 @@ def modal_interaction_payload(self, interaction_member_payload, message_payload, def test_deserialize_modal_interaction( self, - entity_factory_impl, - mock_app, - modal_interaction_payload, - interaction_member_payload, - message_payload, - guild_text_channel_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + modal_interaction_payload: dict[str, typing.Any], + interaction_member_payload: dict[str, typing.Any], + guild_text_channel_payload: dict[str, typing.Any], + message_payload: dict[str, typing.Any], ): interaction = entity_factory_impl.deserialize_modal_interaction(modal_interaction_payload) - assert interaction.app is mock_app + assert interaction.app is hikari_app assert interaction.id == 846462639134605312 assert interaction.application_id == 290926444748734465 assert interaction.type is base_interactions.InteractionType.MODAL_SUBMIT @@ -5358,8 +5774,9 @@ def test_deserialize_modal_interaction( assert interaction.guild_id == 290926798626357999 assert interaction.message == entity_factory_impl.deserialize_message(message_payload) assert interaction.member == entity_factory_impl._deserialize_interaction_member( - interaction_member_payload, guild_id=290926798626357999 + interaction_member_payload, guild_id=snowflakes.Snowflake(290926798626357999) ) + assert interaction.member is not None assert interaction.user is interaction.member.user assert isinstance(interaction, modal_interactions.ModalInteraction) @@ -5375,7 +5792,10 @@ def test_deserialize_modal_interaction( assert short_text_input.custom_id == "name" def test_deserialize_modal_interaction_with_user( - self, entity_factory_impl, modal_interaction_payload, user_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + modal_interaction_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], ): modal_interaction_payload["member"] = None modal_interaction_payload["user"] = user_payload @@ -5391,7 +5811,7 @@ def test_deserialize_modal_interaction_with_user( assert interaction.context == application_models.ApplicationContextType.PRIVATE_CHANNEL def test_deserialize_modal_interaction_with_unrecognized_component( - self, entity_factory_impl, modal_interaction_payload + self, entity_factory_impl: entity_factory.EntityFactoryImpl, modal_interaction_payload: dict[str, typing.Any] ): modal_interaction_payload["data"]["components"] = [{"type": 0}] @@ -5403,11 +5823,11 @@ def test_deserialize_modal_interaction_with_unrecognized_component( ################## @pytest.fixture - def partial_sticker_payload(self): + def partial_sticker_payload(self) -> dict[str, typing.Any]: return {"id": "749046696482439188", "name": "Thinking", "format_type": 3} @pytest.fixture - def standard_sticker_payload(self): + def standard_sticker_payload(self) -> dict[str, typing.Any]: return { "id": "749046696482439188", "name": "Thinking", @@ -5419,7 +5839,7 @@ def standard_sticker_payload(self): } @pytest.fixture - def guild_sticker_payload(self, user_payload): + def guild_sticker_payload(self, user_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "749046696482439188", "name": "Thinking", @@ -5432,7 +5852,7 @@ def guild_sticker_payload(self, user_payload): } @pytest.fixture - def sticker_pack_payload(self, standard_sticker_payload): + def sticker_pack_payload(self, standard_sticker_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "123", "name": "My sticker pack", @@ -5443,14 +5863,18 @@ def sticker_pack_payload(self, standard_sticker_payload): "banner_asset_id": "342123321", } - def test_deserialize_partial_sticker(self, entity_factory_impl, partial_sticker_payload): + def test_deserialize_partial_sticker( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, partial_sticker_payload: dict[str, typing.Any] + ): partial_sticker = entity_factory_impl.deserialize_partial_sticker(partial_sticker_payload) assert partial_sticker.id == 749046696482439188 assert partial_sticker.name == "Thinking" assert partial_sticker.format_type is sticker_models.StickerFormatType.LOTTIE - def test_deserialize_standard_sticker(self, entity_factory_impl, standard_sticker_payload): + def test_deserialize_standard_sticker( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, standard_sticker_payload: dict[str, typing.Any] + ): standard_sticker = entity_factory_impl.deserialize_standard_sticker(standard_sticker_payload) assert standard_sticker.id == 749046696482439188 @@ -5461,7 +5885,12 @@ def test_deserialize_standard_sticker(self, entity_factory_impl, standard_sticke assert standard_sticker.sort_value == 96 assert standard_sticker.tags == ["thinking", "thonkang"] - def test_deserialize_guild_sticker(self, entity_factory_impl, guild_sticker_payload, user_payload): + def test_deserialize_guild_sticker( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + guild_sticker_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + ): guild_sticker = entity_factory_impl.deserialize_guild_sticker(guild_sticker_payload) assert guild_sticker.id == 749046696482439188 @@ -5473,14 +5902,18 @@ def test_deserialize_guild_sticker(self, entity_factory_impl, guild_sticker_payl assert guild_sticker.tag == "tag" assert guild_sticker.user == entity_factory_impl.deserialize_user(user_payload) - def test_deserialize_guild_sticker_with_unset_fields(self, entity_factory_impl, guild_sticker_payload): + def test_deserialize_guild_sticker_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, guild_sticker_payload: dict[str, typing.Any] + ): del guild_sticker_payload["user"] guild_sticker = entity_factory_impl.deserialize_guild_sticker(guild_sticker_payload) assert guild_sticker.user is None - def test_deserialize_sticker_pack(self, entity_factory_impl, sticker_pack_payload): + def test_deserialize_sticker_pack( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, sticker_pack_payload: dict[str, typing.Any] + ): pack = entity_factory_impl.deserialize_sticker_pack(sticker_pack_payload) assert pack.id == 123 @@ -5500,7 +5933,9 @@ def test_deserialize_sticker_pack(self, entity_factory_impl, sticker_pack_payloa assert sticker.sort_value == 96 assert sticker.tags == ["thinking", "thonkang"] - def test_deserialize_sticker_pack_with_optional_fields(self, entity_factory_impl, sticker_pack_payload): + def test_deserialize_sticker_pack_with_optional_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, sticker_pack_payload: dict[str, typing.Any] + ): del sticker_pack_payload["cover_sticker_id"] del sticker_pack_payload["banner_asset_id"] @@ -5509,57 +5944,64 @@ def test_deserialize_sticker_pack_with_optional_fields(self, entity_factory_impl assert pack.cover_sticker_id is None assert pack.banner_asset_id is None - def test_stickers(self, entity_factory_impl, guild_sticker_payload): + def test_stickers( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, guild_sticker_payload: dict[str, typing.Any] + ): guild_definition = entity_factory_impl.deserialize_gateway_guild( - {"id": "265828729970753537", "stickers": [guild_sticker_payload]}, user_id=123321 + {"id": "265828729970753537", "stickers": [guild_sticker_payload]}, user_id=snowflakes.Snowflake(123321) ) assert guild_definition.stickers() == { 749046696482439188: entity_factory_impl.deserialize_guild_sticker(guild_sticker_payload) } - def test_stickers_returns_cached_values(self, entity_factory_impl): + def test_stickers_returns_cached_values(self, entity_factory_impl: entity_factory.EntityFactoryImpl): with mock.patch.object( entity_factory.EntityFactoryImpl, "deserialize_guild_sticker" ) as mock_deserialize_guild_sticker: guild_definition = entity_factory_impl.deserialize_gateway_guild( - {"id": "265828729970753537"}, user_id=123321 + {"id": "265828729970753537"}, user_id=snowflakes.Snowflake(123321) ) - mock_sticker = object() - guild_definition._stickers = {"54545454": mock_sticker} + mock_sticker = mock.Mock() - assert guild_definition.stickers() == {"54545454": mock_sticker} - mock_deserialize_guild_sticker.assert_not_called() + with mock.patch.object(guild_definition, "_stickers", {"54545454": mock_sticker}): + assert guild_definition.stickers() == {"54545454": mock_sticker} + mock_deserialize_guild_sticker.assert_not_called() ################# # INVITE MODELS # ################# @pytest.fixture - def vanity_url_payload(self): + def vanity_url_payload(self) -> dict[str, typing.Any]: return {"code": "iamacode", "uses": 42} - def test_deserialize_vanity_url(self, entity_factory_impl, mock_app, vanity_url_payload): + def test_deserialize_vanity_url( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + vanity_url_payload: dict[str, typing.Any], + ): vanity_url = entity_factory_impl.deserialize_vanity_url(vanity_url_payload) - assert vanity_url.app is mock_app + assert vanity_url.app is hikari_app assert vanity_url.code == "iamacode" assert vanity_url.uses == 42 assert isinstance(vanity_url, invite_models.VanityURL) @pytest.fixture - def alternative_user_payload(self): + def alternative_user_payload(self) -> dict[str, typing.Any]: return {"id": "1231231", "username": "soad", "discriminator": "3333", "avatar": None} @pytest.fixture def invite_payload( self, - partial_channel_payload, - user_payload, - alternative_user_payload, - guild_welcome_screen_payload, - invite_application_payload, - ): + partial_channel_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + alternative_user_payload: dict[str, typing.Any], + guild_welcome_screen_payload: dict[str, typing.Any], + invite_application_payload: dict[str, typing.Any], + ) -> dict[str, typing.Any]: return { "code": "aCode", "guild": { @@ -5587,19 +6029,20 @@ def invite_payload( def test_deserialize_invite( self, - entity_factory_impl, - mock_app, - invite_payload, - partial_channel_payload, - user_payload, - guild_welcome_screen_payload, - alternative_user_payload, - application_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + invite_payload: dict[str, typing.Any], + partial_channel_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + guild_welcome_screen_payload: dict[str, typing.Any], + alternative_user_payload: dict[str, typing.Any], + application_payload: dict[str, typing.Any], ): invite = entity_factory_impl.deserialize_invite(invite_payload) - assert invite.app is mock_app + assert invite.app is hikari_app assert invite.code == "aCode" # InviteGuild + assert invite.guild is not None assert invite.guild.id == 56188492224814744 assert invite.guild.name == "Testin' Your Scene" assert invite.guild.icon_hash == "bb71f469c158984e265093a81b3397fb" @@ -5627,7 +6070,8 @@ def test_deserialize_invite( # InviteApplication application = invite.target_application - assert application.app is mock_app + assert application is not None + assert application.app is hikari_app assert application.id == 773336526917861400 assert application.name == "Betrayal.io" assert application.description == "Play inside Discord with your friends!" @@ -5639,9 +6083,7 @@ def test_deserialize_invite( assert application.cover_image_hash == "0227b2e89ea08d666c43003fbadbc72a (but as cover)" assert isinstance(application, application_models.InviteApplication) - def test_deserialize_invite_with_null_fields( - self, entity_factory_impl, partial_channel_payload, invite_application_payload - ): + def test_deserialize_invite_with_null_fields(self, entity_factory_impl: entity_factory.EntityFactoryImpl): invite = entity_factory_impl.deserialize_invite( { "code": "aCode", @@ -5658,9 +6100,10 @@ def test_deserialize_invite_with_null_fields( } ) assert invite.expires_at is None + assert invite.target_application is not None assert invite.target_application.description is None - def test_deserialize_invite_with_unset_fields(self, entity_factory_impl, partial_channel_payload): + def test_deserialize_invite_with_unset_fields(self, entity_factory_impl: entity_factory.EntityFactoryImpl): invite = entity_factory_impl.deserialize_invite( { "code": "aCode", @@ -5678,7 +6121,9 @@ def test_deserialize_invite_with_unset_fields(self, entity_factory_impl, partial assert invite.target_application is None assert invite.expires_at is None - def test_deserialize_invite_with_unset_sub_fields(self, entity_factory_impl, invite_payload): + def test_deserialize_invite_with_unset_sub_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, invite_payload: dict[str, typing.Any] + ): del invite_payload["guild"]["welcome_screen"] invite_payload["target_application"] = { "id": "773336526917861400", @@ -5689,11 +6134,15 @@ def test_deserialize_invite_with_unset_sub_fields(self, entity_factory_impl, inv invite = entity_factory_impl.deserialize_invite(invite_payload) + assert invite.guild is not None assert invite.guild.welcome_screen is None + assert invite.target_application is not None assert invite.target_application.icon_hash is None assert invite.target_application.cover_image_hash is None - def test_deserialize_invite_with_guild_and_channel_ids_without_objects(self, entity_factory_impl): + def test_deserialize_invite_with_guild_and_channel_ids_without_objects( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): invite = entity_factory_impl.deserialize_invite({"code": "aCode", "guild_id": "42", "channel_id": "202020"}) assert invite.channel is None assert invite.channel_id == 202020 @@ -5703,12 +6152,12 @@ def test_deserialize_invite_with_guild_and_channel_ids_without_objects(self, ent @pytest.fixture def invite_with_metadata_payload( self, - partial_channel_payload, - user_payload, - alternative_user_payload, - guild_welcome_screen_payload, - invite_application_payload, - ): + partial_channel_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + alternative_user_payload: dict[str, typing.Any], + guild_welcome_screen_payload: dict[str, typing.Any], + invite_application_payload: dict[str, typing.Any], + ) -> dict[str, typing.Any]: return { "code": "aCode", "guild": { @@ -5740,19 +6189,19 @@ def invite_with_metadata_payload( def test_deserialize_invite_with_metadata( self, - entity_factory_impl, - mock_app, - invite_with_metadata_payload, - partial_channel_payload, - user_payload, - alternative_user_payload, - guild_welcome_screen_payload, - invite_application_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + invite_with_metadata_payload: dict[str, typing.Any], + partial_channel_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + alternative_user_payload: dict[str, typing.Any], + guild_welcome_screen_payload: dict[str, typing.Any], ): invite_with_metadata = entity_factory_impl.deserialize_invite_with_metadata(invite_with_metadata_payload) - assert invite_with_metadata.app is mock_app + assert invite_with_metadata.app is hikari_app assert invite_with_metadata.code == "aCode" # InviteGuild + assert invite_with_metadata.guild is not None assert invite_with_metadata.guild.id == 56188492224814744 assert invite_with_metadata.guild.name == "Testin' Your Scene" assert invite_with_metadata.guild.icon_hash == "bb71f469c158984e265093a81b3397fb" @@ -5787,7 +6236,8 @@ def test_deserialize_invite_with_metadata( # InviteApplication application = invite_with_metadata.target_application - assert application.app is mock_app + assert application is not None + assert application.app is hikari_app assert application.id == 773336526917861400 assert application.name == "Betrayal.io" assert application.description == "Play inside Discord with your friends!" @@ -5800,7 +6250,7 @@ def test_deserialize_invite_with_metadata( assert isinstance(application, application_models.InviteApplication) def test_deserialize_invite_with_metadata_with_unset_and_0_fields( - self, entity_factory_impl, partial_channel_payload + self, entity_factory_impl: entity_factory.EntityFactoryImpl, partial_channel_payload: dict[str, typing.Any] ): invite_with_metadata = entity_factory_impl.deserialize_invite_with_metadata( { @@ -5825,14 +6275,17 @@ def test_deserialize_invite_with_metadata_with_unset_and_0_fields( assert invite_with_metadata.expires_at is None def test_deserialize_invite_with_metadata_with_null_guild_fields( - self, entity_factory_impl, invite_with_metadata_payload + self, entity_factory_impl: entity_factory.EntityFactoryImpl, invite_with_metadata_payload: dict[str, typing.Any] ): del invite_with_metadata_payload["guild"]["welcome_screen"] invite = entity_factory_impl.deserialize_invite_with_metadata(invite_with_metadata_payload) + assert invite.guild is not None assert invite.guild.welcome_screen is None - def test_max_age_when_zero(self, entity_factory_impl, invite_with_metadata_payload): + def test_max_age_when_zero( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, invite_with_metadata_payload: dict[str, typing.Any] + ): invite_with_metadata_payload["max_age"] = 0 assert entity_factory_impl.deserialize_invite_with_metadata(invite_with_metadata_payload).max_age is None @@ -5841,11 +6294,11 @@ def test_max_age_when_zero(self, entity_factory_impl, invite_with_metadata_paylo #################### @pytest.fixture - def action_row_payload(self, button_payload): + def action_row_payload(self, button_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return {"type": 1, "id": 8394572, "components": [button_payload]} @pytest.fixture - def button_payload(self, custom_emoji_payload): + def button_payload(self, custom_emoji_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "type": 2, "id": 9173652, @@ -5857,7 +6310,12 @@ def button_payload(self, custom_emoji_payload): "disabled": True, } - def test__deserialize_button(self, entity_factory_impl, button_payload, custom_emoji_payload): + def test_deserialize__deserialize_button( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + button_payload: dict[str, typing.Any], + custom_emoji_payload: dict[str, typing.Any], + ): button = entity_factory_impl._deserialize_button(button_payload) assert button.type is component_models.ComponentType.BUTTON @@ -5870,7 +6328,7 @@ def test__deserialize_button(self, entity_factory_impl, button_payload, custom_e assert button.url == "okokok" def test_deserialize__deserialize_button_with_unset_fields( - self, entity_factory_impl, button_payload, custom_emoji_payload + self, entity_factory_impl: entity_factory.EntityFactoryImpl ): button = entity_factory_impl._deserialize_button({"id": 0, "type": 2, "style": 5}) @@ -5884,7 +6342,7 @@ def test_deserialize__deserialize_button_with_unset_fields( assert button.is_disabled is False @pytest.fixture - def select_menu_payload(self, custom_emoji_payload): + def select_menu_payload(self, custom_emoji_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "type": 5, "id": 9830741, @@ -5904,7 +6362,12 @@ def select_menu_payload(self, custom_emoji_payload): "disabled": True, } - def test__deserialize_text_select_menu(self, entity_factory_impl, select_menu_payload, custom_emoji_payload): + def test__deserialize_text_select_menu( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + select_menu_payload: dict[str, typing.Any], + custom_emoji_payload: dict[str, typing.Any], + ): menu = entity_factory_impl._deserialize_text_select_menu(select_menu_payload) assert menu.type is component_models.ComponentType.USER_SELECT_MENU @@ -5926,7 +6389,7 @@ def test__deserialize_text_select_menu(self, entity_factory_impl, select_menu_pa assert menu.max_values == 420 assert menu.is_disabled is True - def test__deserialize_text_select_menu_partial(self, entity_factory_impl): + def test__deserialize_text_select_menu_partial(self, entity_factory_impl: entity_factory.EntityFactoryImpl): menu = entity_factory_impl._deserialize_text_select_menu( {"type": 3, "id": 0, "custom_id": "Not an ID", "options": [{"label": "Trans", "value": "very trans"}]} ) @@ -5949,7 +6412,9 @@ def test__deserialize_text_select_menu_partial(self, entity_factory_impl): def text_input_payload(self): return {"type": 4, "id": 3904875, "custom_id": "name", "value": "Wumpus"} - def test__deserialize_text_input(self, entity_factory_impl, text_input_payload): + def test__deserialize_text_input( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, text_input_payload: dict[str, typing.Any] + ): text_input = entity_factory_impl._deserialize_text_input(text_input_payload) assert text_input.type == component_models.ComponentType.TEXT_INPUT @@ -5975,7 +6440,7 @@ def text_display_payload(self): return {"type": 10, "id": 9840745, "content": "A text display!"} @pytest.fixture - def thumbnail_payload(self, media_payload): + def thumbnail_payload(self, media_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "type": 11, "id": 9824133, @@ -5985,30 +6450,32 @@ def thumbnail_payload(self, media_payload): } @pytest.fixture - def section_payload(self, button_payload, text_display_payload): + def section_payload(self, button_payload: dict[str, typing.Any], text_display_payload: dict[str, typing.Any]): return {"type": 9, "id": 9478385, "accessory": button_payload, "components": [text_display_payload]} @pytest.fixture - def media_gallery_item_payload(self, media_payload): + def media_gallery_item_payload(self, media_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return {"media": media_payload, "description": "Gallery item description?", "spoiler": True} @pytest.fixture - def media_gallery_payload(self, media_gallery_item_payload): + def media_gallery_payload(self, media_gallery_item_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return {"type": 12, "id": 9267351, "items": [media_gallery_item_payload]} @pytest.fixture - def separator_payload(self): + def separator_payload(self) -> dict[str, typing.Any]: return {"type": 14, "id": 4920478, "spacing": 1, "divider": True} @pytest.fixture - def file_payload(self, media_payload): + def file_payload(self, media_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return {"type": 13, "id": 2407385, "file": media_payload, "spoiler": False} @pytest.fixture - def container_payload(self, file_payload): + def container_payload(self, file_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return {"type": 17, "id": 5830957, "accent_color": 16757027, "spoiler": True, "components": [file_payload]} - def test__deserialize_media(self, entity_factory_impl, media_payload): + def test__deserialize_media( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, media_payload: dict[str, typing.Any] + ): media = entity_factory_impl._deserialize_media(media_payload) assert media.url == "https://com.com.com.com.com.com.com.com.com.com/" @@ -6023,7 +6490,9 @@ def test__deserialize_media(self, entity_factory_impl, media_payload): assert isinstance(media, component_models.MediaResource) - def test__deserialize_media_with_unset_fields(self, entity_factory_impl, media_payload): + def test__deserialize_media_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, media_payload: dict[str, typing.Any] + ): del media_payload["proxy_url"] del media_payload["width"] del media_payload["height"] @@ -6040,7 +6509,9 @@ def test__deserialize_media_with_unset_fields(self, entity_factory_impl, media_p assert isinstance(media, component_models.MediaResource) - def test__deserialize_media_with_nullable_fields(self, entity_factory_impl, media_payload): + def test__deserialize_media_with_nullable_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, media_payload: dict[str, typing.Any] + ): media_payload["width"] = None media_payload["height"] = None media_payload["content_type"] = None @@ -6055,7 +6526,12 @@ def test__deserialize_media_with_nullable_fields(self, entity_factory_impl, medi assert isinstance(media, component_models.MediaResource) - def test__deserialize_action_row_component(self, entity_factory_impl, action_row_payload, button_payload): + def test__deserialize_action_row_component( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + action_row_payload: dict[str, typing.Any], + button_payload: dict[str, typing.Any], + ): action_row = entity_factory_impl._deserialize_action_row_component(action_row_payload) assert action_row.type == component_models.ComponentType.ACTION_ROW @@ -6066,7 +6542,7 @@ def test__deserialize_action_row_component(self, entity_factory_impl, action_row assert isinstance(action_row, component_models.ActionRowComponent) def test__deserialize_action_row_component_with_unknown_component_type( - self, entity_factory_impl, action_row_payload + self, entity_factory_impl: entity_factory.EntityFactoryImpl, action_row_payload: dict[str, typing.Any] ): action_row_payload["components"] = [{"type": -9999}, {"type": 9999}] @@ -6075,7 +6551,11 @@ def test__deserialize_action_row_component_with_unknown_component_type( assert action_row.components == [] def test__deserialize_section_component( - self, entity_factory_impl, section_payload, button_payload, text_display_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + section_payload: dict[str, typing.Any], + button_payload: dict[str, typing.Any], + text_display_payload: dict[str, typing.Any], ): section = entity_factory_impl._deserialize_section_component(section_payload) @@ -6085,12 +6565,19 @@ def test__deserialize_section_component( assert section.accessory == entity_factory_impl._deserialize_button(button_payload) assert section.components == [entity_factory_impl._deserialize_text_display_component(text_display_payload)] - def test__deserialize_section_component_with_unknown_accessory_type(self, entity_factory_impl, section_payload): + def test__deserialize_section_component_with_unknown_accessory_type( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, section_payload: dict[str, typing.Any] + ): section_payload["accessory"] = {"type": 9999} with pytest.raises(errors.UnrecognisedEntityError, match=r"Unknown section accessory type 9999"): entity_factory_impl._deserialize_section_component(section_payload) - def test__deserialize_thumbnail_component(self, entity_factory_impl, thumbnail_payload, media_payload): + def test__deserialize_thumbnail_component( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + thumbnail_payload: dict[str, typing.Any], + media_payload: dict[str, typing.Any], + ): thumbnail = entity_factory_impl._deserialize_thumbnail_component(thumbnail_payload) assert thumbnail.type == component_models.ComponentType.THUMBNAIL @@ -6101,7 +6588,9 @@ def test__deserialize_thumbnail_component(self, entity_factory_impl, thumbnail_p assert isinstance(thumbnail, component_models.ThumbnailComponent) - def test__deserialize_thumbnail_component_with_unset_fields(self, entity_factory_impl, thumbnail_payload): + def test__deserialize_thumbnail_component_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, thumbnail_payload: dict[str, typing.Any] + ): del thumbnail_payload["description"] del thumbnail_payload["spoiler"] @@ -6110,7 +6599,9 @@ def test__deserialize_thumbnail_component_with_unset_fields(self, entity_factory assert thumbnail.description is None assert thumbnail.is_spoiler is False - def test__deserialize_text_display_component(self, entity_factory_impl, text_display_payload): + def test__deserialize_text_display_component( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, text_display_payload: dict[str, typing.Any] + ): text_display = entity_factory_impl._deserialize_text_display_component(text_display_payload) assert text_display.type == component_models.ComponentType.TEXT_DISPLAY @@ -6120,7 +6611,10 @@ def test__deserialize_text_display_component(self, entity_factory_impl, text_dis assert isinstance(text_display, component_models.TextDisplayComponent) def test__deserialize_media_gallery_component( - self, entity_factory_impl, media_gallery_payload, media_gallery_item_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + media_gallery_payload: dict[str, typing.Any], + media_gallery_item_payload: dict[str, typing.Any], ): media_gallery = entity_factory_impl._deserialize_media_gallery_component(media_gallery_payload) @@ -6130,7 +6624,12 @@ def test__deserialize_media_gallery_component( assert isinstance(media_gallery, component_models.MediaGalleryComponent) - def test__deserialize_media_gallery_item(self, entity_factory_impl, media_gallery_item_payload, media_payload): + def test__deserialize_media_gallery_item( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + media_gallery_item_payload: dict[str, typing.Any], + media_payload: dict[str, typing.Any], + ): media_gallery_item = entity_factory_impl._deserialize_media_gallery_item(media_gallery_item_payload) assert media_gallery_item.media == entity_factory_impl._deserialize_media(media_payload) @@ -6139,7 +6638,9 @@ def test__deserialize_media_gallery_item(self, entity_factory_impl, media_galler assert isinstance(media_gallery_item, component_models.MediaGalleryItem) - def test__deserialize_media_gallery_item_with_unset_fields(self, entity_factory_impl, media_gallery_item_payload): + def test__deserialize_media_gallery_item_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, media_gallery_item_payload: dict[str, typing.Any] + ): del media_gallery_item_payload["description"] del media_gallery_item_payload["spoiler"] @@ -6148,7 +6649,9 @@ def test__deserialize_media_gallery_item_with_unset_fields(self, entity_factory_ assert media_gallery_item.description is None assert media_gallery_item.is_spoiler is False - def test__deserialize_separator_component(self, entity_factory_impl, separator_payload): + def test__deserialize_separator_component( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, separator_payload: dict[str, typing.Any] + ): separator = entity_factory_impl._deserialize_separator_component(separator_payload) assert separator.type == component_models.ComponentType.SEPARATOR @@ -6158,14 +6661,21 @@ def test__deserialize_separator_component(self, entity_factory_impl, separator_p assert isinstance(separator, component_models.SeparatorComponent) - def test__deserialize_separator_component_with_unset_fields(self, entity_factory_impl, separator_payload): + def test__deserialize_separator_component_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, separator_payload: dict[str, typing.Any] + ): del separator_payload["divider"] separator = entity_factory_impl._deserialize_separator_component(separator_payload) assert separator.divider is False - def test__deserialize_file_component(self, entity_factory_impl, file_payload, media_payload): + def test__deserialize_file_component( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + file_payload: dict[str, typing.Any], + media_payload: dict[str, typing.Any], + ): file = entity_factory_impl._deserialize_file_component(file_payload) assert file.type == component_models.ComponentType.FILE @@ -6175,14 +6685,21 @@ def test__deserialize_file_component(self, entity_factory_impl, file_payload, me assert isinstance(file, component_models.FileComponent) - def test__deserialize_file_component_with_unset_fields(self, entity_factory_impl, file_payload): + def test__deserialize_file_component_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, file_payload: dict[str, typing.Any] + ): del file_payload["spoiler"] file = entity_factory_impl._deserialize_file_component(file_payload) assert file.is_spoiler is False - def test__deserialize_container_component(self, entity_factory_impl, container_payload, file_payload): + def test__deserialize_container_component( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + container_payload: dict[str, typing.Any], + file_payload: dict[str, typing.Any], + ): container = entity_factory_impl._deserialize_container_component(container_payload) assert container.type == component_models.ComponentType.CONTAINER @@ -6193,7 +6710,9 @@ def test__deserialize_container_component(self, entity_factory_impl, container_p assert isinstance(container, component_models.ContainerComponent) - def test__deserialize_container_component_with_unset_fields(self, entity_factory_impl, container_payload): + def test__deserialize_container_component_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, container_payload: dict[str, typing.Any] + ): del container_payload["accent_color"] del container_payload["spoiler"] @@ -6202,14 +6721,18 @@ def test__deserialize_container_component_with_unset_fields(self, entity_factory assert container.accent_color is None assert container.is_spoiler is False - def test__deserialize_container_component_with_nullable_fields(self, entity_factory_impl, container_payload): + def test__deserialize_container_component_with_nullable_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, container_payload: dict[str, typing.Any] + ): container_payload["accent_color"] = None container = entity_factory_impl._deserialize_container_component(container_payload) assert container.accent_color is None - def test__deserialize_container_component_with_unknown_component_type(self, entity_factory_impl, container_payload): + def test__deserialize_container_component_with_unknown_component_type( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, container_payload: dict[str, typing.Any] + ): container_payload["components"] = [{"type": 9999}] container = entity_factory_impl._deserialize_container_component(container_payload) @@ -6218,13 +6741,13 @@ def test__deserialize_container_component_with_unknown_component_type(self, enti def test__deserialize_message_components( self, - entity_factory_impl, - action_row_payload, - text_display_payload, - section_payload, - media_gallery_payload, - separator_payload, - file_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + action_row_payload: dict[str, typing.Any], + text_display_payload: dict[str, typing.Any], + section_payload: dict[str, typing.Any], + media_gallery_payload: dict[str, typing.Any], + separator_payload: dict[str, typing.Any], + file_payload: dict[str, typing.Any], ): message_components = entity_factory_impl._deserialize_top_level_components( [ @@ -6251,12 +6774,19 @@ def test__deserialize_message_components( assert message_components[5] == entity_factory_impl._deserialize_file_component(file_payload) - def test__deserialize_message_components_handles_unknown_top_component_type(self, entity_factory_impl): + def test__deserialize_message_components_handles_unknown_top_component_type( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): message_components = entity_factory_impl._deserialize_top_level_components([{"type": 9999}, {"type": -9999}]) assert len(message_components) == 0 - def test__deserialize_modal_components(self, entity_factory_impl, action_row_payload, text_input_payload): + def test__deserialize_modal_components( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + action_row_payload: dict[str, typing.Any], + text_input_payload: dict[str, typing.Any], + ): action_row_payload["components"] = [text_input_payload] modal_components = entity_factory_impl._deserialize_modal_components([action_row_payload]) @@ -6269,7 +6799,9 @@ def test__deserialize_modal_components(self, entity_factory_impl, action_row_pay components=[entity_factory_impl._deserialize_text_input(text_input_payload)], ) - def test__deserialize_modal_components_handles_unknown_top_component_type(self, entity_factory_impl): + def test__deserialize_modal_components_handles_unknown_top_component_type( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): modal_components = entity_factory_impl._deserialize_modal_components([{"type": 9999}]) assert len(modal_components) == 0 @@ -6286,16 +6818,24 @@ def test__deserialize_modal_components_handles_unknown_top_component_type(self, (8, "_deserialize_channel_select_menu", "_message_component_type_mapping"), ], ) - def test__deserialize_components(self, mock_app, type_, fn, mapping): + def test__deserialize_components( + self, + hikari_app: traits.RESTAware, + entity_factory_impl: entity_factory.EntityFactoryImpl, + type_: int, + fn: str, + mapping: str, + ): component_payload = {"type": type_} payload = [{"type": 1, "components": [component_payload]}] with mock.patch.object(entity_factory.EntityFactoryImpl, fn) as expected_fn: # We need to instantiate it after the mock so that the functions that are stored in the dicts # are the ones we mock - entity_factory_impl = entity_factory.EntityFactoryImpl(app=mock_app) - - components = entity_factory_impl._deserialize_components(payload, getattr(entity_factory_impl, mapping)) + entity_factory_impl = entity_factory.EntityFactoryImpl(app=hikari_app) + components = entity_factory_impl._deserialize_message_components( + payload, getattr(entity_factory_impl, mapping) + ) expected_fn.assert_called_once_with(component_payload) action_row = components[0] @@ -6303,7 +6843,9 @@ def test__deserialize_components(self, mock_app, type_, fn, mapping): assert action_row.components[0] is expected_fn.return_value @pytest.mark.skip("Pending removal.") - def test__deserialize_components_handles_unknown_top_component_type(self, entity_factory_impl): + def test__deserialize_components_handles_unknown_top_component_type( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): components = entity_factory_impl._deserialize_components( [ # Unknown top-level component @@ -6328,7 +6870,7 @@ def test__deserialize_components_handles_unknown_top_component_type(self, entity ################## @pytest.fixture - def partial_application_payload(self): + def partial_application_payload(self) -> dict[str, typing.Any]: return { "id": "456", "name": "hikari", @@ -6338,7 +6880,7 @@ def partial_application_payload(self): } @pytest.fixture - def referenced_message(self, user_payload): + def referenced_message(self, user_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "12312312", "channel_id": "949494", @@ -6358,7 +6900,7 @@ def referenced_message(self, user_payload): } @pytest.fixture - def attachment_payload(self): + def attachment_payload(self) -> dict[str, typing.Any]: return { "id": "690922406474154014", "filename": "IMG.jpg", @@ -6376,7 +6918,7 @@ def attachment_payload(self): } @pytest.fixture - def partial_interaction_metadata_payload(self, user_payload): + def partial_interaction_metadata_payload(self, user_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "123456", "type": 2, @@ -6388,19 +6930,19 @@ def partial_interaction_metadata_payload(self, user_payload): @pytest.fixture def message_payload( self, - user_payload, - member_payload, - custom_emoji_payload, - partial_application_payload, - embed_payload, - poll_payload, - referenced_message, - action_row_payload, - partial_sticker_payload, - attachment_payload, - guild_public_thread_payload, - partial_interaction_metadata_payload, - ): + user_payload: dict[str, typing.Any], + member_payload: dict[str, typing.Any], + custom_emoji_payload: dict[str, typing.Any], + partial_application_payload: dict[str, typing.Any], + embed_payload: dict[str, typing.Any], + poll_payload: dict[str, typing.Any], + referenced_message: dict[str, typing.Any], + action_row_payload: dict[str, typing.Any], + partial_sticker_payload: dict[str, typing.Any], + attachment_payload: dict[str, typing.Any], + guild_public_thread_payload: dict[str, typing.Any], + partial_interaction_metadata_payload: dict[str, typing.Any], + ) -> dict[str, typing.Any]: member_payload = member_payload.copy() del member_payload["user"] @@ -6447,7 +6989,9 @@ def message_payload( "interaction_metadata": partial_interaction_metadata_payload, } - def test__deserialize_message_attachment(self, entity_factory_impl, attachment_payload): + def test__deserialize_message_attachment( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, attachment_payload: dict[str, typing.Any] + ): attachment = entity_factory_impl._deserialize_message_attachment(attachment_payload) assert attachment.id == 690922406474154014 @@ -6465,7 +7009,9 @@ def test__deserialize_message_attachment(self, entity_factory_impl, attachment_p assert attachment.waveform == "some encoded string" assert isinstance(attachment, message_models.Attachment) - def test__deserialize_message_attachment_with_null_fields(self, entity_factory_impl, attachment_payload): + def test__deserialize_message_attachment_with_null_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, attachment_payload: dict[str, typing.Any] + ): attachment_payload["height"] = None attachment_payload["width"] = None @@ -6475,7 +7021,9 @@ def test__deserialize_message_attachment_with_null_fields(self, entity_factory_i assert attachment.width is None assert isinstance(attachment, message_models.Attachment) - def test__deserialize_message_attachment_with_unset_fields(self, entity_factory_impl, attachment_payload): + def test__deserialize_message_attachment_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, attachment_payload: dict[str, typing.Any] + ): del attachment_payload["title"] del attachment_payload["description"] del attachment_payload["content_type"] @@ -6497,7 +7045,10 @@ def test__deserialize_message_attachment_with_unset_fields(self, entity_factory_ assert attachment.waveform is None def test__deserialize_partial_message_interaction_metadata( - self, entity_factory_impl, partial_interaction_metadata_payload, user_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + partial_interaction_metadata_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], ): partial_message_interaction_metadata = entity_factory_impl._deserialize_command_interaction_metadata( partial_interaction_metadata_payload @@ -6516,7 +7067,10 @@ def test__deserialize_partial_message_interaction_metadata( assert isinstance(partial_message_interaction_metadata, base_interactions.PartialInteractionMetadata) def test__deserialize_command_interaction_metadata( - self, entity_factory_impl, partial_interaction_metadata_payload, user_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + partial_interaction_metadata_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], ): partial_interaction_metadata_payload["target_user"] = user_payload partial_interaction_metadata_payload["target_message_id"] = "59332" @@ -6531,7 +7085,9 @@ def test__deserialize_command_interaction_metadata( assert command_interaction_metadata.target_message_id == snowflakes.Snowflake(59332) def test__deserialize_message_component_interaction_metadata( - self, entity_factory_impl, partial_interaction_metadata_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + partial_interaction_metadata_payload: dict[str, typing.Any], ): partial_interaction_metadata_payload["interacted_message_id"] = "684831" @@ -6545,7 +7101,10 @@ def test__deserialize_message_component_interaction_metadata( assert message_component_interaction_metadata.interacted_message_id == snowflakes.Snowflake(684831) def test__deserialize_modal_interaction_metadata_with_commmand_interaction( - self, entity_factory_impl, partial_interaction_metadata_payload, user_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + partial_interaction_metadata_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], ): component_interaction_metadata_payload = dict(partial_interaction_metadata_payload) component_interaction_metadata_payload["target_user"] = user_payload @@ -6570,7 +7129,9 @@ def test__deserialize_modal_interaction_metadata_with_commmand_interaction( ) def test__deserialize_modal_interaction_metadata_with_component_interaction( - self, entity_factory_impl, partial_interaction_metadata_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + partial_interaction_metadata_payload: dict[str, typing.Any], ): command_interaction_metadata_payload = dict(partial_interaction_metadata_payload) command_interaction_metadata_payload["type"] = 3 @@ -6592,22 +7153,21 @@ def test__deserialize_modal_interaction_metadata_with_component_interaction( def test_deserialize_partial_message( self, - entity_factory_impl, - mock_app, - message_payload, - user_payload, - member_payload, - partial_application_payload, - custom_emoji_payload, - embed_payload, - poll_payload, - referenced_message, - action_row_payload, - attachment_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + message_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + member_payload: dict[str, typing.Any], + custom_emoji_payload: dict[str, typing.Any], + embed_payload: dict[str, typing.Any], + poll_payload: dict[str, typing.Any], + referenced_message: dict[str, typing.Any], + action_row_payload: dict[str, typing.Any], + attachment_payload: dict[str, typing.Any], ): partial_message = entity_factory_impl.deserialize_partial_message(message_payload) - assert partial_message.app is mock_app + assert partial_message.app is hikari_app assert partial_message.id == 123 assert partial_message.channel_id == 456 assert partial_message.guild_id == 678 @@ -6632,6 +7192,7 @@ def test_deserialize_partial_message( expected_embed = entity_factory_impl.deserialize_embed(embed_payload) assert partial_message.embeds == [expected_embed] # Reaction + assert partial_message.reactions is not undefined.UNDEFINED reaction = partial_message.reactions[0] assert reaction.count == 100 assert reaction.is_me is True @@ -6644,11 +7205,15 @@ def test_deserialize_partial_message( assert partial_message.type == message_models.MessageType.DEFAULT # Activity + assert partial_message.activity is not undefined.UNDEFINED + assert partial_message.activity is not None assert partial_message.activity.type == message_models.MessageActivityType.JOIN_REQUEST assert partial_message.activity.party_id == "ae488379-351d-4a4f-ad32-2b9b01c91657" assert isinstance(partial_message.activity, message_models.MessageActivity) # Message Activity + assert partial_message.application is not undefined.UNDEFINED + assert partial_message.application is not None assert partial_message.application.id == 456 assert partial_message.application.name == "hikari" assert partial_message.application.description == "The best application" @@ -6656,7 +7221,9 @@ def test_deserialize_partial_message( assert partial_message.application.cover_image_hash == "58982a23790c4f22787b05d3be38a026" assert isinstance(partial_message.application, message_models.MessageApplication) # MessageReference - assert partial_message.message_reference.app is mock_app + assert partial_message.message_reference is not undefined.UNDEFINED + assert partial_message.message_reference is not None + assert partial_message.message_reference.app is hikari_app assert partial_message.message_reference.id == 306588351130107906 assert partial_message.message_reference.channel_id == 278325129692446722 assert partial_message.message_reference.guild_id == 278325129692446720 @@ -6666,6 +7233,7 @@ def test_deserialize_partial_message( assert partial_message.flags == message_models.MessageFlag.IS_CROSSPOST # Sticker + assert partial_message.stickers is not undefined.UNDEFINED assert len(partial_message.stickers) == 1 sticker = partial_message.stickers[0] assert sticker.id == 749046696482439188 @@ -6680,6 +7248,7 @@ def test_deserialize_partial_message( assert partial_message.components == entity_factory_impl._deserialize_top_level_components([action_row_payload]) # InteractionMetadata + assert isinstance(partial_message.interaction_metadata, command_interactions.CommandInteractionMetadata) assert partial_message.interaction_metadata.interaction_id == snowflakes.Snowflake(123456) assert partial_message.interaction_metadata.type == base_interactions.InteractionType.APPLICATION_COMMAND assert partial_message.interaction_metadata.user == entity_factory_impl.deserialize_user(user_payload) @@ -6692,12 +7261,13 @@ def test_deserialize_partial_message( assert partial_message.interaction_metadata.original_response_message_id == snowflakes.Snowflake(9564) assert partial_message.interaction_metadata.target_user == entity_factory_impl.deserialize_user(user_payload) assert partial_message.interaction_metadata.target_message_id == snowflakes.Snowflake(59332) - assert isinstance(partial_message.interaction_metadata, command_interactions.CommandInteractionMetadata) # Poll assert partial_message.poll == entity_factory_impl.deserialize_poll(poll_payload) - def test_deserialize_partial_message_with_snapshot(self, entity_factory_impl, message_payload): + def test_deserialize_partial_message_with_snapshot( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, message_payload: dict[str, typing.Any] + ): del message_payload["message_reference"] del message_payload["referenced_message"] @@ -6716,7 +7286,9 @@ def test_deserialize_partial_message_with_snapshot(self, entity_factory_impl, me assert snapshot.flags == message_models.MessageFlag.HAS_THREAD assert snapshot.type == message_models.MessageType.DEFAULT - def test_deserialize_partial_message_with_partial_fields(self, entity_factory_impl, message_payload): + def test_deserialize_partial_message_with_partial_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, message_payload: dict[str, typing.Any] + ): message_payload["content"] = "" message_payload["edited_timestamp"] = None message_payload["application"]["icon"] = None @@ -6733,17 +7305,23 @@ def test_deserialize_partial_message_with_partial_fields(self, entity_factory_im assert partial_message.edited_timestamp is None assert partial_message.guild_id is not None assert partial_message.member is undefined.UNDEFINED + assert partial_message.application is not undefined.UNDEFINED + assert partial_message.application is not None assert partial_message.application.icon_hash is None assert partial_message.application.cover_image_hash is None + assert partial_message.message_reference is not undefined.UNDEFINED + assert partial_message.message_reference is not None assert partial_message.message_reference.id is None assert partial_message.message_reference.guild_id is None assert partial_message.referenced_message is None assert partial_message.interaction_metadata is None - def test_deserialize_partial_message_with_unset_fields(self, entity_factory_impl, mock_app): + def test_deserialize_partial_message_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, hikari_app: traits.RESTAware + ): partial_message = entity_factory_impl.deserialize_partial_message({"id": 123, "channel_id": 456}) - assert partial_message.app is mock_app + assert partial_message.app is hikari_app assert partial_message.id == 123 assert partial_message.channel_id == 456 assert partial_message.guild_id is None @@ -6774,19 +7352,24 @@ def test_deserialize_partial_message_with_unset_fields(self, entity_factory_impl assert partial_message.application_id is undefined.UNDEFINED assert partial_message.components is undefined.UNDEFINED - def test_deserialize_partial_message_with_guild_id_but_no_author(self, entity_factory_impl): + def test_deserialize_partial_message_with_guild_id_but_no_author( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): partial_message = entity_factory_impl.deserialize_partial_message( {"id": 123, "channel_id": 456, "guild_id": 987} ) assert partial_message.member is None - def test_deserialize_partial_message_deserializes_old_stickers_field(self, entity_factory_impl, message_payload): + def test_deserialize_partial_message_deserializes_old_stickers_field( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, message_payload: dict[str, typing.Any] + ): message_payload["stickers"] = message_payload["sticker_items"] del message_payload["sticker_items"] partial_message = entity_factory_impl.deserialize_partial_message(message_payload) + assert partial_message.stickers is not undefined.UNDEFINED assert len(partial_message.stickers) == 1 sticker = partial_message.stickers[0] assert sticker.id == 749046696482439188 @@ -6796,12 +7379,12 @@ def test_deserialize_partial_message_deserializes_old_stickers_field(self, entit def test_deserialize_message_snapshot( self, - entity_factory_impl, - embed_payload, - attachment_payload, - user_payload, - custom_emoji_payload, - action_row_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + embed_payload: dict[str, typing.Any], + attachment_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + custom_emoji_payload: dict[str, typing.Any], + action_row_payload: dict[str, typing.Any], ): payload = { "message": { @@ -6842,12 +7425,12 @@ def test_deserialize_message_snapshot( def test_deserialize_message_snapshot_sticker_items_field( self, - entity_factory_impl, - embed_payload, - attachment_payload, - user_payload, - custom_emoji_payload, - action_row_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + embed_payload: dict[str, typing.Any], + attachment_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + custom_emoji_payload: dict[str, typing.Any], + action_row_payload: dict[str, typing.Any], ): payload = { "message": { @@ -6871,12 +7454,12 @@ def test_deserialize_message_snapshot_sticker_items_field( def test_deserialize_message_snapshot_no_edit( self, - entity_factory_impl, - embed_payload, - attachment_payload, - user_payload, - custom_emoji_payload, - action_row_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + embed_payload: dict[str, typing.Any], + attachment_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + custom_emoji_payload: dict[str, typing.Any], + action_row_payload: dict[str, typing.Any], ): payload = { "message": { @@ -6900,12 +7483,12 @@ def test_deserialize_message_snapshot_no_edit( def test_deserialize_message_snapshot_all_unset( self, - entity_factory_impl, - embed_payload, - attachment_payload, - user_payload, - custom_emoji_payload, - action_row_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + embed_payload: dict[str, typing.Any], + attachment_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + custom_emoji_payload: dict[str, typing.Any], + action_row_payload: dict[str, typing.Any], ): payload = {"message": {"type": 0}} message_snapshot = entity_factory_impl.deserialize_message_snapshot(payload) @@ -6922,20 +7505,20 @@ def test_deserialize_message_snapshot_all_unset( def test_deserialize_message( self, - entity_factory_impl, - mock_app, - message_payload, - user_payload, - member_payload, - custom_emoji_payload, - embed_payload, - referenced_message, - action_row_payload, - poll_payload, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + message_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + member_payload: dict[str, typing.Any], + custom_emoji_payload: dict[str, typing.Any], + embed_payload: dict[str, typing.Any], + referenced_message: dict[str, typing.Any], + action_row_payload: dict[str, typing.Any], + poll_payload: dict[str, typing.Any], ): message = entity_factory_impl.deserialize_message(message_payload) - assert message.app is mock_app + assert message.app is hikari_app assert message.id == 123 assert message.channel_id == 456 assert message.guild_id == 678 @@ -6987,11 +7570,13 @@ def test_deserialize_message( assert message.type == message_models.MessageType.DEFAULT # Activity + assert message.activity is not None assert message.activity.type == message_models.MessageActivityType.JOIN_REQUEST assert message.activity.party_id == "ae488379-351d-4a4f-ad32-2b9b01c91657" assert isinstance(message.activity, message_models.MessageActivity) # MessageApplication + assert message.application is not None assert message.application.id == 456 assert message.application.name == "hikari" assert message.application.description == "The best application" @@ -7000,7 +7585,8 @@ def test_deserialize_message( assert isinstance(message.application, message_models.MessageApplication) # MessageReference - assert message.message_reference.app is mock_app + assert message.message_reference + assert message.message_reference.app is hikari_app assert message.message_reference.id == 306588351130107906 assert message.message_reference.channel_id == 278325129692446722 assert message.message_reference.guild_id == 278325129692446720 @@ -7033,7 +7619,9 @@ def test_deserialize_message( assert message.thread.flags == channel_models.ChannelFlag.PINNED assert message.thread.name == "e" - def test_deserialize_message_with_snapshot(self, entity_factory_impl, message_payload): + def test_deserialize_message_with_snapshot( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, message_payload: dict[str, typing.Any] + ): del message_payload["message_reference"] del message_payload["referenced_message"] @@ -7052,7 +7640,9 @@ def test_deserialize_message_with_snapshot(self, entity_factory_impl, message_pa assert snapshot.flags == message_models.MessageFlag.HAS_THREAD assert snapshot.type == message_models.MessageType.DEFAULT - def test_deserialize_message_with_unset_sub_fields(self, entity_factory_impl, message_payload): + def test_deserialize_message_with_unset_sub_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, message_payload: dict[str, typing.Any] + ): del message_payload["application"]["cover_image"] del message_payload["activity"]["party_id"] del message_payload["message_reference"]["message_id"] @@ -7066,14 +7656,17 @@ def test_deserialize_message_with_unset_sub_fields(self, entity_factory_impl, me assert message.channel_mentions == {} # Activity + assert message.activity is not None assert message.activity.party_id is None assert isinstance(message.activity, message_models.MessageActivity) # MessageApplication + assert message.application is not None assert message.application.cover_image_hash is None assert isinstance(message.application, message_models.MessageApplication) # MessageReference + assert message.message_reference is not None assert message.message_reference.id is None assert message.message_reference.guild_id is None assert isinstance(message.message_reference, message_models.MessageReference) @@ -7082,18 +7675,26 @@ def test_deserialize_message_with_unset_sub_fields(self, entity_factory_impl, me assert message.thread is None # Poll - message.poll is None + assert message.poll is None - def test_deserialize_message_with_null_sub_fields(self, entity_factory_impl, message_payload): + def test_deserialize_message_with_null_sub_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, message_payload: dict[str, typing.Any] + ): message_payload["application"]["icon"] = None message = entity_factory_impl.deserialize_message(message_payload) # MessageApplication + assert message.application is not None assert message.application.icon_hash is None assert isinstance(message.application, message_models.MessageApplication) - def test_deserialize_message_with_null_and_unset_fields(self, entity_factory_impl, mock_app, user_payload): - message_payload = { + def test_deserialize_message_with_null_and_unset_fields( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + user_payload: dict[str, typing.Any], + ): + message_payload: dict[str, typing.Any] = { "id": "123", "channel_id": "456", "author": user_payload, @@ -7112,7 +7713,7 @@ def test_deserialize_message_with_null_and_unset_fields(self, entity_factory_imp } message = entity_factory_impl.deserialize_message(message_payload) - assert message.app is mock_app + assert message.app is hikari_app assert message.content is None assert message.guild_id is None assert message.member is None @@ -7135,19 +7736,24 @@ def test_deserialize_message_with_null_and_unset_fields(self, entity_factory_imp assert message.application_id is None assert message.components == [] - def test_deserialize_message_with_other_unset_fields(self, entity_factory_impl, message_payload): + def test_deserialize_message_with_other_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, message_payload: dict[str, typing.Any] + ): message_payload["application"]["icon"] = None message_payload["referenced_message"] = None del message_payload["member"] del message_payload["application"]["cover_image"] message = entity_factory_impl.deserialize_message(message_payload) + assert message.application is not None assert message.application.cover_image_hash is None assert message.application.icon_hash is None assert message.referenced_message is None assert message.member is None - def test_deserialize_message_deserializes_old_stickers_field(self, entity_factory_impl, message_payload): + def test_deserialize_message_deserializes_old_stickers_field( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, message_payload: dict[str, typing.Any] + ): message_payload["stickers"] = message_payload["sticker_items"] del message_payload["sticker_items"] @@ -7165,10 +7771,15 @@ def test_deserialize_message_deserializes_old_stickers_field(self, entity_factor ################### def test_deserialize_member_presence( - self, entity_factory_impl, mock_app, member_presence_payload, custom_emoji_payload, user_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + member_presence_payload: dict[str, typing.Any], + custom_emoji_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], ): presence = entity_factory_impl.deserialize_member_presence(member_presence_payload) - assert presence.app is mock_app + assert presence.app is hikari_app assert presence.user_id == 115590097100865541 assert presence.guild_id == 44004040 assert presence.visible_status == presence_models.Status.DO_NOT_DISTURB @@ -7180,6 +7791,7 @@ def test_deserialize_member_presence( assert activity.url == "https://69.420.owouwunyaa" assert activity.created_at == datetime.datetime(2020, 3, 23, 20, 53, 12, 798000, tzinfo=datetime.timezone.utc) # ActivityTimestamps + assert activity.timestamps is not None assert activity.timestamps.start == datetime.datetime( 2020, 3, 23, 20, 53, 12, 798000, tzinfo=datetime.timezone.utc ) @@ -7225,7 +7837,10 @@ def test_deserialize_member_presence( assert isinstance(presence, presence_models.MemberPresence) def test_deserialize_member_presence_with_unset_fields( - self, entity_factory_impl, user_payload, presence_activity_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + user_payload: dict[str, typing.Any], + presence_activity_payload: dict[str, typing.Any], ): presence = entity_factory_impl.deserialize_member_presence( { @@ -7243,7 +7858,9 @@ def test_deserialize_member_presence_with_unset_fields( assert presence.client_status.mobile is presence_models.Status.OFFLINE assert presence.client_status.web is presence_models.Status.OFFLINE - def test_deserialize_member_presence_with_unset_activity_fields(self, entity_factory_impl, user_payload): + def test_deserialize_member_presence_with_unset_activity_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, user_payload: dict[str, typing.Any] + ): presence = entity_factory_impl.deserialize_member_presence( { "user": user_payload, @@ -7269,7 +7886,9 @@ def test_deserialize_member_presence_with_unset_activity_fields(self, entity_fac assert activity.flags is None assert activity.buttons == [] - def test_deserialize_member_presence_with_null_activity_fields(self, entity_factory_impl, user_payload): + def test_deserialize_member_presence_with_null_activity_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, user_payload: dict[str, typing.Any] + ): presence = entity_factory_impl.deserialize_member_presence( { "user": user_payload, @@ -7309,7 +7928,9 @@ def test_deserialize_member_presence_with_null_activity_fields(self, entity_fact assert activity.state is None assert activity.emoji is None - def test_deserialize_member_presence_with_unset_activity_sub_fields(self, entity_factory_impl, user_payload): + def test_deserialize_member_presence_with_unset_activity_sub_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, user_payload: dict[str, typing.Any] + ): presence = entity_factory_impl.deserialize_member_presence( { "user": user_payload, @@ -7388,12 +8009,12 @@ def scheduled_external_event_payload(self, user_payload: dict[str, typing.Any]) def test_deserialize_scheduled_external_event( self, entity_factory_impl: entity_factory.EntityFactoryImpl, - mock_app: mock.Mock, + hikari_app: mock.Mock, scheduled_external_event_payload: dict[str, typing.Any], user_payload: dict[str, typing.Any], ): event = entity_factory_impl.deserialize_scheduled_external_event(scheduled_external_event_payload) - assert event.app is mock_app + assert event.app is hikari_app assert event.id == 9497609168686982223 assert event.guild_id == 1525593721265219296 assert event.name == "bleep" @@ -7412,7 +8033,7 @@ def test_deserialize_scheduled_external_event( def test_deserialize_scheduled_external_event_with_null_fields( self, entity_factory_impl: entity_factory.EntityFactoryImpl, - mock_app: mock.Mock, + hikari_app: mock.Mock, scheduled_external_event_payload: dict[str, typing.Any], ): scheduled_external_event_payload["description"] = None @@ -7426,7 +8047,7 @@ def test_deserialize_scheduled_external_event_with_null_fields( def test_deserialize_scheduled_external_event_with_undefined_fields( self, entity_factory_impl: entity_factory.EntityFactoryImpl, - mock_app: mock.Mock, + hikari_app: mock.Mock, scheduled_external_event_payload: dict[str, typing.Any], ): del scheduled_external_event_payload["creator"] @@ -7466,13 +8087,13 @@ def scheduled_stage_event_payload(self, user_payload: dict[str, typing.Any]) -> def test_deserialize_scheduled_stage_event( self, entity_factory_impl: entity_factory.EntityFactoryImpl, - mock_app: mock.Mock, + hikari_app: mock.Mock, scheduled_stage_event_payload: dict[str, typing.Any], user_payload: dict[str, typing.Any], ): event = entity_factory_impl.deserialize_scheduled_stage_event(scheduled_stage_event_payload) - assert event.app is mock_app + assert event.app is hikari_app assert event.id == 9497014470822052443 assert event.guild_id == 1525593721265192962 assert event.channel_id == 9492384510463386001 @@ -7491,7 +8112,7 @@ def test_deserialize_scheduled_stage_event( def test_deserialize_scheduled_stage_event_with_null_fields( self, entity_factory_impl: entity_factory.EntityFactoryImpl, - mock_app: mock.Mock, + hikari_app: mock.Mock, scheduled_stage_event_payload: dict[str, typing.Any], ): scheduled_stage_event_payload["description"] = None @@ -7507,7 +8128,7 @@ def test_deserialize_scheduled_stage_event_with_null_fields( def test_deserialize_scheduled_stage_event_with_undefined_fields( self, entity_factory_impl: entity_factory.EntityFactoryImpl, - mock_app: mock.Mock, + hikari_app: mock.Mock, scheduled_stage_event_payload: dict[str, typing.Any], ): del scheduled_stage_event_payload["creator"] @@ -7547,13 +8168,13 @@ def scheduled_voice_event_payload(self, user_payload: dict[str, typing.Any]) -> def test_deserialize_scheduled_voice_event( self, entity_factory_impl: entity_factory.EntityFactoryImpl, - mock_app: mock.Mock, + hikari_app: mock.Mock, scheduled_voice_event_payload: dict[str, typing.Any], user_payload: dict[str, typing.Any], ): event = entity_factory_impl.deserialize_scheduled_voice_event(scheduled_voice_event_payload) - assert event.app is mock_app + assert event.app is hikari_app assert event.id == 949760834287063133 assert event.guild_id == 152559372126519296 assert event.channel_id == 152559372126519297 @@ -7572,7 +8193,7 @@ def test_deserialize_scheduled_voice_event( def test_deserialize_scheduled_voice_event_with_null_fields( self, entity_factory_impl: entity_factory.EntityFactoryImpl, - mock_app: mock.Mock, + hikari_app: mock.Mock, scheduled_voice_event_payload: dict[str, typing.Any], ): scheduled_voice_event_payload["description"] = None @@ -7588,7 +8209,7 @@ def test_deserialize_scheduled_voice_event_with_null_fields( def test_deserialize_scheduled_voice_event_with_undefined_fields( self, entity_factory_impl: entity_factory.EntityFactoryImpl, - mock_app: mock.Mock, + hikari_app: mock.Mock, scheduled_voice_event_payload: dict[str, typing.Any], ): del scheduled_voice_event_payload["creator"] @@ -7627,6 +8248,7 @@ def test_deserialize_scheduled_event_when_unknown(self, entity_factory_impl: ent def scheduled_event_user_payload( self, user_payload: dict[str, typing.Any], member_payload: dict[str, typing.Any] ) -> dict[str, typing.Any]: + assert isinstance(member_payload, dict) member_payload = member_payload.copy() del member_payload["user"] return {"guild_scheduled_event_id": "49494949499494", "user": user_payload, "member": member_payload} @@ -7639,12 +8261,16 @@ def test_deserialize_scheduled_event_user( member_payload: dict[str, typing.Any], ): del member_payload["user"] - user = entity_factory_impl.deserialize_scheduled_event_user(scheduled_event_user_payload, guild_id=123321) + user = entity_factory_impl.deserialize_scheduled_event_user( + scheduled_event_user_payload, guild_id=snowflakes.Snowflake(123321) + ) assert user.event_id == 49494949499494 assert user.user == entity_factory_impl.deserialize_user(user_payload) assert user.member == entity_factory_impl.deserialize_member( - member_payload, user=entity_factory_impl.deserialize_user(user_payload), guild_id=123321 + member_payload, + user=entity_factory_impl.deserialize_user(user_payload), + guild_id=snowflakes.Snowflake(123321), ) assert isinstance(user, scheduled_event_models.ScheduledEventUser) @@ -7656,7 +8282,9 @@ def test_deserialize_scheduled_event_user_when_no_member( ): del scheduled_event_user_payload["member"] - event = entity_factory_impl.deserialize_scheduled_event_user(scheduled_event_user_payload, guild_id=123321) + event = entity_factory_impl.deserialize_scheduled_event_user( + scheduled_event_user_payload, guild_id=snowflakes.Snowflake(123321) + ) assert event.member is None assert event.user == entity_factory_impl.deserialize_user(user_payload) @@ -7666,7 +8294,9 @@ def test_deserialize_scheduled_event_user_when_no_member( ################### @pytest.fixture - def template_payload(self, guild_text_channel_payload, user_payload): + def template_payload( + self, guild_text_channel_payload: dict[str, typing.Any], user_payload: dict[str, typing.Any] + ) -> dict[str, typing.Any]: return { "code": "4rDaewUKeYVj", "name": "ttt", @@ -7705,10 +8335,15 @@ def template_payload(self, guild_text_channel_payload, user_payload): } def test_deserialize_template( - self, entity_factory_impl, mock_app, template_payload, user_payload, guild_text_channel_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + template_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + guild_text_channel_payload: dict[str, typing.Any], ): template = entity_factory_impl.deserialize_template(template_payload) - assert template.app is mock_app + assert template.app is hikari_app assert template.code == "4rDaewUKeYVj" assert template.name == "ttt" assert template.description == "eee" @@ -7718,7 +8353,7 @@ def test_deserialize_template( assert template.updated_at == datetime.datetime(2020, 12, 15, 1, 57, 35, tzinfo=datetime.timezone.utc) # TemplateGuild - assert template.source_guild.app is mock_app + assert template.source_guild.app is hikari_app assert template.source_guild.id == 574921006817476608 assert template.source_guild.icon_hash == "27b75989b5b42aba51346a6b69d8fcfe" assert template.source_guild.name == "hikari" @@ -7735,8 +8370,8 @@ def test_deserialize_template( # TemplateRole assert len(template.source_guild.roles) == 1 - role = template.source_guild.roles[33] - assert role.app is mock_app + role = template.source_guild.roles[snowflakes.Snowflake(33)] + assert role.app is hikari_app assert role.id == 33 assert role.name == "@everyone" assert role.permissions == permission_models.Permissions(104189505) @@ -7753,7 +8388,12 @@ def test_deserialize_template( assert template.is_unsynced is True - def test_deserialize_template_with_null_fields(self, entity_factory_impl, template_payload, user_payload): + def test_deserialize_template_with_null_fields( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + template_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + ): template = entity_factory_impl.deserialize_template( { "code": "4rDaewUKeYVj", @@ -7806,54 +8446,68 @@ def avatar_decoration_payload(self) -> dict[str, typing.Any]: return {"asset": "ahhhhhhvatardecoration", "sku_id": "789", "expires_at": 1743753661} def test_deserialize_avatar_decoration( - self, entity_factory_impl, mock_app: mock.Mock, avatar_decoration_payload: dict[str, typing.Any] + self, entity_factory_impl: entity_factory.EntityFactoryImpl, avatar_decoration_payload: dict[str, typing.Any] ): decoration = entity_factory_impl._deserialize_avatar_decoration(avatar_decoration_payload) + assert decoration is not None assert decoration.asset_hash == "ahhhhhhvatardecoration" assert decoration.sku_id == 789 assert decoration.expires_at == datetime.datetime(2025, 4, 4, 8, 1, 1, tzinfo=datetime.timezone.utc) def test_deserialize_avatar_decoration_with_no_expiry( - self, entity_factory_impl, mock_app: mock.Mock, avatar_decoration_payload: dict[str, typing.Any] + self, entity_factory_impl: entity_factory.EntityFactoryImpl, avatar_decoration_payload: dict[str, typing.Any] ): decoration = entity_factory_impl._deserialize_avatar_decoration( {**avatar_decoration_payload, "expires_at": None} ) + assert decoration is not None assert decoration.asset_hash == "ahhhhhhvatardecoration" assert decoration.sku_id == 789 assert decoration.expires_at is None def test_deserialize_avatar_decoration_with_empty_payload( - self, entity_factory_impl, mock_app: mock.Mock, avatar_decoration_payload: dict[str, typing.Any] + self, entity_factory_impl: entity_factory.EntityFactoryImpl ): decoration = entity_factory_impl._deserialize_avatar_decoration(None) assert decoration is None def test__deserialize_primary_guild( - self, entity_factory_impl, mock_app: mock.Mock, primary_guild_payload: dict[str, typing.Any] + self, entity_factory_impl: entity_factory.EntityFactoryImpl, primary_guild_payload: dict[str, typing.Any] ): primary_guild = entity_factory_impl._deserialize_primary_guild(primary_guild_payload) + assert primary_guild is not None + assert primary_guild.identity_guild_id == snowflakes.Snowflake(554454) assert primary_guild.identity_enabled is True assert primary_guild.tag == "HKRI" assert primary_guild.badge_hash == "abcd1234" - def test__deserialize_primary_guild_with_empty_payload(self, entity_factory_impl, mock_app: mock.Mock): + def test__deserialize_primary_guild_with_empty_payload(self, entity_factory_impl: entity_factory.EntityFactoryImpl): primary_guild = entity_factory_impl._deserialize_primary_guild(None) assert primary_guild is None - def test__deserialize_primary_guild_with_nullable_fields(self, entity_factory_impl, mock_app: mock.Mock): + def test__deserialize_primary_guild_with_nullable_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): primary_guild = entity_factory_impl._deserialize_primary_guild( {"identity_guild_id": None, "identity_enabled": None, "tag": None, "badge": None} ) + assert primary_guild is not None + assert primary_guild.identity_guild_id is None assert primary_guild.identity_enabled is None assert primary_guild.tag is None assert primary_guild.badge_hash is None - def test_deserialize_user(self, entity_factory_impl, mock_app, user_payload, primary_guild_payload): + def test_deserialize_user( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + user_payload: dict[str, typing.Any], + primary_guild_payload: dict[str, typing.Any], + ): user = entity_factory_impl.deserialize_user(user_payload) - assert user.app is mock_app + assert user.app is hikari_app assert user.id == 115590097100865541 assert user.username == "nyaa" assert user.avatar_hash == "b3b24c6d7cbcdec129d5d537067061a8" @@ -7866,7 +8520,12 @@ def test_deserialize_user(self, entity_factory_impl, mock_app, user_payload, pri assert isinstance(user, user_models.UserImpl) assert user.primary_guild == entity_factory_impl._deserialize_primary_guild(primary_guild_payload) - def test_deserialize_user_with_unset_fields(self, entity_factory_impl, mock_app, user_payload): + def test_deserialize_user_with_unset_fields( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + user_payload: dict[str, typing.Any], + ): user = entity_factory_impl.deserialize_user( { "id": "115590097100865541", @@ -7883,7 +8542,7 @@ def test_deserialize_user_with_unset_fields(self, entity_factory_impl, mock_app, assert user.primary_guild is None @pytest.fixture - def my_user_payload(self, primary_guild_payload: dict[str, Any]): + def my_user_payload(self, primary_guild_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "id": "379953393319542784", "username": "qt pi", @@ -7904,9 +8563,15 @@ def my_user_payload(self, primary_guild_payload: dict[str, Any]): "primary_guild": primary_guild_payload, } - def test_deserialize_my_user(self, entity_factory_impl, mock_app, my_user_payload, primary_guild_payload): + def test_deserialize_my_user( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + my_user_payload: dict[str, typing.Any], + primary_guild_payload: dict[str, typing.Any], + ): my_user = entity_factory_impl.deserialize_my_user(my_user_payload) - assert my_user.app is mock_app + assert my_user.app is hikari_app assert my_user.id == 379953393319542784 assert my_user.username == "qt pi" assert my_user.global_name == "blahaj" @@ -7926,7 +8591,12 @@ def test_deserialize_my_user(self, entity_factory_impl, mock_app, my_user_payloa assert my_user.primary_guild == entity_factory_impl._deserialize_primary_guild(primary_guild_payload) assert isinstance(my_user, user_models.OwnUser) - def test_deserialize_my_user_with_unset_fields(self, entity_factory_impl, mock_app, my_user_payload): + def test_deserialize_my_user_with_unset_fields( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + my_user_payload: dict[str, typing.Any], + ): my_user = entity_factory_impl.deserialize_my_user( { "id": "379953393319542784", @@ -7940,7 +8610,7 @@ def test_deserialize_my_user_with_unset_fields(self, entity_factory_impl, mock_a } ) assert my_user.global_name is None - assert my_user.app is mock_app + assert my_user.app is hikari_app assert my_user.banner_hash is None assert my_user.accent_color is None assert my_user.is_bot is False @@ -7955,10 +8625,14 @@ def test_deserialize_my_user_with_unset_fields(self, entity_factory_impl, mock_a ################ def test_deserialize_voice_state_with_guild_id_in_payload( - self, entity_factory_impl, mock_app, voice_state_payload, member_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + voice_state_payload: dict[str, typing.Any], + member_payload: dict[str, typing.Any], ): voice_state = entity_factory_impl.deserialize_voice_state(voice_state_payload) - assert voice_state.app is mock_app + assert voice_state.app is hikari_app assert voice_state.guild_id == 929292929292992 assert voice_state.channel_id == 157733188964188161 assert voice_state.user_id == 115590097100865541 @@ -7979,7 +8653,10 @@ def test_deserialize_voice_state_with_guild_id_in_payload( assert isinstance(voice_state, voice_models.VoiceState) def test_deserialize_voice_state_with_injected_guild_id( - self, entity_factory_impl, voice_state_payload, member_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + voice_state_payload: dict[str, typing.Any], + member_payload: dict[str, typing.Any], ): voice_state = entity_factory_impl.deserialize_voice_state( { @@ -8004,7 +8681,9 @@ def test_deserialize_voice_state_with_injected_guild_id( member_payload, guild_id=snowflakes.Snowflake(43123) ) - def test_deserialize_voice_state_with_null_and_unset_fields(self, entity_factory_impl, member_payload): + def test_deserialize_voice_state_with_null_and_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, member_payload: dict[str, typing.Any] + ): voice_state = entity_factory_impl.deserialize_voice_state( { "channel_id": None, @@ -8026,10 +8705,12 @@ def test_deserialize_voice_state_with_null_and_unset_fields(self, entity_factory assert voice_state.requested_to_speak_at is None @pytest.fixture - def voice_region_payload(self): + def voice_region_payload(self) -> dict[str, typing.Any]: return {"id": "london", "name": "LONDON", "optimal": False, "deprecated": True, "custom": False} - def test_deserialize_voice_region(self, entity_factory_impl, voice_region_payload): + def test_deserialize_voice_region( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, voice_region_payload: dict[str, typing.Any] + ): voice_region = entity_factory_impl.deserialize_voice_region(voice_region_payload) assert voice_region.id == "london" assert voice_region.name == "LONDON" @@ -8043,7 +8724,7 @@ def test_deserialize_voice_region(self, entity_factory_impl, voice_region_payloa ################## @pytest.fixture - def incoming_webhook_payload(self, user_payload): + def incoming_webhook_payload(self, user_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "name": "test webhook", "type": 1, @@ -8057,7 +8738,7 @@ def incoming_webhook_payload(self, user_payload): } @pytest.fixture - def follower_webhook_payload(self, user_payload, partial_channel_payload): + def follower_webhook_payload(self, user_payload: dict[str, typing.Any]) -> dict[str, typing.Any]: return { "type": 2, "id": "752831914402115456", @@ -8076,7 +8757,7 @@ def follower_webhook_payload(self, user_payload, partial_channel_payload): } @pytest.fixture - def application_webhook_payload(self): + def application_webhook_payload(self) -> dict[str, typing.Any]: return { "type": 3, "id": "658822586720976555", @@ -8087,10 +8768,16 @@ def application_webhook_payload(self): "application_id": "658822586720976555", } - def test_deserialize_incoming_webhook(self, entity_factory_impl, mock_app, incoming_webhook_payload, user_payload): + def test_deserialize_incoming_webhook( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + incoming_webhook_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], + ): webhook = entity_factory_impl.deserialize_incoming_webhook(incoming_webhook_payload) - assert webhook.app is mock_app + assert webhook.app is hikari_app assert webhook.name == "test webhook" assert webhook.type is webhook_models.WebhookType.INCOMING assert webhook.channel_id == 199737254929760256 @@ -8106,7 +8793,10 @@ def test_deserialize_incoming_webhook(self, entity_factory_impl, mock_app, incom assert isinstance(webhook, webhook_models.IncomingWebhook) def test_deserialize_incoming_webhook_with_null_fields( - self, entity_factory_impl, incoming_webhook_payload, user_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + incoming_webhook_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], ): del incoming_webhook_payload["user"] del incoming_webhook_payload["token"] @@ -8125,11 +8815,15 @@ def test_deserialize_incoming_webhook_with_null_fields( assert isinstance(webhook, webhook_models.IncomingWebhook) def test_deserialize_channel_follower_webhook( - self, entity_factory_impl, mock_app, follower_webhook_payload, user_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + follower_webhook_payload: dict[str, typing.Any], + user_payload: dict[str, typing.Any], ): webhook = entity_factory_impl.deserialize_channel_follower_webhook(follower_webhook_payload) - assert webhook.app is mock_app + assert webhook.app is hikari_app assert webhook.type is webhook_models.WebhookType.CHANNEL_FOLLOWER assert webhook.id == 752831914402115456 assert webhook.name == "Guildy name" @@ -8138,12 +8832,14 @@ def test_deserialize_channel_follower_webhook( assert webhook.guild_id == 56188498421443265 assert webhook.application_id == 312123123 - assert webhook.source_guild.app is mock_app + assert webhook.source_guild is not None + assert webhook.source_guild.app is hikari_app assert webhook.source_guild.id == 56188498421476534 assert webhook.source_guild.name == "Guildy name" assert webhook.source_guild.icon_hash == "bb71f469c158984e265093a81b3397fb" assert isinstance(webhook.source_guild, guild_models.PartialGuild) + assert webhook.source_channel is not None assert webhook.source_channel.id == 5618852344134324 assert webhook.source_channel.name == "announcements" assert webhook.source_channel.type == channel_models.ChannelType.GUILD_NEWS @@ -8153,7 +8849,10 @@ def test_deserialize_channel_follower_webhook( assert isinstance(webhook, webhook_models.ChannelFollowerWebhook) def test_deserialize_channel_follower_webhook_without_optional_fields( - self, entity_factory_impl, mock_app, follower_webhook_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + follower_webhook_payload: dict[str, typing.Any], ): follower_webhook_payload["avatar"] = None del follower_webhook_payload["user"] @@ -8170,18 +8869,27 @@ def test_deserialize_channel_follower_webhook_without_optional_fields( assert webhook.source_channel is None def test_deserialize_channel_follower_webhook_doesnt_set_source_channel_type_if_set( - self, entity_factory_impl, mock_app, follower_webhook_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + follower_webhook_payload: dict[str, typing.Any], ): follower_webhook_payload["source_channel"]["type"] = channel_models.ChannelType.GUILD_VOICE webhook = entity_factory_impl.deserialize_channel_follower_webhook(follower_webhook_payload) + assert webhook.source_channel is not None assert webhook.source_channel.type == channel_models.ChannelType.GUILD_VOICE - def test_deserialize_application_webhook(self, entity_factory_impl, mock_app, application_webhook_payload): + def test_deserialize_application_webhook( + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + application_webhook_payload: dict[str, typing.Any], + ): webhook = entity_factory_impl.deserialize_application_webhook(application_webhook_payload) - assert webhook.app is mock_app + assert webhook.app is hikari_app assert webhook.type is webhook_models.WebhookType.APPLICATION assert webhook.id == 658822586720976555 assert webhook.name == "Clyde" @@ -8190,7 +8898,10 @@ def test_deserialize_application_webhook(self, entity_factory_impl, mock_app, ap assert isinstance(webhook, webhook_models.ApplicationWebhook) def test_deserialize_application_webhook_without_optional_fields( - self, entity_factory_impl, mock_app, application_webhook_payload + self, + entity_factory_impl: entity_factory.EntityFactoryImpl, + hikari_app: traits.RESTAware, + application_webhook_payload: dict[str, typing.Any], ): application_webhook_payload["avatar"] = None @@ -8206,19 +8917,21 @@ def test_deserialize_application_webhook_without_optional_fields( (3, "deserialize_application_webhook"), ], ) - def test_deserialize_webhook(self, mock_app, type_, fn): + def test_deserialize_webhook(self, hikari_app: traits.RESTAware, type_: int, fn: str): payload = {"type": type_} with mock.patch.object(entity_factory.EntityFactoryImpl, fn) as expected_fn: # We need to instantiate it after the mock so that the functions that are stored in the dicts # are the ones we mock - entity_factory_impl = entity_factory.EntityFactoryImpl(app=mock_app) + entity_factory_impl = entity_factory.EntityFactoryImpl(app=hikari_app) assert entity_factory_impl.deserialize_webhook(payload) is expected_fn.return_value expected_fn.assert_called_once_with(payload) - def test_deserialize_webhook_for_unexpected_webhook_type(self, entity_factory_impl): + def test_deserialize_webhook_for_unexpected_webhook_type( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): with pytest.raises(errors.UnrecognisedEntityError): entity_factory_impl.deserialize_webhook({"type": -7999}) @@ -8227,7 +8940,7 @@ def test_deserialize_webhook_for_unexpected_webhook_type(self, entity_factory_im ################## @pytest.fixture - def entitlement_payload(self): + def entitlement_payload(self) -> dict[str, typing.Any]: return { "id": "696969696969696", "sku_id": "420420420420420", @@ -8242,22 +8955,7 @@ def entitlement_payload(self): } @pytest.fixture - def entitlement_payload_starts_ends_null(self): - return { - "id": "696969696969696", - "sku_id": "420420420420420", - "application_id": "123123123123123", - "type": 8, - "deleted": False, - "starts_at": None, - "ends_at": None, - "guild_id": "1015034326372454400", - "user_id": "115590097100865541", - "subscription_id": "1019653835926409216", - } - - @pytest.fixture - def sku_payload(self): + def sku_payload(self) -> dict[str, typing.Any]: return { "id": "420420420420420", "type": 5, @@ -8267,7 +8965,9 @@ def sku_payload(self): "flags": 1 << 2 | 1 << 7, } - def test_deserialize_entitlement(self, entity_factory_impl, entitlement_payload): + def test_deserialize_entitlement( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, entitlement_payload: dict[str, typing.Any] + ): entitlement = entity_factory_impl.deserialize_entitlement(entitlement_payload) assert entitlement.id == 696969696969696 @@ -8282,22 +8982,20 @@ def test_deserialize_entitlement(self, entity_factory_impl, entitlement_payload) assert entitlement.subscription_id == 1019653835926409216 assert isinstance(entitlement, monetization_models.Entitlement) - def test_deserialize_entitlement_starts_ends_null(self, entity_factory_impl, entitlement_payload_starts_ends_null): - entitlement = entity_factory_impl.deserialize_entitlement(entitlement_payload_starts_ends_null) + def test_deserialize_entitlement_starts_ends_null( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, entitlement_payload: dict[str, typing.Any] + ): + entitlement_payload["starts_at"] = None + entitlement_payload["ends_at"] = None + + entitlement = entity_factory_impl.deserialize_entitlement(entitlement_payload) - assert entitlement.id == 696969696969696 - assert entitlement.sku_id == 420420420420420 - assert entitlement.application_id == 123123123123123 - assert entitlement.type is monetization_models.EntitlementType.APPLICATION_SUBSCRIPTION - assert entitlement.is_deleted is False assert entitlement.starts_at is None assert entitlement.ends_at is None - assert entitlement.guild_id == 1015034326372454400 - assert entitlement.user_id == 115590097100865541 - assert entitlement.subscription_id == 1019653835926409216 - assert isinstance(entitlement, monetization_models.Entitlement) - def test_deserialize_sku(self, entity_factory_impl, sku_payload): + def test_deserialize_sku( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, sku_payload: dict[str, typing.Any] + ): sku = entity_factory_impl.deserialize_sku(sku_payload) assert sku.id == 420420420420420 @@ -8313,7 +9011,7 @@ def test_deserialize_sku(self, entity_factory_impl, sku_payload): ######################### @pytest.fixture - def stage_instance_payload(self): + def stage_instance_payload(self) -> dict[str, typing.Any]: return { "id": "840647391636226060", "guild_id": "197038439483310086", @@ -8324,10 +9022,15 @@ def stage_instance_payload(self): "discoverable_disabled": False, } - def test_deserialize_stage_instance(self, entity_factory_impl, stage_instance_payload, mock_app): + def test_deserialize_stage_instance( + self, + hikari_app: traits.RESTAware, + entity_factory_impl: entity_factory.EntityFactoryImpl, + stage_instance_payload: dict[str, typing.Any], + ): stage_instance = entity_factory_impl.deserialize_stage_instance(stage_instance_payload) - assert stage_instance.app is mock_app + assert stage_instance.app is hikari_app assert stage_instance.id == 840647391636226060 assert stage_instance.channel_id == 733488538393510049 assert stage_instance.guild_id == 197038439483310086 @@ -8340,7 +9043,7 @@ def test_deserialize_stage_instance(self, entity_factory_impl, stage_instance_pa ########### @pytest.fixture - def poll_payload(self): + def poll_payload(self) -> dict[str, typing.Any]: return { "question": {"text": "fruit"}, "answers": [ @@ -8360,7 +9063,9 @@ def poll_payload(self): }, } - def test_deserialize_poll(self, entity_factory_impl, poll_payload): + def test_deserialize_poll( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, poll_payload: dict[str, typing.Any] + ): poll = entity_factory_impl.deserialize_poll(poll_payload) assert poll.question.text == "fruit" @@ -8394,46 +9099,52 @@ def test_deserialize_poll(self, entity_factory_impl, poll_payload): assert results_answer_counts[1].count == 28347 assert results_answer_counts[1].me_voted is True - def test_deserialize_poll_with_unset_fields(self, entity_factory_impl, poll_payload): + def test_deserialize_poll_with_unset_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, poll_payload: dict[str, typing.Any] + ): poll_payload["expiry"] = None poll = entity_factory_impl.deserialize_poll(poll_payload) assert poll.expiry is None - def test_deserialize_poll_with_null_fields(self, entity_factory_impl, poll_payload): + def test_deserialize_poll_with_null_fields( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, poll_payload: dict[str, typing.Any] + ): del poll_payload["results"] poll = entity_factory_impl.deserialize_poll(poll_payload) assert poll.results is None - def test_deserialize_auto_mod_action_for_block_message(self, entity_factory_impl): + def test_deserialize_auto_mod_action_for_block_message(self, entity_factory_impl: entity_factory.EntityFactoryImpl): result = entity_factory_impl.deserialize_auto_mod_action({"type": 1}) assert result.type is auto_mod_models.AutoModActionType.BLOCK_MESSAGE assert isinstance(result, auto_mod_models.AutoModBlockMessage) - def test_deserialize_auto_mod_action_for_send_alert_message(self, entity_factory_impl): + def test_deserialize_auto_mod_action_for_send_alert_message( + self, entity_factory_impl: entity_factory.EntityFactoryImpl + ): result = entity_factory_impl.deserialize_auto_mod_action({"type": 2, "metadata": {"channel_id": "43123123"}}) assert result.type is auto_mod_models.AutoModActionType.SEND_ALERT_MESSAGE assert isinstance(result, auto_mod_models.AutoModSendAlertMessage) assert result.channel_id == 43123123 - def test_deserialize_auto_mod_action_for_timeout(self, entity_factory_impl): + def test_deserialize_auto_mod_action_for_timeout(self, entity_factory_impl: entity_factory.EntityFactoryImpl): result = entity_factory_impl.deserialize_auto_mod_action({"type": 3, "metadata": {"duration_seconds": 123321}}) assert result.type is auto_mod_models.AutoModActionType.TIMEOUT assert isinstance(result, auto_mod_models.AutoModTimeout) assert result.duration == datetime.timedelta(seconds=123321) - def test_deserialize_auto_mod_action_for_unknown_type(self, entity_factory_impl): + def test_deserialize_auto_mod_action_for_unknown_type(self, entity_factory_impl: entity_factory.EntityFactoryImpl): with pytest.raises(errors.UnrecognisedEntityError): entity_factory_impl.deserialize_auto_mod_action({"type": -696969}) @pytest.fixture - def auto_mod_rule_payload(self): + def auto_mod_rule_payload(self) -> dict[str, typing.Any]: return { "id": "94594949494", "guild_id": "9595939234", @@ -8452,7 +9163,9 @@ def auto_mod_rule_payload(self): "exempt_channels": ["95959595", "31223"], } - def test_deserialize_auto_mod_rule(self, entity_factory_impl, auto_mod_rule_payload): + def test_deserialize_auto_mod_rule( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, auto_mod_rule_payload: dict[str, typing.Any] + ): result = entity_factory_impl.deserialize_auto_mod_rule(auto_mod_rule_payload) assert result.id == 94594949494 @@ -8471,7 +9184,9 @@ def test_deserialize_auto_mod_rule(self, entity_factory_impl, auto_mod_rule_payl assert result.exempt_role_ids == [49493932, 123321] assert result.exempt_channel_ids == [95959595, 31223] - def test_deserialize_auto_mod_rule_for_keyword_trigger(self, entity_factory_impl, auto_mod_rule_payload): + def test_deserialize_auto_mod_rule_for_keyword_trigger( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, auto_mod_rule_payload: dict[str, typing.Any] + ): result = entity_factory_impl.deserialize_auto_mod_rule( { "id": "94594949494", @@ -8498,7 +9213,9 @@ def test_deserialize_auto_mod_rule_for_keyword_trigger(self, entity_factory_impl assert result.trigger.regex_patterns == ["some", "regex", "patterns"] assert result.trigger.allow_list == ["allowed", "stuff"] - def test_deserialize_auto_mod_rule_for_spam_trigger(self, entity_factory_impl, auto_mod_rule_payload): + def test_deserialize_auto_mod_rule_for_spam_trigger( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, auto_mod_rule_payload: dict[str, typing.Any] + ): result = entity_factory_impl.deserialize_auto_mod_rule( { "id": "94594949494", @@ -8517,7 +9234,9 @@ def test_deserialize_auto_mod_rule_for_spam_trigger(self, entity_factory_impl, a assert isinstance(result.trigger, auto_mod_models.SpamTrigger) assert result.trigger.type is auto_mod_models.AutoModTriggerType.SPAM - def test_deserialize_auto_mod_rule_for_keyword_preset_trigger(self, entity_factory_impl, auto_mod_rule_payload): + def test_deserialize_auto_mod_rule_for_keyword_preset_trigger( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, auto_mod_rule_payload: dict[str, typing.Any] + ): result = entity_factory_impl.deserialize_auto_mod_rule( { "id": "94594949494", @@ -8542,7 +9261,9 @@ def test_deserialize_auto_mod_rule_for_keyword_preset_trigger(self, entity_facto ] assert result.trigger.allow_list == ["allowed", "stuff"] - def test_deserialize_auto_mod_rule_for_mention_spam_trigger(self, entity_factory_impl, auto_mod_rule_payload): + def test_deserialize_auto_mod_rule_for_mention_spam_trigger( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, auto_mod_rule_payload: dict[str, typing.Any] + ): result = entity_factory_impl.deserialize_auto_mod_rule( { "id": "94594949494", @@ -8564,7 +9285,9 @@ def test_deserialize_auto_mod_rule_for_mention_spam_trigger(self, entity_factory assert result.trigger.mention_total_limit == 5 assert result.trigger.mention_raid_protection_enabled is False - def test_deserialize_auto_mod_rule_for_member_profile_trigger(self, entity_factory_impl, auto_mod_rule_payload): + def test_deserialize_auto_mod_rule_for_member_profile_trigger( + self, entity_factory_impl: entity_factory.EntityFactoryImpl, auto_mod_rule_payload: dict[str, typing.Any] + ): result = entity_factory_impl.deserialize_auto_mod_rule( { "id": "94594949494", diff --git a/tests/hikari/impl/test_event_factory.py b/tests/hikari/impl/test_event_factory.py index fa262dfa40..94f37f2a68 100644 --- a/tests/hikari/impl/test_event_factory.py +++ b/tests/hikari/impl/test_event_factory.py @@ -57,104 +57,142 @@ class TestEventFactoryImpl: @pytest.fixture - def mock_app(self): - return mock.Mock(traits.RESTAware) - - @pytest.fixture - def mock_shard(self): + def mock_shard(self) -> shard.GatewayShard: return mock.Mock(shard.GatewayShard) @pytest.fixture - def event_factory(self, mock_app): - return event_factory_.EventFactoryImpl(mock_app) + def event_factory(self, hikari_app: traits.RESTAware) -> event_factory_.EventFactoryImpl: + return event_factory_.EventFactoryImpl(hikari_app) ###################### # APPLICATION EVENTS # ###################### - def test_deserialize_application_command_permission_update_event(self, event_factory, mock_app, mock_shard): - mock_payload = object() + def test_deserialize_application_command_permission_update_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock() - event = event_factory.deserialize_application_command_permission_update_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_guild_command_permissions" + ) as patched_deserialize_guild_command_permissions: + event = event_factory.deserialize_application_command_permission_update_event(mock_shard, mock_payload) - mock_app.entity_factory.deserialize_guild_command_permissions.assert_called_once_with(mock_payload) + patched_deserialize_guild_command_permissions.assert_called_once_with(mock_payload) assert isinstance(event, application_events.ApplicationCommandPermissionsUpdateEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard - assert event.permissions is mock_app.entity_factory.deserialize_guild_command_permissions.return_value + assert event.permissions is patched_deserialize_guild_command_permissions.return_value ################## # CHANNEL EVENTS # ################## - def test_deserialize_guild_channel_create_event(self, event_factory, mock_app, mock_shard): - mock_app.entity_factory.deserialize_channel.return_value = mock.Mock( - spec=channel_models.PermissibleGuildChannel - ) - mock_payload = mock.Mock(app=mock_app) - - event = event_factory.deserialize_guild_channel_create_event(mock_shard, mock_payload) - - mock_app.entity_factory.deserialize_channel.assert_called_once_with(mock_payload) - assert isinstance(event, channel_events.GuildChannelCreateEvent) - assert event.shard is mock_shard - assert event.channel is mock_app.entity_factory.deserialize_channel.return_value + def test_deserialize_guild_channel_create_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock(app=hikari_app) - def test_deserialize_guild_channel_update_event(self, event_factory, mock_app, mock_shard): - mock_app.entity_factory.deserialize_channel.return_value = mock.Mock( - spec=channel_models.PermissibleGuildChannel - ) - mock_old_channel = object() - mock_payload = object() + with mock.patch.object( + hikari_app.entity_factory, + "deserialize_channel", + return_value=mock.Mock(spec=channel_models.PermissibleGuildChannel), + ) as patched_deserialize_channel: + event = event_factory.deserialize_guild_channel_create_event(mock_shard, mock_payload) - event = event_factory.deserialize_guild_channel_update_event( - mock_shard, mock_payload, old_channel=mock_old_channel - ) + patched_deserialize_channel.assert_called_once_with(mock_payload) + assert isinstance(event, channel_events.GuildChannelCreateEvent) + assert event.shard is mock_shard + assert event.channel is patched_deserialize_channel.return_value - mock_app.entity_factory.deserialize_channel.assert_called_once_with(mock_payload) - assert isinstance(event, channel_events.GuildChannelUpdateEvent) - assert event.shard is mock_shard - assert event.channel is mock_app.entity_factory.deserialize_channel.return_value - assert event.old_channel is mock_old_channel + def test_deserialize_guild_channel_update_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_old_channel = mock.Mock() + mock_payload = mock.Mock() - def test_deserialize_guild_channel_delete_event(self, event_factory, mock_app, mock_shard): - mock_app.entity_factory.deserialize_channel.return_value = mock.Mock( - spec=channel_models.PermissibleGuildChannel - ) - mock_payload = mock.Mock(app=mock_app) + with mock.patch.object( + hikari_app.entity_factory, + "deserialize_channel", + return_value=mock.Mock(spec=channel_models.PermissibleGuildChannel), + ) as patched_deserialize_channel: + event = event_factory.deserialize_guild_channel_update_event( + mock_shard, mock_payload, old_channel=mock_old_channel + ) + + patched_deserialize_channel.assert_called_once_with(mock_payload) + assert isinstance(event, channel_events.GuildChannelUpdateEvent) + assert event.shard is mock_shard + assert event.channel is patched_deserialize_channel.return_value + assert event.old_channel is mock_old_channel + + def test_deserialize_guild_channel_delete_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock(app=hikari_app) - event = event_factory.deserialize_guild_channel_delete_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, + "deserialize_channel", + return_value=mock.Mock(spec=channel_models.PermissibleGuildChannel), + ) as patched_deserialize_channel: + event = event_factory.deserialize_guild_channel_delete_event(mock_shard, mock_payload) - mock_app.entity_factory.deserialize_channel.assert_called_once_with(mock_payload) - assert isinstance(event, channel_events.GuildChannelDeleteEvent) - assert event.shard is mock_shard - assert event.channel is mock_app.entity_factory.deserialize_channel.return_value + patched_deserialize_channel.assert_called_once_with(mock_payload) + assert isinstance(event, channel_events.GuildChannelDeleteEvent) + assert event.shard is mock_shard + assert event.channel is patched_deserialize_channel.return_value - def test_deserialize_channel_pins_update_event_for_guild(self, event_factory, mock_app, mock_shard): + def test_deserialize_channel_pins_update_event_for_guild( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"channel_id": "123435", "last_pin_timestamp": None, "guild_id": "43123123"} event = event_factory.deserialize_channel_pins_update_event(mock_shard, mock_payload) assert isinstance(event, channel_events.GuildPinsUpdateEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.guild_id == 43123123 assert event.last_pin_timestamp is None - def test_deserialize_channel_pins_update_event_for_dm(self, event_factory, mock_app, mock_shard): + def test_deserialize_channel_pins_update_event_for_dm( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"channel_id": "123435", "last_pin_timestamp": "2020-03-15T15:23:32.686000+00:00"} event = event_factory.deserialize_channel_pins_update_event(mock_shard, mock_payload) assert isinstance(event, channel_events.DMPinsUpdateEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.last_pin_timestamp == datetime.datetime( 2020, 3, 15, 15, 23, 32, 686000, tzinfo=datetime.timezone.utc ) def test_deserialize_channel_pins_update_event_without_last_pin_timestamp( - self, event_factory, mock_app, mock_shard + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): mock_payload = {"channel_id": "123435", "guild_id": "43123123"} @@ -163,53 +201,76 @@ def test_deserialize_channel_pins_update_event_without_last_pin_timestamp( assert event.last_pin_timestamp is None def test_deserialize_guild_thread_create_event( - self, event_factory: event_factory_.EventFactoryImpl, mock_app: mock.Mock, mock_shard: shard.GatewayShard + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): mock_payload = mock.Mock() - event = event_factory.deserialize_guild_thread_create_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_guild_thread" + ) as patched_deserialize_guild_thread: + event = event_factory.deserialize_guild_thread_create_event(mock_shard, mock_payload) - assert event.shard is mock_shard - assert event.thread is mock_app.entity_factory.deserialize_guild_thread.return_value - mock_app.entity_factory.deserialize_guild_thread.assert_called_once_with(mock_payload) - assert isinstance(event, channel_events.GuildThreadCreateEvent) + assert event.shard is mock_shard + assert event.thread is patched_deserialize_guild_thread.return_value + patched_deserialize_guild_thread.assert_called_once_with(mock_payload) + assert isinstance(event, channel_events.GuildThreadCreateEvent) def test_deserialize_guild_thread_access_event( - self, event_factory: event_factory_.EventFactoryImpl, mock_app: mock.Mock, mock_shard: shard.GatewayShard + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): mock_payload = mock.Mock() - event = event_factory.deserialize_guild_thread_access_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_guild_thread" + ) as patched_deserialize_guild_thread: + event = event_factory.deserialize_guild_thread_access_event(mock_shard, mock_payload) - assert event.shard is mock_shard - assert event.thread is mock_app.entity_factory.deserialize_guild_thread.return_value - mock_app.entity_factory.deserialize_guild_thread.assert_called_once_with(mock_payload) - assert isinstance(event, channel_events.GuildThreadAccessEvent) + assert event.shard is mock_shard + assert event.thread is patched_deserialize_guild_thread.return_value + patched_deserialize_guild_thread.assert_called_once_with(mock_payload) + assert isinstance(event, channel_events.GuildThreadAccessEvent) def test_deserialize_guild_thread_update_event( - self, event_factory: event_factory_.EventFactoryImpl, mock_app: mock.Mock, mock_shard: shard.GatewayShard + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): - mock_old_thread = object() + mock_old_thread = mock.Mock() mock_payload = mock.Mock() - event = event_factory.deserialize_guild_thread_update_event( - mock_shard, mock_payload, old_thread=mock_old_thread - ) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_guild_thread" + ) as patched_deserialize_guild_thread: + event = event_factory.deserialize_guild_thread_update_event( + mock_shard, mock_payload, old_thread=mock_old_thread + ) - mock_app.entity_factory.deserialize_guild_thread.assert_called_once_with(mock_payload) + assert event.shard is mock_shard + assert event.thread is patched_deserialize_guild_thread.return_value + patched_deserialize_guild_thread.assert_called_once_with(mock_payload) assert isinstance(event, channel_events.GuildThreadUpdateEvent) assert event.shard is mock_shard - assert event.thread is mock_app.entity_factory.deserialize_guild_thread.return_value + assert event.thread is patched_deserialize_guild_thread.return_value assert event.old_thread is mock_old_thread def test_deserialize_guild_thread_delete_event( - self, event_factory: event_factory_.EventFactoryImpl, mock_app: mock.Mock, mock_shard: shard.GatewayShard + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): mock_payload = {"id": "12332123321", "guild_id": "54544234342", "parent_id": "9494949", "type": 11} event = event_factory.deserialize_guild_thread_delete_event(mock_shard, mock_payload) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.thread_id == 12332123321 assert event.guild_id == 54544234342 @@ -218,13 +279,15 @@ def test_deserialize_guild_thread_delete_event( assert isinstance(event, channel_events.GuildThreadDeleteEvent) def test_deserialize_thread_members_update_event( - self, event_factory: event_factory_.EventFactoryImpl, mock_app: mock.Mock, mock_shard: shard.GatewayShard + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): mock_thread_member_payload = {"id": "393939393", "user_id": "3933993"} mock_other_thread_member_payload = {"id": "393994954", "user_id": "123321123"} mock_thread_member = mock.Mock(user_id=123321123) mock_other_thread_member = mock.Mock(user_id=5454234) - mock_app.entity_factory.deserialize_thread_member.side_effect = [mock_thread_member, mock_other_thread_member] payload = { "id": "92929929", "guild_id": "92929292", @@ -233,9 +296,14 @@ def test_deserialize_thread_members_update_event( "removed_member_ids": ["4949534", "123321", "54234"], } - event = event_factory.deserialize_thread_members_update_event(mock_shard, payload) + with mock.patch.object( + hikari_app.entity_factory, + "deserialize_thread_member", + side_effect=[mock_thread_member, mock_other_thread_member], + ) as patched_deserialize_thread_member: + event = event_factory.deserialize_thread_members_update_event(mock_shard, payload) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.thread_id == 92929929 assert event.guild_id == 92929292 @@ -243,12 +311,15 @@ def test_deserialize_thread_members_update_event( assert event.removed_member_ids == [4949534, 123321, 54234] assert event.guild_members == {} assert event.guild_presences == {} - mock_app.entity_factory.deserialize_thread_member.assert_has_calls( + patched_deserialize_thread_member.assert_has_calls( [mock.call(mock_thread_member_payload), mock.call(mock_other_thread_member_payload)] ) def test_deserialize_thread_members_update_event_when_presences_and_real_members( - self, event_factory: event_factory_.EventFactoryImpl, mock_app: mock.Mock, mock_shard: shard.GatewayShard + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): mock_presence_payload = mock.Mock() mock_other_presence_payload = mock.Mock() @@ -272,9 +343,6 @@ def test_deserialize_thread_members_update_event_when_presences_and_real_members mock_other_presence = mock.Mock() mock_guild_member = mock.Mock() mock_other_guild_member = mock.Mock() - mock_app.entity_factory.deserialize_thread_member.side_effect = [mock_thread_member, mock_other_thread_member] - mock_app.entity_factory.deserialize_member.side_effect = [mock_guild_member, mock_other_guild_member] - mock_app.entity_factory.deserialize_member_presence.side_effect = [mock_presence, mock_other_presence] payload = { "id": "92929929", "guild_id": "123321123123", @@ -283,9 +351,26 @@ def test_deserialize_thread_members_update_event_when_presences_and_real_members "removed_member_ids": ["4949534", "123321", "54234"], } - event = event_factory.deserialize_thread_members_update_event(mock_shard, payload) - - assert event.app is mock_app + with ( + mock.patch.object( + hikari_app.entity_factory, + "deserialize_thread_member", + side_effect=[mock_thread_member, mock_other_thread_member], + ) as patched_deserialize_thread_member, + mock.patch.object( + hikari_app.entity_factory, + "deserialize_member", + side_effect=[mock_guild_member, mock_other_guild_member], + ) as patched_deserialize_member, + mock.patch.object( + hikari_app.entity_factory, + "deserialize_member_presence", + side_effect=[mock_presence, mock_other_presence], + ) as patched_deserialize_member_presence, + ): + event = event_factory.deserialize_thread_members_update_event(mock_shard, payload) + + assert event.app is hikari_app assert event.shard is mock_shard assert event.thread_id == 92929929 assert event.guild_id == 123321123123 @@ -293,16 +378,16 @@ def test_deserialize_thread_members_update_event_when_presences_and_real_members assert event.removed_member_ids == [4949534, 123321, 54234] assert event.guild_members == {3933993: mock_guild_member, 123321123: mock_other_guild_member} assert event.guild_presences == {3933993: mock_presence, 123321123: mock_other_presence} - mock_app.entity_factory.deserialize_thread_member.assert_has_calls( + patched_deserialize_thread_member.assert_has_calls( [mock.call(mock_thread_member_payload), mock.call(mock_other_thread_member_payload)] ) - mock_app.entity_factory.deserialize_member.assert_has_calls( + patched_deserialize_member.assert_has_calls( [ mock.call(mock_guild_member_payload, guild_id=123321123123), mock.call(mock_other_guild_member_payload, guild_id=123321123123), ] ) - mock_app.entity_factory.deserialize_member_presence.assert_has_calls( + patched_deserialize_member_presence.assert_has_calls( [ mock.call(mock_presence_payload, guild_id=123321123123), mock.call(mock_other_presence_payload, guild_id=123321123123), @@ -322,7 +407,10 @@ def test_deserialize_thread_members_update_event_partial( assert event.guild_presences == {} def test_deserialize_thread_list_sync_event( - self, event_factory: event_factory_.EventFactoryImpl, mock_app: mock.Mock, mock_shard: shard.GatewayShard + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): mock_thread_payload = {"id": "342123123", "name": "nyaa"} mock_other_thread_payload = {"id": "5454123123", "name": "meow"} @@ -334,12 +422,6 @@ def test_deserialize_thread_list_sync_event( mock_not_in_thread = mock.Mock(id=94949494) mock_member = mock.Mock(thread_id=342123123) mock_other_member = mock.Mock(thread_id=5454123123) - mock_app.entity_factory.deserialize_guild_thread.side_effect = [ - mock_thread, - mock_not_in_thread, - mock_other_thread, - ] - mock_app.entity_factory.deserialize_thread_member.side_effect = [mock_member, mock_other_member] mock_payload = { "guild_id": "43123123", "channel_ids": ["54123", "123431", "43939", "12343123"], @@ -347,17 +429,27 @@ def test_deserialize_thread_list_sync_event( "members": [mock_other_member_payload, mokc_member_payload], } - event = event_factory.deserialize_thread_list_sync_event(mock_shard, mock_payload) + with ( + mock.patch.object( + hikari_app.entity_factory, + "deserialize_guild_thread", + side_effect=[mock_thread, mock_not_in_thread, mock_other_thread], + ) as patched_deserialize_guild_thread, + mock.patch.object( + hikari_app.entity_factory, "deserialize_thread_member", side_effect=[mock_member, mock_other_member] + ) as patched_deserialize_thread_member, + ): + event = event_factory.deserialize_thread_list_sync_event(mock_shard, mock_payload) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.guild_id == 43123123 assert event.channel_ids == [54123, 123431, 43939, 12343123] assert event.threads == {342123123: mock_thread, 5454123123: mock_other_thread, 94949494: mock_not_in_thread} - mock_app.entity_factory.deserialize_thread_member.assert_has_calls( + patched_deserialize_thread_member.assert_has_calls( [mock.call(mock_other_member_payload), mock.call(mokc_member_payload)] ) - mock_app.entity_factory.deserialize_guild_thread.assert_has_calls( + patched_deserialize_guild_thread.assert_has_calls( [ mock.call(mock_thread_payload, guild_id=43123123, member=mock_member), mock.call(mock_not_in_thread_payload, guild_id=43123123, member=None), @@ -366,43 +458,64 @@ def test_deserialize_thread_list_sync_event( ) def test_deserialize_thread_list_sync_event_when_not_channel_ids( - self, event_factory: event_factory_.EventFactoryImpl, mock_app: mock.Mock, mock_shard: shard.GatewayShard + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): - mock_payload = {"guild_id": "123321", "threads": [], "members": []} + mock_payload: dict[str, typing.Any] = {"guild_id": "123321", "threads": [], "members": []} event = event_factory.deserialize_thread_list_sync_event(mock_shard, mock_payload) assert event.channel_ids is None - def test_deserialize_webhook_update_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_webhook_update_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"guild_id": "123123", "channel_id": "4393939"} event = event_factory.deserialize_webhook_update_event(mock_shard, mock_payload) assert isinstance(event, channel_events.WebhookUpdateEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.channel_id == 4393939 assert event.guild_id == 123123 - def test_deserialize_invite_create_event(self, event_factory, mock_app, mock_shard): - mock_payload = mock.Mock(app=mock_app) + def test_deserialize_invite_create_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock(app=hikari_app) - event = event_factory.deserialize_invite_create_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_invite_with_metadata" + ) as patched_deserialize_invite_with_metadata: + event = event_factory.deserialize_invite_create_event(mock_shard, mock_payload) - mock_app.entity_factory.deserialize_invite_with_metadata.assert_called_once_with(mock_payload) + patched_deserialize_invite_with_metadata.assert_called_once_with(mock_payload) assert isinstance(event, channel_events.InviteCreateEvent) assert event.shard is mock_shard - assert event.invite is mock_app.entity_factory.deserialize_invite_with_metadata.return_value + assert event.invite is patched_deserialize_invite_with_metadata.return_value - def test_deserialize_invite_delete_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_invite_delete_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"guild_id": "1231234", "channel_id": "123123", "code": "no u"} - mock_old_invite = object() + mock_old_invite = mock.Mock() event = event_factory.deserialize_invite_delete_event(mock_shard, mock_payload, old_invite=mock_old_invite) assert isinstance(event, channel_events.InviteDeleteEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.channel_id == 123123 assert event.guild_id == 1231234 @@ -413,33 +526,45 @@ def test_deserialize_invite_delete_event(self, event_factory, mock_app, mock_sha # TYPING EVENTS # ################## - def test_deserialize_typing_start_event_for_guild(self, event_factory, mock_app, mock_shard): - mock_member_payload = object() + def test_deserialize_typing_start_event_for_guild( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_member_payload = mock.Mock() mock_payload = { "guild_id": "123321", "channel_id": "48585858", "timestamp": 7634521233, "member": mock_member_payload, } - mock_app.entity_factory.deserialize_member.return_value = mock.Mock(app=mock_app) - event = event_factory.deserialize_typing_start_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_member", return_value=mock.Mock(app=hikari_app) + ) as patched_deserialize_member: + event = event_factory.deserialize_typing_start_event(mock_shard, mock_payload) - mock_app.entity_factory.deserialize_member.assert_called_once_with(mock_member_payload, guild_id=123321) + patched_deserialize_member.assert_called_once_with(mock_member_payload, guild_id=123321) assert isinstance(event, typing_events.GuildTypingEvent) assert event.shard is mock_shard assert event.channel_id == 48585858 assert event.guild_id == 123321 assert event.timestamp == datetime.datetime(2211, 12, 6, 12, 20, 33, tzinfo=datetime.timezone.utc) - assert event.member == mock_app.entity_factory.deserialize_member.return_value + assert event.member == patched_deserialize_member.return_value - def test_deserialize_typing_start_event_for_dm(self, event_factory, mock_app, mock_shard): + def test_deserialize_typing_start_event_for_dm( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"channel_id": "534123", "timestamp": 7634521212, "user_id": "9494994"} event = event_factory.deserialize_typing_start_event(mock_shard, mock_payload) assert isinstance(event, typing_events.DMTypingEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.channel_id == 534123 assert event.timestamp == datetime.datetime(2211, 12, 6, 12, 20, 12, tzinfo=datetime.timezone.utc) @@ -449,17 +574,28 @@ def test_deserialize_typing_start_event_for_dm(self, event_factory, mock_app, mo # GUILD EVENTS # ################ - def test_deserialize_guild_available_event(self, event_factory, mock_app, mock_shard): - mock_payload = mock.Mock(app=mock_app) - - event = event_factory.deserialize_guild_available_event(mock_shard, mock_payload) - - mock_app.entity_factory.deserialize_gateway_guild.assert_called_once_with( - mock_payload, user_id=mock_shard.get_user_id.return_value + def test_deserialize_guild_available_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock(app=hikari_app) + + with ( + mock.patch.object(mock_shard, "get_user_id") as patched_get_user_id, + mock.patch.object( + hikari_app.entity_factory, "deserialize_gateway_guild" + ) as patched_deserialize_gateway_guild, + ): + event = event_factory.deserialize_guild_available_event(mock_shard, mock_payload) + + patched_deserialize_gateway_guild.assert_called_once_with( + mock_payload, user_id=patched_get_user_id.return_value ) assert isinstance(event, guild_events.GuildAvailableEvent) assert event.shard is mock_shard - guild_definition = mock_app.entity_factory.deserialize_gateway_guild.return_value + guild_definition = patched_deserialize_gateway_guild.return_value assert event.guild is guild_definition.guild.return_value assert event.emojis is guild_definition.emojis.return_value assert event.stickers is guild_definition.stickers.return_value @@ -476,19 +612,30 @@ def test_deserialize_guild_available_event(self, event_factory, mock_app, mock_s guild_definition.members.assert_called_once_with() guild_definition.presences.assert_called_once_with() guild_definition.voice_states.assert_called_once_with() - mock_shard.get_user_id.assert_called_once_with() - - def test_deserialize_guild_join_event(self, event_factory, mock_app, mock_shard): - mock_payload = mock.Mock(app=mock_app) + patched_get_user_id.assert_called_once_with() - event = event_factory.deserialize_guild_join_event(mock_shard, mock_payload) - - mock_app.entity_factory.deserialize_gateway_guild.assert_called_once_with( - mock_payload, user_id=mock_shard.get_user_id.return_value + def test_deserialize_guild_join_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock(app=hikari_app) + + with ( + mock.patch.object(mock_shard, "get_user_id") as patched_get_user_id, + mock.patch.object( + hikari_app.entity_factory, "deserialize_gateway_guild" + ) as patched_deserialize_gateway_guild, + ): + event = event_factory.deserialize_guild_join_event(mock_shard, mock_payload) + + patched_deserialize_gateway_guild.assert_called_once_with( + mock_payload, user_id=patched_get_user_id.return_value ) assert isinstance(event, guild_events.GuildJoinEvent) assert event.shard is mock_shard - guild_definition = mock_app.entity_factory.deserialize_gateway_guild.return_value + guild_definition = patched_deserialize_gateway_guild.return_value assert event.guild is guild_definition.guild.return_value assert event.emojis is guild_definition.emojis.return_value assert event.roles is guild_definition.roles.return_value @@ -496,20 +643,31 @@ def test_deserialize_guild_join_event(self, event_factory, mock_app, mock_shard) assert event.members is guild_definition.members.return_value assert event.presences is guild_definition.presences.return_value assert event.voice_states is guild_definition.voice_states.return_value - mock_shard.get_user_id.assert_called_once_with() - - def test_deserialize_guild_update_event(self, event_factory, mock_app, mock_shard): - mock_payload = mock.Mock(app=mock_app) - mock_old_guild = object() + patched_get_user_id.assert_called_once_with() - event = event_factory.deserialize_guild_update_event(mock_shard, mock_payload, old_guild=mock_old_guild) - - mock_app.entity_factory.deserialize_gateway_guild.assert_called_once_with( - mock_payload, user_id=mock_shard.get_user_id.return_value + def test_deserialize_guild_update_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock(app=hikari_app) + mock_old_guild = mock.Mock() + + with ( + mock.patch.object(mock_shard, "get_user_id") as patched_get_user_id, + mock.patch.object( + hikari_app.entity_factory, "deserialize_gateway_guild" + ) as patched_deserialize_gateway_guild, + ): + event = event_factory.deserialize_guild_update_event(mock_shard, mock_payload, old_guild=mock_old_guild) + + patched_deserialize_gateway_guild.assert_called_once_with( + mock_payload, user_id=patched_get_user_id.return_value ) assert isinstance(event, guild_events.GuildUpdateEvent) assert event.shard is mock_shard - guild_definition = mock_app.entity_factory.deserialize_gateway_guild.return_value + guild_definition = patched_deserialize_gateway_guild.return_value assert event.guild is guild_definition.guild.return_value assert event.emojis is guild_definition.emojis.return_value assert event.roles is guild_definition.roles.return_value @@ -517,148 +675,215 @@ def test_deserialize_guild_update_event(self, event_factory, mock_app, mock_shar guild_definition.guild.assert_called_once_with() guild_definition.emojis.assert_called_once_with() guild_definition.roles.assert_called_once_with() - mock_shard.get_user_id.assert_called_once_with() + patched_get_user_id.assert_called_once_with() - def test_deserialize_guild_leave_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_guild_leave_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"id": "43123123"} - mock_old_guild = object() + mock_old_guild = mock.Mock() event = event_factory.deserialize_guild_leave_event(mock_shard, mock_payload, old_guild=mock_old_guild) assert isinstance(event, guild_events.GuildLeaveEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.guild_id == 43123123 assert event.old_guild is mock_old_guild - def test_deserialize_guild_unavailable_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_guild_unavailable_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"id": "6541233"} event = event_factory.deserialize_guild_unavailable_event(mock_shard, mock_payload) assert isinstance(event, guild_events.GuildUnavailableEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.guild_id == 6541233 - def test_deserialize_guild_ban_add_event(self, event_factory, mock_app, mock_shard): - mock_user_payload = mock.Mock(app=mock_app) + def test_deserialize_guild_ban_add_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_user_payload = mock.Mock(app=hikari_app) mock_payload = {"guild_id": "4212312", "user": mock_user_payload} - event = event_factory.deserialize_guild_ban_add_event(mock_shard, mock_payload) + with mock.patch.object(hikari_app.entity_factory, "deserialize_user") as patched_deserialize_user: + event = event_factory.deserialize_guild_ban_add_event(mock_shard, mock_payload) - mock_app.entity_factory.deserialize_user.assert_called_once_with(mock_user_payload) + patched_deserialize_user.assert_called_once_with(mock_user_payload) assert isinstance(event, guild_events.BanCreateEvent) assert event.shard is mock_shard assert event.guild_id == 4212312 - assert event.user is mock_app.entity_factory.deserialize_user.return_value + assert event.user is patched_deserialize_user.return_value - def test_deserialize_guild_ban_remove_event(self, event_factory, mock_app, mock_shard): - mock_user_payload = mock.Mock(app=mock_app) + def test_deserialize_guild_ban_remove_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_user_payload = mock.Mock(app=hikari_app) mock_payload = {"guild_id": "9292929", "user": mock_user_payload} - event = event_factory.deserialize_guild_ban_remove_event(mock_shard, mock_payload) + with mock.patch.object(hikari_app.entity_factory, "deserialize_user") as patched_deserialize_user: + event = event_factory.deserialize_guild_ban_remove_event(mock_shard, mock_payload) - mock_app.entity_factory.deserialize_user.assert_called_once_with(mock_user_payload) + patched_deserialize_user.assert_called_once_with(mock_user_payload) assert isinstance(event, guild_events.BanDeleteEvent) assert event.shard is mock_shard assert event.guild_id == 9292929 - assert event.user is mock_app.entity_factory.deserialize_user.return_value + assert event.user is patched_deserialize_user.return_value - def test_deserialize_guild_emojis_update_event(self, event_factory, mock_app, mock_shard): - mock_emoji_payload = object() - mock_old_emojis = object() + def test_deserialize_guild_emojis_update_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_emoji_payload = mock.Mock() + mock_old_emojis = mock.Mock() mock_payload = {"guild_id": "123431", "emojis": [mock_emoji_payload]} - event = event_factory.deserialize_guild_emojis_update_event( - mock_shard, mock_payload, old_emojis=mock_old_emojis - ) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_known_custom_emoji" + ) as patched_deserialize_known_custom_emoji: + event = event_factory.deserialize_guild_emojis_update_event( + mock_shard, mock_payload, old_emojis=mock_old_emojis + ) - mock_app.entity_factory.deserialize_known_custom_emoji.assert_called_once_with( - mock_emoji_payload, guild_id=123431 - ) + patched_deserialize_known_custom_emoji.assert_called_once_with(mock_emoji_payload, guild_id=123431) assert isinstance(event, guild_events.EmojisUpdateEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard - assert event.emojis == [mock_app.entity_factory.deserialize_known_custom_emoji.return_value] + assert event.emojis == [patched_deserialize_known_custom_emoji.return_value] assert event.guild_id == 123431 assert event.old_emojis is mock_old_emojis - def test_deserialize_guild_stickers_update_event(self, event_factory, mock_app, mock_shard): - mock_sticker_payload = object() - mock_old_stickers = object() + def test_deserialize_guild_stickers_update_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_sticker_payload = mock.Mock() + mock_old_stickers = mock.Mock() mock_payload = {"guild_id": "472", "stickers": [mock_sticker_payload]} - event = event_factory.deserialize_guild_stickers_update_event( - mock_shard, mock_payload, old_stickers=mock_old_stickers - ) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_guild_sticker" + ) as patched_deserialize_guild_sticker: + event = event_factory.deserialize_guild_stickers_update_event( + mock_shard, mock_payload, old_stickers=mock_old_stickers + ) - mock_app.entity_factory.deserialize_guild_sticker.assert_called_once_with(mock_sticker_payload) + patched_deserialize_guild_sticker.assert_called_once_with(mock_sticker_payload) assert isinstance(event, guild_events.StickersUpdateEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard - assert event.stickers == [mock_app.entity_factory.deserialize_guild_sticker.return_value] + assert event.stickers == [patched_deserialize_guild_sticker.return_value] assert event.guild_id == 472 assert event.old_stickers is mock_old_stickers - def test_deserialize_integration_create_event(self, event_factory, mock_app, mock_shard): - mock_payload = object() + def test_deserialize_integration_create_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock() - event = event_factory.deserialize_integration_create_event(mock_shard, mock_payload) + with mock.patch.object(hikari_app.entity_factory, "deserialize_integration") as patched_deserialize_integration: + event = event_factory.deserialize_integration_create_event(mock_shard, mock_payload) - mock_app.entity_factory.deserialize_integration.assert_called_once_with(mock_payload) + patched_deserialize_integration.assert_called_once_with(mock_payload) assert isinstance(event, guild_events.IntegrationCreateEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard - assert event.integration is mock_app.entity_factory.deserialize_integration.return_value + assert event.integration is patched_deserialize_integration.return_value - def test_deserialize_integration_delete_event_with_application_id(self, event_factory, mock_app, mock_shard): + def test_deserialize_integration_delete_event_with_application_id( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"id": "123321", "guild_id": "59595959", "application_id": "934949494"} event = event_factory.deserialize_integration_delete_event(mock_shard, mock_payload) assert isinstance(event, guild_events.IntegrationDeleteEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.id == 123321 assert event.guild_id == 59595959 assert event.application_id == 934949494 - def test_deserialize_integration_delete_event_without_application_id(self, event_factory, mock_app, mock_shard): + def test_deserialize_integration_delete_event_without_application_id( + self, event_factory: event_factory_.EventFactoryImpl, mock_shard: shard.GatewayShard + ): mock_payload = {"id": "123321", "guild_id": "59595959"} event = event_factory.deserialize_integration_delete_event(mock_shard, mock_payload) assert event.application_id is None - def test_deserialize_integration_update_event(self, event_factory, mock_app, mock_shard): - mock_payload = object() + def test_deserialize_integration_update_event( + self, + hikari_app: traits.RESTAware, + event_factory: event_factory_.EventFactoryImpl, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock() - event = event_factory.deserialize_integration_update_event(mock_shard, mock_payload) + with mock.patch.object(hikari_app.entity_factory, "deserialize_integration") as patched_deserialize_integration: + event = event_factory.deserialize_integration_update_event(mock_shard, mock_payload) - mock_app.entity_factory.deserialize_integration.assert_called_once_with(mock_payload) + patched_deserialize_integration.assert_called_once_with(mock_payload) assert isinstance(event, guild_events.IntegrationUpdateEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard - assert event.integration is mock_app.entity_factory.deserialize_integration.return_value + assert event.integration is patched_deserialize_integration.return_value - def test_deserialize_presence_update_event_with_only_user_id(self, event_factory, mock_app, mock_shard): + def test_deserialize_presence_update_event_with_only_user_id( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"user": {"id": "1231312"}} - mock_old_presence = object() - mock_app.entity_factory.deserialize_member_presence.return_value = mock.Mock(app=mock_app) + mock_old_presence = mock.Mock() - event = event_factory.deserialize_presence_update_event( - mock_shard, mock_payload, old_presence=mock_old_presence - ) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_member_presence", return_value=mock.Mock(app=hikari_app) + ) as patched_deserialize_member_presence: + event = event_factory.deserialize_presence_update_event( + mock_shard, mock_payload, old_presence=mock_old_presence + ) - mock_app.entity_factory.deserialize_member_presence.assert_called_once_with(mock_payload) + patched_deserialize_member_presence.assert_called_once_with(mock_payload) assert isinstance(event, guild_events.PresenceUpdateEvent) assert event.shard is mock_shard assert event.old_presence is mock_old_presence assert event.user is None - assert event.presence is mock_app.entity_factory.deserialize_member_presence.return_value + assert event.presence is patched_deserialize_member_presence.return_value - def test_deserialize_presence_update_event_with_full_user_object(self, event_factory, mock_app, mock_shard): + def test_deserialize_presence_update_event_with_full_user_object( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = { "user": { "id": "1231312", @@ -673,13 +898,16 @@ def test_deserialize_presence_update_event_with_full_user_object(self, event_fac "discriminator": "1231", } } - mock_old_presence = mock.Mock(app=mock_app) + mock_old_presence = mock.Mock(app=hikari_app) - event = event_factory.deserialize_presence_update_event( - mock_shard, mock_payload, old_presence=mock_old_presence - ) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_member_presence" + ) as patched_deserialize_member_presence: + event = event_factory.deserialize_presence_update_event( + mock_shard, mock_payload, old_presence=mock_old_presence + ) - mock_app.entity_factory.deserialize_member_presence.assert_called_once_with(mock_payload) + patched_deserialize_member_presence.assert_called_once_with(mock_payload) assert isinstance(event, guild_events.PresenceUpdateEvent) assert event.shard is mock_shard assert event.old_presence is mock_old_presence @@ -696,18 +924,25 @@ def test_deserialize_presence_update_event_with_full_user_object(self, event_fac assert event.user.is_system is False assert event.user.flags == 42 - assert event.presence is mock_app.entity_factory.deserialize_member_presence.return_value + assert event.presence is patched_deserialize_member_presence.return_value - def test_deserialize_presence_update_event_with_partial_user_object(self, event_factory, mock_app, mock_shard): + def test_deserialize_presence_update_event_with_partial_user_object( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"user": {"id": "1231312", "e": "OK"}} - mock_old_presence = object() - mock_app.entity_factory.deserialize_member_presence.return_value = mock.Mock(app=mock_app) + mock_old_presence = mock.Mock() - event = event_factory.deserialize_presence_update_event( - mock_shard, mock_payload, old_presence=mock_old_presence - ) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_member_presence", return_value=mock.Mock(app=hikari_app) + ) as patched_deserialize_member_presence: + event = event_factory.deserialize_presence_update_event( + mock_shard, mock_payload, old_presence=mock_old_presence + ) - mock_app.entity_factory.deserialize_member_presence.assert_called_once_with(mock_payload) + patched_deserialize_member_presence.assert_called_once_with(mock_payload) assert isinstance(event, guild_events.PresenceUpdateEvent) assert event.shard is mock_shard assert event.old_presence is mock_old_presence @@ -724,17 +959,23 @@ def test_deserialize_presence_update_event_with_partial_user_object(self, event_ assert event.user.is_system is undefined.UNDEFINED assert event.user.flags is undefined.UNDEFINED - assert event.presence is mock_app.entity_factory.deserialize_member_presence.return_value + assert event.presence is patched_deserialize_member_presence.return_value def test_deserialize_audit_log_entry_create_event( - self, event_factory: event_factory_.EventFactoryImpl, mock_app, mock_shard + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): payload = {"id": "439034093490"} - result = event_factory.deserialize_audit_log_entry_create_event(mock_shard, payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_audit_log_entry" + ) as patched_deserialize_audit_log_entry: + result = event_factory.deserialize_audit_log_entry_create_event(mock_shard, payload) - mock_app.entity_factory.deserialize_audit_log_entry.assert_called_once_with(payload) - assert result.entry is mock_app.entity_factory.deserialize_audit_log_entry.return_value + patched_deserialize_audit_log_entry.assert_called_once_with(payload) + assert result.entry is patched_deserialize_audit_log_entry.return_value assert result.shard is mock_shard assert isinstance(result, guild_events.AuditLogEntryCreateEvent) @@ -753,23 +994,28 @@ def test_deserialize_audit_log_entry_create_event( ) def test_deserialize_interaction_create_event( self, - event_factory, - mock_app, - mock_shard, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, interaction_type: typing.Optional[base_interactions.InteractionType], - expected: interaction_events.InteractionCreateEvent, + expected: type[interaction_events.InteractionCreateEvent], ): payload = {"id": "1561232344"} - if interaction_type: - mock_app.entity_factory.deserialize_interaction.return_value = mock.Mock(type=interaction_type) - result = event_factory.deserialize_interaction_create_event(mock_shard, payload) - mock_app.entity_factory.deserialize_interaction.assert_called_once_with(payload) - assert result.shard is mock_shard - assert result.interaction is mock_app.entity_factory.deserialize_interaction.return_value + with mock.patch.object( + hikari_app.entity_factory, "deserialize_interaction", return_value=mock.Mock(type=interaction_type) + ) as patched_deserialize_interaction: + result = event_factory.deserialize_interaction_create_event(mock_shard, payload) + assert isinstance(result, expected) + assert result.shard is mock_shard + assert result.interaction is patched_deserialize_interaction.return_value + + patched_deserialize_interaction.assert_called_once_with(payload) - def test_deserialize_interaction_create_event_error(self, event_factory, mock_app, mock_shard): + def test_deserialize_interaction_create_event_error( + self, event_factory: event_factory_.EventFactoryImpl, mock_shard: shard.GatewayShard + ): payload = {"id": "1561232344"} with pytest.raises(KeyError): event_factory.deserialize_interaction_create_event(mock_shard, payload) @@ -778,82 +1024,117 @@ def test_deserialize_interaction_create_event_error(self, event_factory, mock_ap # MEMBER EVENTS # ################# - def test_deserialize_guild_member_add_event(self, event_factory, mock_app, mock_shard): - mock_payload = mock.Mock(app=mock_app) + def test_deserialize_guild_member_add_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock(app=hikari_app) - event = event_factory.deserialize_guild_member_add_event(mock_shard, mock_payload) + with mock.patch.object(hikari_app.entity_factory, "deserialize_member") as patched_deserialize_member: + event = event_factory.deserialize_guild_member_add_event(mock_shard, mock_payload) - mock_app.entity_factory.deserialize_member.assert_called_once_with(mock_payload) + patched_deserialize_member.assert_called_once_with(mock_payload) assert isinstance(event, member_events.MemberCreateEvent) assert event.shard is mock_shard - assert event.member is mock_app.entity_factory.deserialize_member.return_value + assert event.member is patched_deserialize_member.return_value - def test_deserialize_guild_member_update_event(self, event_factory, mock_app, mock_shard): - mock_payload = mock.Mock(app=mock_app) - mock_old_member = object() + def test_deserialize_guild_member_update_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock(app=hikari_app) + mock_old_member = mock.Mock() - event = event_factory.deserialize_guild_member_update_event( - mock_shard, mock_payload, old_member=mock_old_member - ) + with mock.patch.object(hikari_app.entity_factory, "deserialize_member") as patched_deserialize_member: + event = event_factory.deserialize_guild_member_update_event( + mock_shard, mock_payload, old_member=mock_old_member + ) - mock_app.entity_factory.deserialize_member.assert_called_once_with(mock_payload) + patched_deserialize_member.assert_called_once_with(mock_payload) assert isinstance(event, member_events.MemberUpdateEvent) assert event.shard is mock_shard - assert event.member is mock_app.entity_factory.deserialize_member.return_value + assert event.member is patched_deserialize_member.return_value assert event.old_member is mock_old_member - def test_deserialize_guild_member_remove_event(self, event_factory, mock_app, mock_shard): - mock_user_payload = mock.Mock(app=mock_app) - mock_old_member = object() + def test_deserialize_guild_member_remove_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_user_payload = mock.Mock(app=hikari_app) + mock_old_member = mock.Mock() mock_payload = {"guild_id": "43123", "user": mock_user_payload} - event = event_factory.deserialize_guild_member_remove_event( - mock_shard, mock_payload, old_member=mock_old_member - ) + with mock.patch.object(hikari_app.entity_factory, "deserialize_user") as patched_deserialize_user: + event = event_factory.deserialize_guild_member_remove_event( + mock_shard, mock_payload, old_member=mock_old_member + ) - mock_app.entity_factory.deserialize_user.assert_called_once_with(mock_user_payload) + patched_deserialize_user.assert_called_once_with(mock_user_payload) assert isinstance(event, member_events.MemberDeleteEvent) assert event.shard is mock_shard assert event.guild_id == 43123 - assert event.user is mock_app.entity_factory.deserialize_user.return_value + assert event.user is patched_deserialize_user.return_value assert event.old_member is mock_old_member ############### # ROLE EVENTS # ############### - def test_deserialize_guild_role_create_event(self, event_factory, mock_app, mock_shard): - mock_role_payload = mock.Mock(app=mock_app) + def test_deserialize_guild_role_create_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_role_payload = mock.Mock(app=hikari_app) mock_payload = {"role": mock_role_payload, "guild_id": "45123"} - event = event_factory.deserialize_guild_role_create_event(mock_shard, mock_payload) + with mock.patch.object(hikari_app.entity_factory, "deserialize_role") as patched_deserialize_role: + event = event_factory.deserialize_guild_role_create_event(mock_shard, mock_payload) - mock_app.entity_factory.deserialize_role.assert_called_once_with(mock_role_payload, guild_id=45123) + patched_deserialize_role.assert_called_once_with(mock_role_payload, guild_id=45123) assert isinstance(event, role_events.RoleCreateEvent) assert event.shard is mock_shard - assert event.role is mock_app.entity_factory.deserialize_role.return_value + assert event.role is patched_deserialize_role.return_value - def test_deserialize_guild_role_update_event(self, event_factory, mock_app, mock_shard): - mock_role_payload = mock.Mock(app=mock_app) - mock_old_role = object() + def test_deserialize_guild_role_update_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_role_payload = mock.Mock(app=hikari_app) + mock_old_role = mock.Mock() mock_payload = {"role": mock_role_payload, "guild_id": "45123"} - event = event_factory.deserialize_guild_role_update_event(mock_shard, mock_payload, old_role=mock_old_role) + with mock.patch.object(hikari_app.entity_factory, "deserialize_role") as patched_deserialize_role: + event = event_factory.deserialize_guild_role_update_event(mock_shard, mock_payload, old_role=mock_old_role) - mock_app.entity_factory.deserialize_role.assert_called_once_with(mock_role_payload, guild_id=45123) + patched_deserialize_role.assert_called_once_with(mock_role_payload, guild_id=45123) assert isinstance(event, role_events.RoleUpdateEvent) assert event.shard is mock_shard - assert event.role is mock_app.entity_factory.deserialize_role.return_value + assert event.role is patched_deserialize_role.return_value assert event.old_role is mock_old_role - def test_deserialize_guild_role_delete_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_guild_role_delete_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"guild_id": "432123", "role_id": "848484"} - mock_old_role = object() + mock_old_role = mock.Mock() event = event_factory.deserialize_guild_role_delete_event(mock_shard, mock_payload, old_role=mock_old_role) assert isinstance(event, role_events.RoleDeleteEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.guild_id == 432123 assert event.role_id == 848484 @@ -864,43 +1145,61 @@ def test_deserialize_guild_role_delete_event(self, event_factory, mock_app, mock ########################## def test_deserialize_scheduled_event_create_event( - self, event_factory: event_factory_.EventFactoryImpl, mock_app: traits.RESTAware, mock_shard: mock.Mock + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): mock_payload = mock.Mock() - event = event_factory.deserialize_scheduled_event_create_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_scheduled_event" + ) as patched_deserialize_scheduled_event: + event = event_factory.deserialize_scheduled_event_create_event(mock_shard, mock_payload) assert event.shard is mock_shard - assert event.event is mock_app.entity_factory.deserialize_scheduled_event.return_value + assert event.event is patched_deserialize_scheduled_event.return_value assert isinstance(event, scheduled_events.ScheduledEventCreateEvent) - mock_app.entity_factory.deserialize_scheduled_event.assert_called_once_with(mock_payload) + patched_deserialize_scheduled_event.assert_called_once_with(mock_payload) def test_deserialize_scheduled_event_update_event( - self, event_factory: event_factory_.EventFactoryImpl, mock_app: traits.RESTAware, mock_shard: mock.Mock + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): mock_payload = mock.Mock() - event = event_factory.deserialize_scheduled_event_update_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_scheduled_event" + ) as patched_deserialize_scheduled_event: + event = event_factory.deserialize_scheduled_event_update_event(mock_shard, mock_payload) assert event.shard is mock_shard - assert event.event is mock_app.entity_factory.deserialize_scheduled_event.return_value + assert event.event is patched_deserialize_scheduled_event.return_value assert isinstance(event, scheduled_events.ScheduledEventUpdateEvent) - mock_app.entity_factory.deserialize_scheduled_event.assert_called_once_with(mock_payload) + patched_deserialize_scheduled_event.assert_called_once_with(mock_payload) def test_deserialize_scheduled_event_delete_event( - self, event_factory: event_factory_.EventFactoryImpl, mock_app: traits.RESTAware, mock_shard: mock.Mock + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): mock_payload = mock.Mock() - event = event_factory.deserialize_scheduled_event_delete_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_scheduled_event" + ) as patched_deserialize_scheduled_event: + event = event_factory.deserialize_scheduled_event_delete_event(mock_shard, mock_payload) assert event.shard is mock_shard - assert event.event is mock_app.entity_factory.deserialize_scheduled_event.return_value + assert event.event is patched_deserialize_scheduled_event.return_value assert isinstance(event, scheduled_events.ScheduledEventDeleteEvent) - mock_app.entity_factory.deserialize_scheduled_event.assert_called_once_with(mock_payload) + patched_deserialize_scheduled_event.assert_called_once_with(mock_payload) def test_deserialize_scheduled_event_user_add_event( - self, event_factory: event_factory_.EventFactoryImpl, mock_app: mock.Mock, mock_shard: mock.Mock + self, event_factory: event_factory_.EventFactoryImpl, mock_shard: shard.GatewayShard ): mock_payload = {"guild_id": "494949", "user_id": "123123123", "guild_scheduled_event_id": "49494944"} @@ -913,7 +1212,7 @@ def test_deserialize_scheduled_event_user_add_event( assert isinstance(event, scheduled_events.ScheduledEventUserAddEvent) def test_deserialize_scheduled_event_user_remove_event( - self, event_factory: event_factory_.EventFactoryImpl, mock_app: mock.Mock, mock_shard: mock.Mock + self, event_factory: event_factory_.EventFactoryImpl, mock_shard: shard.GatewayShard ): mock_payload = {"guild_id": "3244321", "user_id": "56423", "guild_scheduled_event_id": "1234312"} @@ -929,115 +1228,178 @@ def test_deserialize_scheduled_event_user_remove_event( # LIFETIME EVENTS # ################### - def test_deserialize_starting_event(self, event_factory, mock_app): + def test_deserialize_starting_event( + self, event_factory: event_factory_.EventFactoryImpl, hikari_app: traits.RESTAware + ): event = event_factory.deserialize_starting_event() assert isinstance(event, lifetime_events.StartingEvent) - assert event.app is mock_app + assert event.app is hikari_app - def test_deserialize_started_event(self, event_factory, mock_app): + def test_deserialize_started_event( + self, event_factory: event_factory_.EventFactoryImpl, hikari_app: traits.RESTAware + ): event = event_factory.deserialize_started_event() assert isinstance(event, lifetime_events.StartedEvent) - assert event.app is mock_app + assert event.app is hikari_app - def test_deserialize_stopping_event(self, event_factory, mock_app): + def test_deserialize_stopping_event( + self, event_factory: event_factory_.EventFactoryImpl, hikari_app: traits.RESTAware + ): event = event_factory.deserialize_stopping_event() assert isinstance(event, lifetime_events.StoppingEvent) - assert event.app is mock_app + assert event.app is hikari_app - def test_deserialize_stopped_event(self, event_factory, mock_app): + def test_deserialize_stopped_event( + self, event_factory: event_factory_.EventFactoryImpl, hikari_app: traits.RESTAware + ): event = event_factory.deserialize_stopped_event() assert isinstance(event, lifetime_events.StoppedEvent) - assert event.app is mock_app + assert event.app is hikari_app ################## # MESSAGE EVENTS # ################## - def test_deserialize_message_create_event_in_guild(self, event_factory, mock_app, mock_shard): - mock_payload = mock.Mock(app=mock_app) - mock_app.entity_factory.deserialize_message.return_value = mock.Mock(guild_id=123321) + def test_deserialize_message_create_event_in_guild( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock(app=hikari_app) - event = event_factory.deserialize_message_create_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_message", mock.Mock(guild_id=123321) + ) as patched_deserialize_message: + event = event_factory.deserialize_message_create_event(mock_shard, mock_payload) assert isinstance(event, message_events.GuildMessageCreateEvent) assert event.shard is mock_shard - assert event.message is mock_app.entity_factory.deserialize_message.return_value + assert event.message is patched_deserialize_message.return_value + patched_deserialize_message.assert_called_once_with(mock_payload) - def test_deserialize_message_create_event_in_dm(self, event_factory, mock_app, mock_shard): - mock_payload = mock.Mock(app=mock_app) - mock_app.entity_factory.deserialize_message.return_value = mock.Mock(guild_id=None) + def test_deserialize_message_create_event_in_dm( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock(app=hikari_app) - event = event_factory.deserialize_message_create_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_message", return_value=mock.Mock(guild_id=None) + ) as patched_deserialize_message: + event = event_factory.deserialize_message_create_event(mock_shard, mock_payload) assert isinstance(event, message_events.DMMessageCreateEvent) assert event.shard is mock_shard - assert event.message is mock_app.entity_factory.deserialize_message.return_value + assert event.message is patched_deserialize_message.return_value + patched_deserialize_message.assert_called_once_with(mock_payload) - def test_deserialize_message_update_event_in_guild(self, event_factory, mock_app, mock_shard): - mock_payload = mock.Mock(app=mock_app) - mock_old_message = object() - mock_app.entity_factory.deserialize_partial_message.return_value = mock.Mock(guild_id=123321, app=mock_app) - - event = event_factory.deserialize_message_update_event(mock_shard, mock_payload, old_message=mock_old_message) + def test_deserialize_message_update_event_in_guild( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock(app=hikari_app) + mock_old_message = mock.Mock() + + with mock.patch.object( + hikari_app.entity_factory, + "deserialize_partial_message", + return_value=mock.Mock(guild_id=123321, app=hikari_app), + ) as patched_deserialize_partial_message: + event = event_factory.deserialize_message_update_event( + mock_shard, mock_payload, old_message=mock_old_message + ) assert isinstance(event, message_events.GuildMessageUpdateEvent) assert event.shard is mock_shard - assert event.message is mock_app.entity_factory.deserialize_partial_message.return_value + assert event.message is patched_deserialize_partial_message.return_value assert event.old_message is mock_old_message + patched_deserialize_partial_message.assert_called_once_with(mock_payload) - def test_deserialize_message_update_event_in_dm(self, event_factory, mock_app, mock_shard): - mock_payload = mock.Mock(app=mock_app) - mock_old_message = object() - mock_app.entity_factory.deserialize_partial_message.return_value = mock.Mock(guild_id=None) - - event = event_factory.deserialize_message_update_event(mock_shard, mock_payload, old_message=mock_old_message) + def test_deserialize_message_update_event_in_dm( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock(app=hikari_app) + mock_old_message = mock.Mock() + + with mock.patch.object( + hikari_app.entity_factory, + "deserialize_partial_message", + return_value=mock.Mock(guild_id=None, app=hikari_app), + ) as patched_deserialize_partial_message: + event = event_factory.deserialize_message_update_event( + mock_shard, mock_payload, old_message=mock_old_message + ) assert isinstance(event, message_events.DMMessageUpdateEvent) assert event.shard is mock_shard - assert event.message is mock_app.entity_factory.deserialize_partial_message.return_value + assert event.message is patched_deserialize_partial_message.return_value assert event.old_message is mock_old_message + patched_deserialize_partial_message.assert_called_once_with(mock_payload) - def test_deserialize_message_delete_event_in_guild(self, event_factory, mock_app, mock_shard): + def test_deserialize_message_delete_event_in_guild( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"id": "5412", "channel_id": "541123", "guild_id": "9494949"} - old_message = object() + old_message = mock.Mock() event = event_factory.deserialize_message_delete_event(mock_shard, mock_payload, old_message=old_message) assert isinstance(event, message_events.GuildMessageDeleteEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.old_message is old_message assert event.channel_id == 541123 assert event.message_id == 5412 assert event.guild_id == 9494949 - def test_deserialize_message_delete_event_in_dm(self, event_factory, mock_app, mock_shard): + def test_deserialize_message_delete_event_in_dm( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"id": "5412", "channel_id": "541123"} - old_message = object() + old_message = mock.Mock() event = event_factory.deserialize_message_delete_event(mock_shard, mock_payload, old_message=old_message) assert isinstance(event, message_events.DMMessageDeleteEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.old_message is old_message assert event.channel_id == 541123 assert event.message_id == 5412 - def test_deserialize_guild_message_delete_bulk_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_guild_message_delete_bulk_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"ids": ["6523423", "345123"], "channel_id": "564123", "guild_id": "4394949"} - old_messages = object() + old_messages = mock.Mock() event = event_factory.deserialize_guild_message_delete_bulk_event( mock_shard, mock_payload, old_messages=old_messages ) assert isinstance(event, message_events.GuildBulkMessageDeleteEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.old_messages is old_messages assert event.channel_id == 564123 @@ -1045,7 +1407,10 @@ def test_deserialize_guild_message_delete_bulk_event(self, event_factory, mock_a assert event.guild_id == 4394949 def test_deserialize_guild_message_delete_bulk_event_when_old_messages_is_none( - self, event_factory, mock_app, mock_shard + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): mock_payload = {"ids": ["6523423", "345123"], "channel_id": "564123", "guild_id": "4394949"} @@ -1058,8 +1423,13 @@ def test_deserialize_guild_message_delete_bulk_event_when_old_messages_is_none( # REACTION EVENTS # ################### - def test_deserialize_message_reaction_add_event_in_guild(self, event_factory, mock_shard, mock_app): - mock_member_payload = mock.Mock(app=mock_app) + def test_deserialize_message_reaction_add_event_in_guild( + self, + event_factory: event_factory_.EventFactoryImpl, + mock_shard: shard.GatewayShard, + hikari_app: traits.RESTAware, + ): + mock_member_payload = mock.Mock(app=hikari_app) mock_payload = { "member": mock_member_payload, "channel_id": "34123", @@ -1068,23 +1438,29 @@ def test_deserialize_message_reaction_add_event_in_guild(self, event_factory, mo "emoji": {"id": "123312", "name": "okok", "animated": True}, } - event = event_factory.deserialize_message_reaction_add_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_member", return_value=mock.Mock(guild_id=None, app=hikari_app) + ) as patched_deserialize_member: + event = event_factory.deserialize_message_reaction_add_event(mock_shard, mock_payload) - mock_app.entity_factory.deserialize_member.assert_called_once_with(mock_member_payload, guild_id=43949494) + patched_deserialize_member.assert_called_once_with(mock_member_payload, guild_id=43949494) assert isinstance(event, reaction_events.GuildReactionAddEvent) assert event.shard is mock_shard assert event.channel_id == 34123 assert event.message_id == 43123123 - assert event.member is mock_app.entity_factory.deserialize_member.return_value + assert event.member is patched_deserialize_member.return_value assert not isinstance(event.emoji_name, emoji_models.UnicodeEmoji) assert event.emoji_name == "okok" assert event.emoji_id == 123312 assert event.is_animated is True def test_deserialize_message_reaction_add_event_in_guild_when_partial_custom( - self, event_factory, mock_shard, mock_app + self, + event_factory: event_factory_.EventFactoryImpl, + mock_shard: shard.GatewayShard, + hikari_app: traits.RESTAware, ): - mock_member_payload = object() + mock_member_payload = mock.Mock() mock_payload = { "member": mock_member_payload, "channel_id": "34123", @@ -1099,8 +1475,13 @@ def test_deserialize_message_reaction_add_event_in_guild_when_partial_custom( assert event.emoji_id == 123312 assert event.emoji_name is None - def test_deserialize_message_reaction_add_event_in_guild_when_unicode(self, event_factory, mock_shard, mock_app): - mock_member_payload = object() + def test_deserialize_message_reaction_add_event_in_guild_when_unicode( + self, + event_factory: event_factory_.EventFactoryImpl, + mock_shard: shard.GatewayShard, + hikari_app: traits.RESTAware, + ): + mock_member_payload = mock.Mock() mock_payload = { "member": mock_member_payload, "channel_id": "34123", @@ -1116,7 +1497,12 @@ def test_deserialize_message_reaction_add_event_in_guild_when_unicode(self, even assert event.emoji_id is None assert event.is_animated is False - def test_deserialize_message_reaction_add_event_in_dm(self, event_factory, mock_shard, mock_app): + def test_deserialize_message_reaction_add_event_in_dm( + self, + event_factory: event_factory_.EventFactoryImpl, + mock_shard: shard.GatewayShard, + hikari_app: traits.RESTAware, + ): mock_payload = { "channel_id": "34123", "message_id": "43123123", @@ -1127,7 +1513,7 @@ def test_deserialize_message_reaction_add_event_in_dm(self, event_factory, mock_ event = event_factory.deserialize_message_reaction_add_event(mock_shard, mock_payload) assert isinstance(event, reaction_events.DMReactionAddEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.channel_id == 34123 assert event.message_id == 43123123 @@ -1138,7 +1524,10 @@ def test_deserialize_message_reaction_add_event_in_dm(self, event_factory, mock_ assert event.is_animated is True def test_deserialize_message_reaction_add_event_in_dm_when_partial_custom( - self, event_factory, mock_shard, mock_app + self, + event_factory: event_factory_.EventFactoryImpl, + mock_shard: shard.GatewayShard, + hikari_app: traits.RESTAware, ): mock_payload = { "channel_id": "34123", @@ -1153,7 +1542,12 @@ def test_deserialize_message_reaction_add_event_in_dm_when_partial_custom( assert event.emoji_id == 3293939 assert event.is_animated is False - def test_deserialize_message_reaction_add_event_in_dm_when_unicode(self, event_factory, mock_shard, mock_app): + def test_deserialize_message_reaction_add_event_in_dm_when_unicode( + self, + event_factory: event_factory_.EventFactoryImpl, + mock_shard: shard.GatewayShard, + hikari_app: traits.RESTAware, + ): mock_payload = { "channel_id": "34123", "message_id": "43123123", @@ -1164,7 +1558,7 @@ def test_deserialize_message_reaction_add_event_in_dm_when_unicode(self, event_f event = event_factory.deserialize_message_reaction_add_event(mock_shard, mock_payload) assert isinstance(event, reaction_events.DMReactionAddEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.channel_id == 34123 assert event.message_id == 43123123 @@ -1174,7 +1568,12 @@ def test_deserialize_message_reaction_add_event_in_dm_when_unicode(self, event_f assert event.emoji_id is None assert event.is_animated is False - def test_deserialize_message_reaction_remove_event_in_guild(self, event_factory, mock_app, mock_shard): + def test_deserialize_message_reaction_remove_event_in_guild( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = { "user_id": "43123", "channel_id": "484848", @@ -1186,7 +1585,7 @@ def test_deserialize_message_reaction_remove_event_in_guild(self, event_factory, event = event_factory.deserialize_message_reaction_remove_event(mock_shard, mock_payload) assert isinstance(event, reaction_events.GuildReactionDeleteEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.user_id == 43123 assert event.channel_id == 484848 @@ -1197,7 +1596,10 @@ def test_deserialize_message_reaction_remove_event_in_guild(self, event_factory, assert not isinstance(event.emoji_name, emoji_models.UnicodeEmoji) def test_deserialize_message_reaction_remove_event_in_guild_with_unicode_emoji( - self, event_factory, mock_app, mock_shard + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): mock_payload = { "user_id": "43123", @@ -1210,7 +1612,7 @@ def test_deserialize_message_reaction_remove_event_in_guild_with_unicode_emoji( event = event_factory.deserialize_message_reaction_remove_event(mock_shard, mock_payload) assert isinstance(event, reaction_events.GuildReactionDeleteEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.user_id == 43123 assert event.channel_id == 484848 @@ -1220,7 +1622,12 @@ def test_deserialize_message_reaction_remove_event_in_guild_with_unicode_emoji( assert event.emoji_name == "o" assert isinstance(event.emoji_name, emoji_models.UnicodeEmoji) - def test_deserialize_message_reaction_remove_event_in_dm(self, event_factory, mock_app, mock_shard): + def test_deserialize_message_reaction_remove_event_in_dm( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = { "user_id": "43123", "channel_id": "484848", @@ -1231,7 +1638,7 @@ def test_deserialize_message_reaction_remove_event_in_dm(self, event_factory, mo event = event_factory.deserialize_message_reaction_remove_event(mock_shard, mock_payload) assert isinstance(event, reaction_events.DMReactionDeleteEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.user_id == 43123 assert event.channel_id == 484848 @@ -1241,14 +1648,17 @@ def test_deserialize_message_reaction_remove_event_in_dm(self, event_factory, mo assert event.emoji_id == 123123 def test_deserialize_message_reaction_remove_event_in_dm_with_unicode_emoji( - self, event_factory, mock_app, mock_shard + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): mock_payload = {"user_id": "43123", "channel_id": "484848", "message_id": "43234", "emoji": {"name": "wwww"}} event = event_factory.deserialize_message_reaction_remove_event(mock_shard, mock_payload) assert isinstance(event, reaction_events.DMReactionDeleteEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.user_id == 43123 assert event.channel_id == 484848 @@ -1257,30 +1667,45 @@ def test_deserialize_message_reaction_remove_event_in_dm_with_unicode_emoji( assert event.emoji_name == "wwww" assert event.emoji_id is None - def test_deserialize_message_reaction_remove_all_event_in_guild(self, event_factory, mock_app, mock_shard): + def test_deserialize_message_reaction_remove_all_event_in_guild( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"channel_id": "312312", "message_id": "34323", "guild_id": "393939"} event = event_factory.deserialize_message_reaction_remove_all_event(mock_shard, mock_payload) assert isinstance(event, reaction_events.GuildReactionDeleteAllEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.channel_id == 312312 assert event.message_id == 34323 assert event.guild_id == 393939 - def test_deserialize_message_reaction_remove_all_event_in_dm(self, event_factory, mock_app, mock_shard): + def test_deserialize_message_reaction_remove_all_event_in_dm( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"channel_id": "312312", "message_id": "34323"} event = event_factory.deserialize_message_reaction_remove_all_event(mock_shard, mock_payload) assert isinstance(event, reaction_events.DMReactionDeleteAllEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.channel_id == 312312 assert event.message_id == 34323 - def test_deserialize_message_reaction_remove_emoji_event_in_guild(self, event_factory, mock_app, mock_shard): + def test_deserialize_message_reaction_remove_emoji_event_in_guild( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = { "channel_id": "123123", "guild_id": "423412", @@ -1291,7 +1716,7 @@ def test_deserialize_message_reaction_remove_emoji_event_in_guild(self, event_fa event = event_factory.deserialize_message_reaction_remove_emoji_event(mock_shard, mock_payload) assert isinstance(event, reaction_events.GuildReactionDeleteEmojiEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.channel_id == 123123 assert event.guild_id == 423412 @@ -1301,7 +1726,10 @@ def test_deserialize_message_reaction_remove_emoji_event_in_guild(self, event_fa assert not isinstance(event.emoji_name, emoji_models.UnicodeEmoji) def test_deserialize_message_reaction_remove_emoji_event_in_guild_with_unicode_emoji( - self, event_factory, mock_app, mock_shard + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): mock_payload = { "channel_id": "123123", @@ -1315,13 +1743,18 @@ def test_deserialize_message_reaction_remove_emoji_event_in_guild_with_unicode_e assert event.emoji_name == "okokok" assert isinstance(event.emoji_name, emoji_models.UnicodeEmoji) - def test_deserialize_message_reaction_remove_emoji_event_in_dm(self, event_factory, mock_app, mock_shard): + def test_deserialize_message_reaction_remove_emoji_event_in_dm( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"channel_id": "123123", "message_id": "99999", "emoji": {"id": "123321", "name": "nom"}} event = event_factory.deserialize_message_reaction_remove_emoji_event(mock_shard, mock_payload) assert isinstance(event, reaction_events.DMReactionDeleteEmojiEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.channel_id == 123123 assert event.message_id == 99999 @@ -1330,14 +1763,17 @@ def test_deserialize_message_reaction_remove_emoji_event_in_dm(self, event_facto assert not isinstance(event.emoji_name, emoji_models.UnicodeEmoji) def test_deserialize_message_reaction_remove_emoji_event_in_dm_with_unicode_emoji( - self, event_factory, mock_app, mock_shard + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, ): mock_payload = {"channel_id": "123123", "message_id": "99999", "emoji": {"name": "gg"}} event = event_factory.deserialize_message_reaction_remove_emoji_event(mock_shard, mock_payload) assert isinstance(event, reaction_events.DMReactionDeleteEmojiEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.channel_id == 123123 assert event.message_id == 99999 @@ -1349,18 +1785,28 @@ def test_deserialize_message_reaction_remove_emoji_event_in_dm_with_unicode_emoj # SHARD EVENTS # ################ - def test_deserialize_shard_payload_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_shard_payload_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"id": "123123"} event = event_factory.deserialize_shard_payload_event(mock_shard, mock_payload, name="ooga booga") - assert event.app is mock_app + assert event.app is hikari_app assert event.name == "ooga booga" assert event.payload == mock_payload assert event.shard is mock_shard - def test_deserialize_ready_event(self, event_factory, mock_app, mock_shard): - mock_user_payload = object() + def test_deserialize_ready_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_user_payload = mock.Mock() mock_payload = { "v": "69", "resume_gateway_url": "testing.com", @@ -1369,43 +1815,65 @@ def test_deserialize_ready_event(self, event_factory, mock_app, mock_shard): "session_id": "kjsdjiodsaiosad", "application": {"id": "4123212", "flags": "4949494"}, } - mock_app.entity_factory.deserialize_my_user.return_value = mock.Mock(app=mock_app) - event = event_factory.deserialize_ready_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_my_user", return_value=mock.Mock(app=hikari_app) + ) as patched_deserialize_my_user: + event = event_factory.deserialize_ready_event(mock_shard, mock_payload) - mock_app.entity_factory.deserialize_my_user.assert_called_once_with(mock_user_payload) + patched_deserialize_my_user.assert_called_once_with(mock_user_payload) assert isinstance(event, shard_events.ShardReadyEvent) assert event.shard is mock_shard assert event.actual_gateway_version == 69 assert event.resume_gateway_url == "testing.com" - assert event.my_user is mock_app.entity_factory.deserialize_my_user.return_value + assert event.my_user is patched_deserialize_my_user.return_value assert event.unavailable_guilds == [432123, 949494] assert event.session_id == "kjsdjiodsaiosad" assert event.application_id == 4123212 assert event.application_flags == 4949494 - def test_deserialize_connected_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_connected_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): event = event_factory.deserialize_connected_event(mock_shard) assert isinstance(event, shard_events.ShardConnectedEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard - def test_deserialize_disconnected_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_disconnected_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): event = event_factory.deserialize_disconnected_event(mock_shard) assert isinstance(event, shard_events.ShardDisconnectedEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard - def test_deserialize_resumed_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_resumed_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): event = event_factory.deserialize_resumed_event(mock_shard) assert isinstance(event, shard_events.ShardResumedEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard - def test_deserialize_guild_member_chunk_event_with_optional_fields(self, event_factory, mock_app, mock_shard): + def test_deserialize_guild_member_chunk_event_with_optional_fields( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_member_payload = {"user": {"id": "4222222"}} mock_presence_payload = {"user": {"id": "43123123"}} mock_payload = { @@ -1418,24 +1886,33 @@ def test_deserialize_guild_member_chunk_event_with_optional_fields(self, event_f "nonce": "OKOKOKOK", } - event = event_factory.deserialize_guild_member_chunk_event(mock_shard, mock_payload) + with ( + mock.patch.object(hikari_app.entity_factory, "deserialize_member") as patched_deserialize_member, + mock.patch.object( + hikari_app.entity_factory, "deserialize_member_presence" + ) as patched_deserialize_member_presence, + ): + event = event_factory.deserialize_guild_member_chunk_event(mock_shard, mock_payload) - mock_app.entity_factory.deserialize_member.assert_called_once_with(mock_member_payload, guild_id=123432123) - mock_app.entity_factory.deserialize_member_presence.assert_called_once_with( - mock_presence_payload, guild_id=123432123 - ) + patched_deserialize_member.assert_called_once_with(mock_member_payload, guild_id=123432123) + patched_deserialize_member_presence.assert_called_once_with(mock_presence_payload, guild_id=123432123) assert isinstance(event, shard_events.MemberChunkEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.guild_id == 123432123 - assert event.members == {4222222: mock_app.entity_factory.deserialize_member.return_value} + assert event.members == {4222222: patched_deserialize_member.return_value} assert event.chunk_count == 54 assert event.chunk_index == 3 assert event.not_found == [34212312312, 323123123] - assert event.presences == {43123123: mock_app.entity_factory.deserialize_member_presence.return_value} + assert event.presences == {43123123: patched_deserialize_member_presence.return_value} assert event.nonce == "OKOKOKOK" - def test_deserialize_guild_member_chunk_event_without_optional_fields(self, event_factory, mock_app, mock_shard): + def test_deserialize_guild_member_chunk_event_without_optional_fields( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_member_payload = {"user": {"id": "4222222"}} mock_payload = {"guild_id": "123432123", "members": [mock_member_payload], "chunk_index": 3, "chunk_count": 54} @@ -1449,45 +1926,64 @@ def test_deserialize_guild_member_chunk_event_without_optional_fields(self, even # USER EVENTS # ############### - def test_deserialize_own_user_update_event(self, event_factory, mock_app, mock_shard): - mock_payload = mock.Mock(app=mock_app) - mock_old_user = object() - mock_app.entity_factory.deserialize_my_user.return_value = mock.Mock(app=mock_app) + def test_deserialize_own_user_update_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock(app=hikari_app) + mock_old_user = mock.Mock() - event = event_factory.deserialize_own_user_update_event(mock_shard, mock_payload, old_user=mock_old_user) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_my_user", return_value=mock.Mock(app=hikari_app) + ) as patched_deserialize_my_user: + event = event_factory.deserialize_own_user_update_event(mock_shard, mock_payload, old_user=mock_old_user) - mock_app.entity_factory.deserialize_my_user.assert_called_once_with(mock_payload) + patched_deserialize_my_user.assert_called_once_with(mock_payload) assert isinstance(event, user_events.OwnUserUpdateEvent) assert event.shard is mock_shard - assert event.user is mock_app.entity_factory.deserialize_my_user.return_value + assert event.user is patched_deserialize_my_user.return_value assert event.old_user is mock_old_user ################ # VOICE EVENTS # ################ - def test_deserialize_voice_state_update_event(self, event_factory, mock_app, mock_shard): - mock_payload = object() - mock_old_voice_state = object() - mock_app.entity_factory.deserialize_voice_state.return_value = mock.Mock(app=mock_app) + def test_deserialize_voice_state_update_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): + mock_payload = mock.Mock() + mock_old_voice_state = mock.Mock() - event = event_factory.deserialize_voice_state_update_event( - mock_shard, mock_payload, old_state=mock_old_voice_state - ) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_voice_state", return_value=mock.Mock(app=hikari_app) + ) as patched_deserialize_voice_state: + event = event_factory.deserialize_voice_state_update_event( + mock_shard, mock_payload, old_state=mock_old_voice_state + ) - mock_app.entity_factory.deserialize_voice_state.assert_called_once_with(mock_payload) + patched_deserialize_voice_state.assert_called_once_with(mock_payload) assert isinstance(event, voice_events.VoiceStateUpdateEvent) assert event.shard is mock_shard - assert event.state is mock_app.entity_factory.deserialize_voice_state.return_value + assert event.state is patched_deserialize_voice_state.return_value assert event.old_state is mock_old_voice_state - def test_deserialize_voice_server_update_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_voice_server_update_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"token": "okokok", "guild_id": "3122312", "endpoint": "httppppppp"} event = event_factory.deserialize_voice_server_update_event(mock_shard, mock_payload) assert isinstance(event, voice_events.VoiceServerUpdateEvent) - assert event.app is mock_app + assert event.app is hikari_app assert event.shard is mock_shard assert event.token == "okokok" assert event.guild_id == 3122312 @@ -1497,7 +1993,12 @@ def test_deserialize_voice_server_update_event(self, event_factory, mock_app, mo # MONETIZATION # ################## - def test_deserialize_entitlement_create_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_entitlement_create_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): payload = { "id": "696969696969696", "sku_id": "420420420420420", @@ -1515,7 +2016,12 @@ def test_deserialize_entitlement_create_event(self, event_factory, mock_app, moc assert isinstance(event, monetization_events.EntitlementCreateEvent) - def test_deserialize_entitlement_update_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_entitlement_update_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): payload = { "id": "696969696969696", "sku_id": "420420420420420", @@ -1533,7 +2039,12 @@ def test_deserialize_entitlement_update_event(self, event_factory, mock_app, moc assert isinstance(event, monetization_events.EntitlementUpdateEvent) - def test_deserialize_entitlement_delete_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_entitlement_delete_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): payload = { "id": "696969696969696", "sku_id": "420420420420420", @@ -1555,7 +2066,12 @@ def test_deserialize_entitlement_delete_event(self, event_factory, mock_app, moc # STAGE INSTANCE EVENTS # ######################### - def test_deserialize_stage_instance_create_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_stage_instance_create_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = { "id": "840647391636226060", "guild_id": "197038439483310086", @@ -1564,14 +2080,24 @@ def test_deserialize_stage_instance_create_event(self, event_factory, mock_app, "privacy_level": 1, "discoverable_disabled": False, } - event = event_factory.deserialize_stage_instance_create_event(mock_shard, mock_payload) + + with mock.patch.object( + hikari_app.entity_factory, "deserialize_stage_instance" + ) as patched_deserialize_stage_instance: + event = event_factory.deserialize_stage_instance_create_event(mock_shard, mock_payload) + assert isinstance(event, stage_events.StageInstanceCreateEvent) assert event.shard is mock_shard assert event.app is event.stage_instance.app - assert event.stage_instance == mock_app.entity_factory.deserialize_stage_instance.return_value + assert event.stage_instance == patched_deserialize_stage_instance.return_value - def test_deserialize_stage_instance_update_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_stage_instance_update_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = { "id": "840647391636226060", "guild_id": "197038439483310086", @@ -1580,14 +2106,23 @@ def test_deserialize_stage_instance_update_event(self, event_factory, mock_app, "privacy_level": 2, "discoverable_disabled": True, } - event = event_factory.deserialize_stage_instance_update_event(mock_shard, mock_payload) + + with mock.patch.object( + hikari_app.entity_factory, "deserialize_stage_instance" + ) as patched_deserialize_stage_instance: + event = event_factory.deserialize_stage_instance_update_event(mock_shard, mock_payload) assert isinstance(event, stage_events.StageInstanceUpdateEvent) assert event.shard is mock_shard assert event.app is event.stage_instance.app - assert event.stage_instance == mock_app.entity_factory.deserialize_stage_instance.return_value + assert event.stage_instance == patched_deserialize_stage_instance.return_value - def test_deserialize_stage_instance_delete_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_stage_instance_delete_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = { "id": "840647391636226060", "guild_id": "197038439483310086", @@ -1596,18 +2131,25 @@ def test_deserialize_stage_instance_delete_event(self, event_factory, mock_app, "privacy_level": 2, "discoverable_disabled": True, } - event = event_factory.deserialize_stage_instance_delete_event(mock_shard, mock_payload) + + with mock.patch.object( + hikari_app.entity_factory, "deserialize_stage_instance" + ) as patched_deserialize_stage_instance: + event = event_factory.deserialize_stage_instance_delete_event(mock_shard, mock_payload) + assert isinstance(event, stage_events.StageInstanceDeleteEvent) assert event.shard is mock_shard assert event.app is event.stage_instance.app - assert event.stage_instance == mock_app.entity_factory.deserialize_stage_instance.return_value + assert event.stage_instance == patched_deserialize_stage_instance.return_value ########### # POLLS # ########### - def test_deserialize_poll_vote_create_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_poll_vote_create_event( + self, event_factory: event_factory_.EventFactoryImpl, mock_shard: shard.GatewayShard + ): payload = { "user_id": "3847382", "channel_id": "4598743", @@ -1620,7 +2162,9 @@ def test_deserialize_poll_vote_create_event(self, event_factory, mock_app, mock_ assert isinstance(event, poll_events.PollVoteCreateEvent) - def test_deserialize_poll_vote_delete_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_poll_vote_delete_event( + self, event_factory: event_factory_.EventFactoryImpl, mock_shard: shard.GatewayShard + ): payload = { "user_id": "3847382", "channel_id": "4598743", @@ -1633,60 +2177,92 @@ def test_deserialize_poll_vote_delete_event(self, event_factory, mock_app, mock_ assert isinstance(event, poll_events.PollVoteDeleteEvent) - def test_deserialize_auto_mod_rule_create_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_auto_mod_rule_create_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"id": "49499494"} - event = event_factory.deserialize_auto_mod_rule_create_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_auto_mod_rule" + ) as patched_deserialize_auto_mod_rule: + event = event_factory.deserialize_auto_mod_rule_create_event(mock_shard, mock_payload) assert isinstance(event, auto_mod_events.AutoModRuleCreateEvent) assert event.shard is mock_shard - assert event.rule is mock_app.entity_factory.deserialize_auto_mod_rule.return_value - mock_app.entity_factory.deserialize_auto_mod_rule.assert_called_once_with(mock_payload) + assert event.rule is patched_deserialize_auto_mod_rule.return_value + patched_deserialize_auto_mod_rule.assert_called_once_with(mock_payload) - def test_deserialize_auto_mod_rule_update_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_auto_mod_rule_update_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"id": "49499494"} - event = event_factory.deserialize_auto_mod_rule_update_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_auto_mod_rule" + ) as patched_deserialize_auto_mod_rule: + event = event_factory.deserialize_auto_mod_rule_update_event(mock_shard, mock_payload) assert isinstance(event, auto_mod_events.AutoModRuleUpdateEvent) assert event.shard is mock_shard - assert event.rule is mock_app.entity_factory.deserialize_auto_mod_rule.return_value - mock_app.entity_factory.deserialize_auto_mod_rule.assert_called_once_with(mock_payload) + assert event.rule is patched_deserialize_auto_mod_rule.return_value + patched_deserialize_auto_mod_rule.assert_called_once_with(mock_payload) - def test_deserialize_auto_mod_rule_delete_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_auto_mod_rule_delete_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_payload = {"id": "49499494"} - event = event_factory.deserialize_auto_mod_rule_delete_event(mock_shard, mock_payload) + with mock.patch.object( + hikari_app.entity_factory, "deserialize_auto_mod_rule" + ) as patched_deserialize_auto_mod_rule: + event = event_factory.deserialize_auto_mod_rule_delete_event(mock_shard, mock_payload) assert isinstance(event, auto_mod_events.AutoModRuleDeleteEvent) assert event.shard is mock_shard - assert event.rule is mock_app.entity_factory.deserialize_auto_mod_rule.return_value - mock_app.entity_factory.deserialize_auto_mod_rule.assert_called_once_with(mock_payload) + assert event.rule is patched_deserialize_auto_mod_rule.return_value + patched_deserialize_auto_mod_rule.assert_called_once_with(mock_payload) - def test_deserialize_auto_mod_action_execution_event(self, event_factory, mock_app, mock_shard): + def test_deserialize_auto_mod_action_execution_event( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_action_payload = {"type": "69"} - event = event_factory.deserialize_auto_mod_action_execution_event( - mock_shard, - { - "guild_id": "123321", - "action": mock_action_payload, - "rule_id": "4959595", - "rule_trigger_type": 3, - "user_id": "4949494", - "channel_id": "5423234", - "message_id": "49343292", - "alert_system_message_id": "49211123", - "content": "meow", - "matched_keyword": "fredf", - "matched_content": "dfofodofdodf", - }, - ) - - assert event.app is mock_app + with mock.patch.object( + hikari_app.entity_factory, "deserialize_auto_mod_action" + ) as patched_deserialize_auto_mod_action: + event = event_factory.deserialize_auto_mod_action_execution_event( + mock_shard, + { + "guild_id": "123321", + "action": mock_action_payload, + "rule_id": "4959595", + "rule_trigger_type": 3, + "user_id": "4949494", + "channel_id": "5423234", + "message_id": "49343292", + "alert_system_message_id": "49211123", + "content": "meow", + "matched_keyword": "fredf", + "matched_content": "dfofodofdodf", + }, + ) + + assert event.app is hikari_app assert event.shard is mock_shard assert event.guild_id == 123321 - assert event.action is mock_app.entity_factory.deserialize_auto_mod_action.return_value + assert event.action is patched_deserialize_auto_mod_action.return_value assert event.rule_id == 4959595 assert event.rule_trigger_type is auto_mod.AutoModTriggerType.SPAM assert event.user_id == 4949494 @@ -1696,29 +2272,37 @@ def test_deserialize_auto_mod_action_execution_event(self, event_factory, mock_a assert event.content == "meow" assert event.matched_keyword == "fredf" assert event.matched_content == "dfofodofdodf" - mock_app.entity_factory.deserialize_auto_mod_action.assert_called_once_with(mock_action_payload) + patched_deserialize_auto_mod_action.assert_called_once_with(mock_action_payload) - def test_deserialize_auto_mod_action_execution_event_when_partial(self, event_factory, mock_app, mock_shard): + def test_deserialize_auto_mod_action_execution_event_when_partial( + self, + event_factory: event_factory_.EventFactoryImpl, + hikari_app: traits.RESTAware, + mock_shard: shard.GatewayShard, + ): mock_action_payload = {"type": "69"} - event = event_factory.deserialize_auto_mod_action_execution_event( - mock_shard, - { - "guild_id": "123321", - "action": mock_action_payload, - "rule_id": "4959595", - "rule_trigger_type": 3, - "user_id": "4949494", - "content": "", - "matched_keyword": None, - "matched_content": None, - }, - ) - - assert event.app is mock_app + with mock.patch.object( + hikari_app.entity_factory, "deserialize_auto_mod_action" + ) as patched_deserialize_auto_mod_action: + event = event_factory.deserialize_auto_mod_action_execution_event( + mock_shard, + { + "guild_id": "123321", + "action": mock_action_payload, + "rule_id": "4959595", + "rule_trigger_type": 3, + "user_id": "4949494", + "content": "", + "matched_keyword": None, + "matched_content": None, + }, + ) + + assert event.app is hikari_app assert event.shard is mock_shard assert event.guild_id == 123321 - assert event.action is mock_app.entity_factory.deserialize_auto_mod_action.return_value + assert event.action is patched_deserialize_auto_mod_action.return_value assert event.rule_id == 4959595 assert event.rule_trigger_type is auto_mod.AutoModTriggerType.SPAM assert event.user_id == 4949494 @@ -1728,4 +2312,4 @@ def test_deserialize_auto_mod_action_execution_event_when_partial(self, event_fa assert event.content is None assert event.matched_keyword is None assert event.matched_content is None - mock_app.entity_factory.deserialize_auto_mod_action.assert_called_once_with(mock_action_payload) + patched_deserialize_auto_mod_action.assert_called_once_with(mock_action_payload) diff --git a/tests/hikari/impl/test_event_manager.py b/tests/hikari/impl/test_event_manager.py index ede420467b..d774b7f794 100644 --- a/tests/hikari/impl/test_event_manager.py +++ b/tests/hikari/impl/test_event_manager.py @@ -22,8 +22,8 @@ import asyncio import base64 -import contextlib import random +import typing import mock import pytest @@ -33,8 +33,11 @@ from hikari import intents from hikari import presences from hikari.api import event_factory as event_factory_ +from hikari.api import shard as shard_api from hikari.events import guild_events from hikari.impl import config +from hikari.impl import entity_factory as entity_factory_impl +from hikari.impl import event_factory as event_factory_impl from hikari.impl import event_manager from hikari.internal import time from tests.hikari import hikari_test_helpers @@ -49,12 +52,12 @@ def test_fixed_size_nonce(): @pytest.fixture -def shard(): +def shard() -> shard_api.GatewayShard: return mock.Mock(id=987) @pytest.mark.asyncio -async def test__request_guild_members(shard): +async def test__request_guild_members(shard: shard_api.GatewayShard): shard.request_guild_members = mock.AsyncMock() await event_manager._request_guild_members(shard, 123, include_presences=True, nonce="okokok") @@ -63,7 +66,7 @@ async def test__request_guild_members(shard): @pytest.mark.asyncio -async def test__request_guild_members_handles_state_conflict_error(shard): +async def test__request_guild_members_handles_state_conflict_error(shard: shard_api.GatewayShard): shard.request_guild_members = mock.AsyncMock(side_effect=errors.ComponentStateConflictError(reason="OK")) await event_manager._request_guild_members(shard, 123, include_presences=True, nonce="okokok") @@ -73,15 +76,17 @@ async def test__request_guild_members_handles_state_conflict_error(shard): class TestEventManagerImpl: @pytest.fixture - def entity_factory(self): + def entity_factory(self) -> entity_factory_impl.EntityFactoryImpl: return mock.Mock() @pytest.fixture - def event_factory(self): + def event_factory(self) -> event_factory_impl.EventFactoryImpl: return mock.Mock() @pytest.fixture - def event_manager_impl(self, entity_factory, event_factory): + def event_manager_impl( + self, entity_factory: entity_factory_impl.EntityFactoryImpl, event_factory: event_factory_impl.EventFactoryImpl + ) -> event_manager.EventManagerImpl: obj = hikari_test_helpers.mock_class_namespace(event_manager.EventManagerImpl, slots_=False)( entity_factory, event_factory, intents.Intents.ALL, cache=mock.Mock(settings=config.CacheSettings()) ) @@ -90,7 +95,9 @@ def event_manager_impl(self, entity_factory, event_factory): return obj @pytest.fixture - def stateless_event_manager_impl(self, event_factory, entity_factory): + def stateless_event_manager_impl( + self, event_factory: event_factory_impl.EventFactoryImpl, entity_factory: entity_factory_impl.EntityFactoryImpl + ) -> event_manager.EventManagerImpl: obj = hikari_test_helpers.mock_class_namespace(event_manager.EventManagerImpl, slots_=False)( entity_factory, event_factory, intents.Intents.ALL, cache=None ) @@ -98,170 +105,305 @@ def stateless_event_manager_impl(self, event_factory, entity_factory): obj.dispatch = mock.Mock() return obj - def test_on_ready_stateful(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_ready_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock(my_user=mock.Mock()) - event_factory.deserialize_ready_event.return_value = event - - event_manager_impl.on_ready(shard, payload) - - event_manager_impl._cache.update_me.assert_called_once_with(event.my_user) - event_factory.deserialize_ready_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_ready_stateless(self, stateless_event_manager_impl, shard, event_factory): - payload = {} + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "update_me") as patched_update_me, + mock.patch.object( + event_factory, "deserialize_ready_event", return_value=event + ) as patched_deserialize_ready_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_ready(shard, payload) + + patched_update_me.assert_called_once_with(event.my_user) + patched_deserialize_ready_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) + + def test_on_ready_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - stateless_event_manager_impl.on_ready(shard, payload) + with ( + mock.patch.object(event_factory, "deserialize_ready_event") as patched_deserialize_ready_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_ready(shard, payload) - event_factory.deserialize_ready_event.assert_called_once_with(shard, payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_ready_event.return_value - ) + patched_deserialize_ready_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(patched_deserialize_ready_event.return_value) - def test_on_resumed(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_resumed( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - event_manager_impl.on_resumed(shard, payload) + with ( + mock.patch.object(event_factory, "deserialize_resumed_event") as patched_deserialize_resumed_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_resumed(shard, payload) - event_factory.deserialize_resumed_event.assert_called_once_with(shard) - event_manager_impl.dispatch.assert_called_once_with(event_factory.deserialize_resumed_event.return_value) + patched_deserialize_resumed_event.assert_called_once_with(shard) + patched_dispatch.assert_called_once_with(patched_deserialize_resumed_event.return_value) - def test_on_application_command_permissions_update(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_application_command_permissions_update( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - event_manager_impl.on_application_command_permissions_update(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_application_command_permission_update_event" + ) as patched_deserialize_application_command_permission_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_application_command_permissions_update(shard, payload) - event_factory.deserialize_application_command_permission_update_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_application_command_permission_update_event.return_value + patched_deserialize_application_command_permission_update_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with( + patched_deserialize_application_command_permission_update_event.return_value ) - def test_on_channel_create_stateful(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_channel_create_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock(channel=mock.Mock(channels.GuildChannel)) - event_factory.deserialize_guild_channel_create_event.return_value = event - - event_manager_impl.on_channel_create(shard, payload) - - event_manager_impl._cache.set_guild_channel.assert_called_once_with(event.channel) - event_factory.deserialize_guild_channel_create_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "set_guild_channel") as patched_set_guild_channel, + mock.patch.object( + event_factory, "deserialize_guild_channel_create_event", return_value=event + ) as patched_deserialize_guild_channel_create_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_channel_create(shard, payload) + + patched_set_guild_channel.assert_called_once_with(event.channel) + patched_deserialize_guild_channel_create_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) + + def test_on_channel_create_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - def test_on_channel_create_stateless(self, stateless_event_manager_impl, shard, event_factory): - payload = {} + with ( + mock.patch.object( + event_factory, "deserialize_guild_channel_create_event" + ) as patched_deserialize_guild_channel_create_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_channel_create(shard, payload) - stateless_event_manager_impl.on_channel_create(shard, payload) + patched_deserialize_guild_channel_create_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_channel_create_event.return_value) - event_factory.deserialize_guild_channel_create_event.assert_called_once_with(shard, payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_channel_create_event.return_value - ) - - def test_on_channel_update_stateful(self, event_manager_impl, shard, event_factory): - payload = {"id": 123} + def test_on_channel_update_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {"id": 123} old_channel = object() event = mock.Mock(channel=mock.Mock(channels.GuildChannel)) - event_factory.deserialize_guild_channel_update_event.return_value = event - event_manager_impl._cache.get_guild_channel.return_value = old_channel - - event_manager_impl.on_channel_update(shard, payload) - - event_manager_impl._cache.get_guild_channel.assert_called_once_with(123) - event_manager_impl._cache.update_guild_channel.assert_called_once_with(event.channel) - event_factory.deserialize_guild_channel_update_event.assert_called_once_with( - shard, payload, old_channel=old_channel - ) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_channel_update_stateless(self, stateless_event_manager_impl, shard, event_factory): + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object( + patched__cache, "get_guild_channel", return_value=old_channel + ) as patched_get_guild_channel, + mock.patch.object(patched__cache, "update_guild_channel") as patched_update_guild_channel, + mock.patch.object( + event_factory, "deserialize_guild_channel_update_event", return_value=event + ) as patched_deserialize_guild_channel_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_channel_update(shard, payload) + + patched_get_guild_channel.assert_called_once_with(123) + patched_update_guild_channel.assert_called_once_with(event.channel) + patched_deserialize_guild_channel_update_event.assert_called_once_with(shard, payload, old_channel=old_channel) + patched_dispatch.assert_called_once_with(event) + + def test_on_channel_update_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"id": 123} - stateless_event_manager_impl.on_channel_update(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_channel_update_event" + ) as patched_deserialize_guild_channel_update_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_channel_update(shard, payload) - event_factory.deserialize_guild_channel_update_event.assert_called_once_with(shard, payload, old_channel=None) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_channel_update_event.return_value - ) + patched_deserialize_guild_channel_update_event.assert_called_once_with(shard, payload, old_channel=None) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_channel_update_event.return_value) - def test_on_channel_delete_stateful(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_channel_delete_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock(channel=mock.Mock(id=123)) - event_factory.deserialize_guild_channel_delete_event.return_value = event - - event_manager_impl.on_channel_delete(shard, payload) - - event_manager_impl._cache.delete_guild_channel.assert_called_once_with(123) - event_factory.deserialize_guild_channel_delete_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_channel_delete_stateless(self, stateless_event_manager_impl, shard, event_factory): - payload = {} + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "delete_guild_channel") as patched_delete_guild_channel, + mock.patch.object( + event_factory, "deserialize_guild_channel_delete_event", return_value=event + ) as patched_deserialize_guild_channel_delete_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_channel_delete(shard, payload) + + patched_delete_guild_channel.assert_called_once_with(123) + patched_deserialize_guild_channel_delete_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) + + def test_on_channel_delete_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - stateless_event_manager_impl.on_channel_delete(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_channel_delete_event" + ) as patched_deserialize_guild_channel_delete_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_channel_delete(shard, payload) - event_factory.deserialize_guild_channel_delete_event.assert_called_once_with(shard, payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_channel_delete_event.return_value - ) + patched_deserialize_guild_channel_delete_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_channel_delete_event.return_value) - def test_on_channel_pins_update(self, stateless_event_manager_impl, shard, event_factory): - payload = {} + def test_on_channel_pins_update( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - stateless_event_manager_impl.on_channel_pins_update(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_channel_pins_update_event" + ) as patched_deserialize_channel_pins_update_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_channel_pins_update(shard, payload) - event_factory.deserialize_channel_pins_update_event.assert_called_once_with(shard, payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_channel_pins_update_event.return_value - ) + patched_deserialize_channel_pins_update_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(patched_deserialize_channel_pins_update_event.return_value) def test_on_thread_create_when_create_stateful( self, event_manager_impl: event_manager.EventManagerImpl, shard: mock.Mock, event_factory: mock.Mock ): mock_payload = {"id": "123321", "newly_created": True} - event_manager_impl.on_thread_create(shard, mock_payload) - event = event_factory.deserialize_guild_thread_create_event.return_value - event_manager_impl._cache.set_thread.assert_called_once_with(event.thread) - event_manager_impl.dispatch.assert_called_once_with(event) - event_factory.deserialize_guild_thread_create_event.assert_called_once_with(shard, mock_payload) + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "set_thread") as patched_set_thread, + mock.patch.object( + event_factory, "deserialize_guild_thread_create_event" + ) as patched_deserialize_guild_thread_create_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_thread_create(shard, mock_payload) + + event = patched_deserialize_guild_thread_create_event.return_value + patched_set_thread.assert_called_once_with(event.thread) + patched_dispatch.assert_called_once_with(event) + patched_deserialize_guild_thread_create_event.assert_called_once_with(shard, mock_payload) def test_on_thread_create_stateless( self, stateless_event_manager_impl: event_manager.EventManagerImpl, shard: mock.Mock, event_factory: mock.Mock ): mock_payload = {"id": "123321", "newly_created": True} - stateless_event_manager_impl.on_thread_create(shard, mock_payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_thread_create_event.return_value - ) - event_factory.deserialize_guild_thread_create_event.assert_called_once_with(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_thread_create_event" + ) as patched_deserialize_guild_thread_create_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_thread_create(shard, mock_payload) + + patched_dispatch.assert_called_once_with(patched_deserialize_guild_thread_create_event.return_value) + patched_deserialize_guild_thread_create_event.assert_called_once_with(shard, mock_payload) def test_on_thread_create_for_access_stateful( self, event_manager_impl: event_manager.EventManagerImpl, shard: mock.Mock, event_factory: mock.Mock ): mock_payload = {"id": "123321"} - event_manager_impl.on_thread_create(shard, mock_payload) - event = event_factory.deserialize_guild_thread_access_event.return_value - event_manager_impl._cache.set_thread.assert_called_once_with(event.thread) - event_manager_impl.dispatch.assert_called_once_with(event) - event_factory.deserialize_guild_thread_access_event.assert_called_once_with(shard, mock_payload) + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "set_thread") as patched_set_thread, + mock.patch.object( + event_factory, "deserialize_guild_thread_access_event" + ) as patched_deserialize_guild_thread_access_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_thread_create(shard, mock_payload) + + event = patched_deserialize_guild_thread_access_event.return_value + patched_set_thread.assert_called_once_with(event.thread) + patched_dispatch.assert_called_once_with(event) + patched_deserialize_guild_thread_access_event.assert_called_once_with(shard, mock_payload) def test_on_thread_create_for_access_stateless( self, stateless_event_manager_impl: event_manager.EventManagerImpl, shard: mock.Mock, event_factory: mock.Mock ): mock_payload = {"id": "123321"} - stateless_event_manager_impl.on_thread_create(shard, mock_payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_thread_access_event.return_value - ) - event_factory.deserialize_guild_thread_access_event.assert_called_once_with(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_thread_access_event" + ) as patched_deserialize_guild_thread_access_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_thread_create(shard, mock_payload) + + patched_dispatch.assert_called_once_with(patched_deserialize_guild_thread_access_event.return_value) + patched_deserialize_guild_thread_access_event.assert_called_once_with(shard, mock_payload) def test_on_thread_update_stateful( self, event_manager_impl: event_manager.EventManagerImpl, shard: mock.Mock, event_factory: mock.Mock @@ -271,167 +413,259 @@ def test_on_thread_update_stateful( event = mock.Mock(thread=mock.Mock(channels.GuildThreadChannel)) event_factory.deserialize_guild_thread_update_event.return_value = event - event_manager_impl._cache.get_thread.return_value = old_thread - - event_manager_impl.on_thread_update(shard, payload) - event_manager_impl._cache.get_thread.assert_called_once_with(123) - event_manager_impl._cache.update_thread.assert_called_once_with(event.thread) - event_factory.deserialize_guild_thread_update_event.assert_called_once_with( - shard, payload, old_thread=old_thread - ) - event_manager_impl.dispatch.assert_called_once_with(event) + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "get_thread", return_value=old_thread) as patched_get_thread, + mock.patch.object(patched__cache, "update_thread") as patched_update_thread, + mock.patch.object( + event_factory, "deserialize_guild_thread_update_event", return_value=event + ) as patched_deserialize_guild_thread_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_thread_update(shard, payload) + + patched_get_thread.assert_called_once_with(123) + patched_update_thread.assert_called_once_with(event.thread) + patched_deserialize_guild_thread_update_event.assert_called_once_with(shard, payload, old_thread=old_thread) + patched_dispatch.assert_called_once_with(event) def test_on_thread_update_stateless( self, stateless_event_manager_impl: event_manager.EventManagerImpl, shard: mock.Mock, event_factory: mock.Mock ): payload = {"id": 123} - stateless_event_manager_impl.on_thread_update(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_thread_update_event" + ) as patched_deserialize_guild_thread_update_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_thread_update(shard, payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_thread_update_event.return_value - ) - event_factory.deserialize_guild_thread_update_event.assert_called_once_with(shard, payload, old_thread=None) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_thread_update_event.return_value) + patched_deserialize_guild_thread_update_event.assert_called_once_with(shard, payload, old_thread=None) def test_on_thread_delete_stateful( self, event_manager_impl: event_manager.EventManagerImpl, shard: mock.Mock, event_factory: mock.Mock ): mock_payload = mock.Mock() - event_manager_impl.on_thread_delete(shard, mock_payload) - event = event_factory.deserialize_guild_thread_delete_event.return_value - event_manager_impl._cache.delete_thread.assert_called_once_with(event.thread_id) - event_manager_impl.dispatch.assert_called_once_with(event) - event_factory.deserialize_guild_thread_delete_event.assert_called_once_with(shard, mock_payload) + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "delete_thread") as patched_delete_thread, + mock.patch.object( + event_factory, "deserialize_guild_thread_delete_event" + ) as patched_deserialize_guild_thread_delete_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_thread_delete(shard, mock_payload) + + event = patched_deserialize_guild_thread_delete_event.return_value + patched_delete_thread.assert_called_once_with(event.thread_id) + patched_dispatch.assert_called_once_with(event) + patched_deserialize_guild_thread_delete_event.assert_called_once_with(shard, mock_payload) def test_on_thread_delete_stateless( self, stateless_event_manager_impl: event_manager.EventManagerImpl, shard: mock.Mock, event_factory: mock.Mock ): mock_payload = mock.Mock() - stateless_event_manager_impl.on_thread_delete(shard, mock_payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_thread_delete_event.return_value - ) - event_factory.deserialize_guild_thread_delete_event.assert_called_once_with(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_thread_delete_event" + ) as patched_deserialize_guild_thread_delete_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_thread_delete(shard, mock_payload) + + patched_dispatch.assert_called_once_with(patched_deserialize_guild_thread_delete_event.return_value) + patched_deserialize_guild_thread_delete_event.assert_called_once_with(shard, mock_payload) def test_on_thread_list_sync_stateful_when_channel_ids( self, event_manager_impl: event_manager.EventManagerImpl, shard: mock.Mock, event_factory: mock.Mock ): - event = event_factory.deserialize_thread_list_sync_event.return_value + event = mock.Mock() event.channel_ids = ["1", "2"] event.threads = {1: "thread1"} mock_payload = mock.Mock() - event_manager_impl.on_thread_list_sync(shard, mock_payload) - assert event_manager_impl._cache.clear_threads_for_channel.call_count == 2 - event_manager_impl._cache.clear_threads_for_channel.assert_has_calls( + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "set_thread") as patched_set_thread, + mock.patch.object(patched__cache, "clear_threads_for_channel") as patched_clear_threads_for_channel, + mock.patch.object( + event_factory, "deserialize_thread_list_sync_event", return_value=event + ) as patched_deserialize_thread_list_sync_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_thread_list_sync(shard, mock_payload) + + assert patched_clear_threads_for_channel.call_count == 2 + patched_clear_threads_for_channel.assert_has_calls( [mock.call(event.guild_id, "1"), mock.call(event.guild_id, "2")] ) - event_manager_impl._cache.set_thread("thread1") - event_manager_impl.dispatch.assert_called_once_with(event) - event_factory.deserialize_thread_list_sync_event.assert_called_once_with(shard, mock_payload) + patched_set_thread("thread1") + patched_dispatch.assert_called_once_with(event) + patched_deserialize_thread_list_sync_event.assert_called_once_with(shard, mock_payload) def test_on_thread_list_sync_stateful_when_not_channel_ids( self, event_manager_impl: event_manager.EventManagerImpl, shard: mock.Mock, event_factory: mock.Mock ): - event = event_factory.deserialize_thread_list_sync_event.return_value + event = mock.Mock() event.channel_ids = None event.threads = {1: "thread1"} mock_payload = mock.Mock() - event_manager_impl.on_thread_list_sync(shard, mock_payload) - event_manager_impl._cache.clear_threads_for_guild.assert_called_once_with(event.guild_id) - event_manager_impl._cache.set_thread("thread1") - event_manager_impl.dispatch.assert_called_once_with(event) - event_factory.deserialize_thread_list_sync_event.assert_called_once_with(shard, mock_payload) + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "set_thread") as patched_set_thread, + mock.patch.object(patched__cache, "clear_threads_for_guild") as patched_clear_threads_for_guild, + mock.patch.object( + event_factory, "deserialize_thread_list_sync_event", return_value=event + ) as patched_deserialize_thread_list_sync_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_thread_list_sync(shard, mock_payload) + + patched_clear_threads_for_guild.assert_called_once_with(event.guild_id) + patched_set_thread("thread1") + patched_dispatch.assert_called_once_with(event) + patched_deserialize_thread_list_sync_event.assert_called_once_with(shard, mock_payload) def test_on_thread_list_sync_stateless( self, stateless_event_manager_impl: event_manager.EventManagerImpl, shard: mock.Mock, event_factory: mock.Mock ): mock_payload = mock.Mock() - stateless_event_manager_impl.on_thread_list_sync(shard, mock_payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_thread_list_sync_event.return_value - ) - event_factory.deserialize_thread_list_sync_event.assert_called_once_with(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_thread_list_sync_event" + ) as patched_deserialize_thread_list_sync_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_thread_list_sync(shard, mock_payload) + + patched_dispatch.assert_called_once_with(patched_deserialize_thread_list_sync_event.return_value) + patched_deserialize_thread_list_sync_event.assert_called_once_with(shard, mock_payload) def test_on_thread_members_update_stateful_when_id_in_removed( self, event_manager_impl: event_manager.EventManagerImpl, shard: mock.Mock, event_factory: mock.Mock ): - event = event_factory.deserialize_thread_members_update_event.return_value + event = mock.Mock() event.removed_member_ids = [1, 2, 3] event.shard.get_user_id.return_value = 1 mock_payload = mock.Mock() - event_manager_impl.on_thread_members_update(shard, mock_payload) - event_manager_impl._cache.delete_thread.assert_called_once_with(event.thread_id) - event_manager_impl.dispatch.assert_called_once_with(event) - event_factory.deserialize_thread_members_update_event.assert_called_once_with(shard, mock_payload) + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "delete_thread") as patched_delete_thread, + mock.patch.object( + event_factory, "deserialize_thread_members_update_event", return_value=event + ) as patched_deserialize_thread_members_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_thread_members_update(shard, mock_payload) + + patched_delete_thread.assert_called_once_with(event.thread_id) + patched_dispatch.assert_called_once_with(event) + patched_deserialize_thread_members_update_event.assert_called_once_with(shard, mock_payload) def test_on_thread_members_update_stateful_when_id_not_in_removed( self, event_manager_impl: event_manager.EventManagerImpl, shard: mock.Mock, event_factory: mock.Mock ): - event = event_factory.deserialize_thread_members_update_event.return_value + event = mock.Mock() event.removed_member_ids = [1, 2, 3] event.shard.get_user_id.return_value = 69 mock_payload = mock.Mock() - event_manager_impl.on_thread_members_update(shard, mock_payload) - event_manager_impl._cache.delete_thread.assert_not_called() - event_manager_impl.dispatch.assert_called_once_with(event) - event_factory.deserialize_thread_members_update_event.assert_called_once_with(shard, mock_payload) + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "delete_thread") as patched_delete_thread, + mock.patch.object( + event_factory, "deserialize_thread_members_update_event", return_value=event + ) as patched_deserialize_thread_members_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_thread_members_update(shard, mock_payload) + + patched_delete_thread.assert_not_called() + patched_dispatch.assert_called_once_with(event) + patched_deserialize_thread_members_update_event.assert_called_once_with(shard, mock_payload) def test_on_thread_members_update_stateless( self, stateless_event_manager_impl: event_manager.EventManagerImpl, shard: mock.Mock, event_factory: mock.Mock ): mock_payload = mock.Mock() - stateless_event_manager_impl.on_thread_members_update(shard, mock_payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_thread_members_update_event.return_value - ) - event_factory.deserialize_thread_members_update_event.assert_called_once_with(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_thread_members_update_event" + ) as patched_deserialize_thread_members_update_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_thread_members_update(shard, mock_payload) - def test_on_guild_create_when_unavailable_guild(self, event_manager_impl, shard, event_factory, entity_factory): + patched_dispatch.assert_called_once_with(patched_deserialize_thread_members_update_event.return_value) + patched_deserialize_thread_members_update_event.assert_called_once_with(shard, mock_payload) + + def test_on_guild_create_when_unavailable_guild( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + entity_factory: entity_factory_impl.EntityFactoryImpl, + ): payload = {"unavailable": True} - event_manager_impl._cache_enabled_for = mock.Mock(return_value=True) - event_manager_impl._enabled_for_event = mock.Mock(return_value=True) - with mock.patch.object(event_manager, "_request_guild_members") as request_guild_members: + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + mock.patch.object( + event_manager_impl, "_enabled_for_event", return_value=True + ) as patched__enabled_for_event, + mock.patch.object(event_manager_impl, "_cache_enabled_for", return_value=True), + mock.patch.object( + entity_factory, "deserialize_guild_available_event" + ) as patched_deserialize_guild_available_event, + mock.patch.object(entity_factory, "deserialized_guild_join_event") as patched_deserialize_guild_join_event, + mock.patch.object(event_manager, "_request_guild_members") as patched__request_guild_members, + ): event_manager_impl.on_guild_create(shard, payload) - event_manager_impl._enabled_for_event.assert_not_called() - event_factory.deserialize_guild_available_event.assert_not_called() - event_factory.deserialize_guild_join_event.assert_not_called() - - event_manager_impl._cache.update_guild.assert_not_called() - event_manager_impl._cache.clear_guild_channels_for_guild.assert_not_called() - event_manager_impl._cache.set_guild_channel.assert_not_called() - event_manager_impl._cache.clear_emojis_for_guild.assert_not_called() - event_manager_impl._cache.set_emoji.assert_not_called() - event_manager_impl._cache.clear_stickers_for_guild.assert_not_called() - event_manager_impl._cache.set_sticker.assert_not_called() - event_manager_impl._cache.clear_roles_for_guild.assert_not_called() - event_manager_impl._cache.set_role.assert_not_called() - event_manager_impl._cache.clear_members_for_guild.assert_not_called() - event_manager_impl._cache.set_member.assert_not_called() - event_manager_impl._cache.clear_presences_for_guild.assert_not_called() - event_manager_impl._cache.set_presence.assert_not_called() - event_manager_impl._cache.clear_voice_states_for_guild.assert_not_called() - event_manager_impl._cache.set_voice_state.assert_not_called() - request_guild_members.assert_not_called() - - event_manager_impl.dispatch.assert_not_called() + patched__enabled_for_event.assert_not_called() + patched_deserialize_guild_available_event.assert_not_called() + patched_deserialize_guild_join_event.assert_not_called() + + patched__cache.update_guild.assert_not_called() + patched__cache.clear_guild_channels_for_guild.assert_not_called() + patched__cache.set_guild_channel.assert_not_called() + patched__cache.clear_emojis_for_guild.assert_not_called() + patched__cache.set_emoji.assert_not_called() + patched__cache.clear_stickers_for_guild.assert_not_called() + patched__cache.set_sticker.assert_not_called() + patched__cache.clear_roles_for_guild.assert_not_called() + patched__cache.set_role.assert_not_called() + patched__cache.clear_members_for_guild.assert_not_called() + patched__cache.set_member.assert_not_called() + patched__cache.clear_presences_for_guild.assert_not_called() + patched__cache.set_presence.assert_not_called() + patched__cache.clear_voice_states_for_guild.assert_not_called() + patched__cache.set_voice_state.assert_not_called() + patched__request_guild_members.assert_not_called() + + patched_dispatch.assert_not_called() @pytest.mark.asyncio @pytest.mark.parametrize("include_unavailable", [True, False]) async def test_on_guild_create_when_dispatching_and_not_caching( - self, event_manager_impl, shard, event_factory, entity_factory, include_unavailable + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + include_unavailable: bool, ): payload = {"unavailable": False} if include_unavailable else {} event_manager_impl._intents = intents.Intents.NONE @@ -473,7 +707,12 @@ async def test_on_guild_create_when_dispatching_and_not_caching( @pytest.mark.parametrize("include_unavailable", [True, False]) def test_on_guild_create_when_not_dispatching_and_not_caching( - self, event_manager_impl, shard, event_factory, entity_factory, include_unavailable + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + entity_factory: entity_factory_impl.EntityFactoryImpl, + include_unavailable: bool, ): payload = {"unavailable": False} if include_unavailable else {} event_manager_impl._intents = intents.Intents.NONE @@ -519,7 +758,13 @@ def test_on_guild_create_when_not_dispatching_and_not_caching( ("include_unavailable", "only_my_member"), [(True, True), (True, False), (False, True), (False, False)] ) def test_on_guild_create_when_not_dispatching_and_caching( - self, event_manager_impl, shard, event_factory, entity_factory, include_unavailable, only_my_member + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + entity_factory: entity_factory_impl.EntityFactoryImpl, + include_unavailable: bool, + only_my_member: bool, ): payload = {"unavailable": False} if include_unavailable else {} event_manager_impl._intents = intents.Intents.NONE @@ -578,50 +823,72 @@ def test_on_guild_create_when_not_dispatching_and_caching( @pytest.mark.parametrize("include_unavailable", [True, False]) def test_on_guild_create_when_stateless( - self, stateless_event_manager_impl, shard, event_factory, entity_factory, include_unavailable + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + entity_factory: entity_factory_impl.EntityFactoryImpl, + include_unavailable: bool, ): - payload = {"id": 123} + payload: dict[str, typing.Any] = {"id": 123} if include_unavailable: payload["unavailable"] = False - stateless_event_manager_impl._intents = intents.Intents.NONE - stateless_event_manager_impl._cache_enabled_for = mock.Mock(return_value=True) - stateless_event_manager_impl._enabled_for_event = mock.Mock(return_value=False) + with ( + mock.patch.object(event_factory, "deserialize_guild_join_event") as patched_deserialize_guild_join_event, + mock.patch.object( + event_factory, "deserialize_guild_available_event" + ) as patched_deserialize_guild_available_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl._intents = intents.Intents.NONE + stateless_event_manager_impl._cache_enabled_for = mock.Mock(return_value=True) + stateless_event_manager_impl._enabled_for_event = mock.Mock(return_value=False) with mock.patch.object(event_manager, "_request_guild_members") as request_guild_members: stateless_event_manager_impl.on_guild_create(shard, payload) - if include_unavailable: - stateless_event_manager_impl._enabled_for_event.assert_called_once_with(guild_events.GuildAvailableEvent) - else: - stateless_event_manager_impl._enabled_for_event.assert_called_once_with(guild_events.GuildJoinEvent) + if include_unavailable: + stateless_event_manager_impl._enabled_for_event.assert_called_once_with( + guild_events.GuildAvailableEvent + ) + else: + stateless_event_manager_impl._enabled_for_event.assert_called_once_with(guild_events.GuildJoinEvent) - event_factory.deserialize_guild_join_event.assert_not_called() - event_factory.deserialize_guild_available_event.assert_not_called() - request_guild_members.assert_not_called() + patched_deserialize_guild_join_event.assert_not_called() + patched_deserialize_guild_available_event.assert_not_called() + request_guild_members.assert_not_called() - stateless_event_manager_impl.dispatch.assert_not_called() + patched_dispatch.assert_not_called() def test_on_guild_create_when_members_declared_and_member_cache_enabled_but_only_my_member_not_enabled( - self, event_manager_impl, shard, event_factory, entity_factory + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + entity_factory: entity_factory_impl.EntityFactoryImpl, ): - def cache_enabled_for_members_only(component): + def cache_enabled_for_members_only(component: typing.Any): return component == config.CacheComponents.MEMBERS - shard.id = 123 - event_manager_impl._cache.settings.only_my_member = False event_manager_impl._intents = intents.Intents.GUILD_MEMBERS event_manager_impl._cache_enabled_for = cache_enabled_for_members_only event_manager_impl._enabled_for_event = mock.Mock(return_value=False) - gateway_guild = entity_factory.deserialize_gateway_guild.return_value + gateway_guild = mock.Mock() gateway_guild.id = 456 gateway_guild.members.return_value = {1: "member1", 2: "member2"} mock_request_guild_members = mock.Mock() - with mock.patch.object(asyncio, "create_task") as create_task: - with mock.patch.object(event_manager, "_fixed_size_nonce", return_value="abc"): - with mock.patch.object(event_manager, "_request_guild_members", new=mock_request_guild_members): - event_manager_impl.on_guild_create(shard, {"id": 456, "large": False}) + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache.settings, "only_my_member", False), + mock.patch.object(asyncio, "create_task") as create_task, + mock.patch.object(event_manager, "_fixed_size_nonce", return_value="abc"), + mock.patch.object(event_manager, "_request_guild_members", new=mock_request_guild_members), + mock.patch.object(entity_factory, "deserialize_gateway_guild", return_value=gateway_guild), + mock.patch.object(shard, "id", 123), + ): + event_manager_impl.on_guild_create(shard, {"id": 456, "large": False}) mock_request_guild_members.assert_called_once_with(shard, 456, include_presences=False, nonce="123.abc") create_task.assert_called_once_with( @@ -629,85 +896,123 @@ def cache_enabled_for_members_only(component): ) def test_on_guild_create_when_members_declared_and_member_cache_but_only_my_member_enabled( - self, event_manager_impl, shard, event_factory, entity_factory + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + entity_factory: entity_factory_impl.EntityFactoryImpl, ): - def cache_enabled_for_members_only(component): + def cache_enabled_for_members_only(component: typing.Any): return component == config.CacheComponents.MEMBERS - shard.id = 123 - shard.get_user_id.return_value = 1 - event_manager_impl._cache.settings.only_my_member = True event_manager_impl._intents = intents.Intents.GUILD_MEMBERS event_manager_impl._cache_enabled_for = cache_enabled_for_members_only event_manager_impl._enabled_for_event = mock.Mock(return_value=False) - gateway_guild = entity_factory.deserialize_gateway_guild.return_value + gateway_guild = mock.Mock() gateway_guild.members.return_value = {1: "member1", 2: "member2"} mock_request_guild_members = mock.Mock() - with mock.patch.object(asyncio, "create_task") as create_task: - with mock.patch.object(event_manager, "_fixed_size_nonce", return_value="abc"): - with mock.patch.object(event_manager, "_request_guild_members", new=mock_request_guild_members): - event_manager_impl.on_guild_create(shard, {"id": 456, "large": False}) + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache.settings, "only_my_member", True), + mock.patch.object(asyncio, "create_task") as create_task, + mock.patch.object(event_manager, "_fixed_size_nonce", return_value="abc"), + mock.patch.object(event_manager, "_request_guild_members", new=mock_request_guild_members), + mock.patch.object(entity_factory, "deserialize_gateway_guild", return_value=gateway_guild), + mock.patch.object(shard, "id", 123), + mock.patch.object(shard, "get_user_id", return_value=1), + ): + event_manager_impl.on_guild_create(shard, {"id": 456, "large": False}) mock_request_guild_members.assert_not_called() create_task.assert_not_called() def test_on_guild_create_when_members_declared_and_enabled_for_member_chunk_event( - self, stateless_event_manager_impl, shard, event_factory, entity_factory + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + entity_factory: entity_factory_impl.EntityFactoryImpl, ): - shard.id = 123 stateless_event_manager_impl._intents = intents.Intents.GUILD_MEMBERS stateless_event_manager_impl._cache_enabled_for = mock.Mock(return_value=False) stateless_event_manager_impl._enabled_for_event = mock.Mock(return_value=True) mock_event = mock.Mock() mock_event.guild.id = 456 - event_factory.deserialize_guild_join_event.return_value = mock_event mock_request_guild_members = mock.Mock() - with mock.patch.object(asyncio, "create_task") as create_task: - with mock.patch.object(event_manager, "_fixed_size_nonce", return_value="abc"): - with mock.patch.object(event_manager, "_request_guild_members", new=mock_request_guild_members): - stateless_event_manager_impl.on_guild_create(shard, {"large": True}) + with ( + mock.patch.object(asyncio, "create_task") as create_task, + mock.patch.object(event_manager, "_fixed_size_nonce", return_value="abc"), + mock.patch.object(event_manager, "_request_guild_members", new=mock_request_guild_members), + mock.patch.object(event_factory, "deserialize_guild_join_event", return_value=mock_event), + mock.patch.object(shard, "id", 123), + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_guild_create(shard, {"large": True}) mock_request_guild_members.assert_called_once_with(shard, 456, include_presences=False, nonce="123.abc") create_task.assert_called_once_with( mock_request_guild_members.return_value, name="123:456 guild create members request" ) assert mock_event.chunk_nonce == "123.abc" - stateless_event_manager_impl.dispatch.assert_called_once_with(mock_event) + patched_dispatch.assert_called_once_with(mock_event) @pytest.mark.parametrize("cache_enabled", [True, False]) @pytest.mark.parametrize("large", [True, False]) @pytest.mark.parametrize("enabled_for_event", [True, False]) def test_on_guild_create_when_chunk_members_disabled( - self, stateless_event_manager_impl, shard, large, cache_enabled, enabled_for_event + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + large: bool, + cache_enabled: bool, + enabled_for_event: bool, ): - shard.id = 123 - stateless_event_manager_impl._intents = intents.Intents.GUILD_MEMBERS - stateless_event_manager_impl._cache_enabled_for = mock.Mock(return_value=cache_enabled) - stateless_event_manager_impl._enabled_for_event = mock.Mock(return_value=enabled_for_event) - stateless_event_manager_impl._auto_chunk_members = False + with mock.patch.object(shard, "id", 123): + stateless_event_manager_impl._intents = intents.Intents.GUILD_MEMBERS + stateless_event_manager_impl._cache_enabled_for = mock.Mock(return_value=cache_enabled) + stateless_event_manager_impl._enabled_for_event = mock.Mock(return_value=enabled_for_event) + stateless_event_manager_impl._auto_chunk_members = False with mock.patch.object(event_manager, "_request_guild_members") as request_guild_members: stateless_event_manager_impl.on_guild_create(shard, {"id": 456, "large": large}) - request_guild_members.assert_not_called() + request_guild_members.assert_not_called() - def test_on_guild_update_when_stateless(self, stateless_event_manager_impl, shard, event_factory, entity_factory): + def test_on_guild_update_when_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + entity_factory: entity_factory_impl.EntityFactoryImpl, + ): stateless_event_manager_impl._intents = intents.Intents.NONE stateless_event_manager_impl._cache_enabled_for = mock.Mock(return_value=True) stateless_event_manager_impl._enabled_for_event = mock.Mock(return_value=False) - stateless_event_manager_impl.on_guild_update(shard, {}) + with ( + mock.patch.object( + event_factory, "deserialize_guild_update_event" + ) as patched_deserialize_guild_update_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_guild_update(shard, {}) stateless_event_manager_impl._enabled_for_event.assert_called_once_with(guild_events.GuildUpdateEvent) - event_factory.deserialize_guild_update_event.assert_not_called() + patched_deserialize_guild_update_event.assert_not_called() - stateless_event_manager_impl.dispatch.assert_not_called() + patched_dispatch.assert_not_called() - def test_on_guild_update_stateful_and_dispatching(self, event_manager_impl, shard, event_factory, entity_factory): + def test_on_guild_update_stateful_and_dispatching( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + entity_factory: entity_factory_impl.EntityFactoryImpl, + ): payload = {"id": 123} old_guild = object() mock_role = object() @@ -738,7 +1043,11 @@ def test_on_guild_update_stateful_and_dispatching(self, event_manager_impl, shar shard.get_user_id.assert_not_called() def test_on_guild_update_all_cache_components_and_not_dispatching( - self, event_manager_impl, shard, event_factory, entity_factory + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + entity_factory: entity_factory_impl.EntityFactoryImpl, ): payload = {"id": 123} mock_role = object() @@ -772,7 +1081,11 @@ def test_on_guild_update_all_cache_components_and_not_dispatching( guild_definition.guild.assert_called_once_with() def test_on_guild_update_no_cache_components_and_not_dispatching( - self, event_manager_impl, shard, event_factory, entity_factory + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + entity_factory: entity_factory_impl.EntityFactoryImpl, ): payload = {"id": 123} event_manager_impl._cache_enabled_for = mock.Mock(return_value=False) @@ -800,22 +1113,35 @@ def test_on_guild_update_no_cache_components_and_not_dispatching( shard.get_user_id.assert_called_once_with() def test_on_guild_update_stateless_and_dispatching( - self, stateless_event_manager_impl, shard, event_factory, entity_factory + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + entity_factory: entity_factory_impl.EntityFactoryImpl, ): - payload = {"id": 123} + payload: dict[str, typing.Any] = {"id": 123} stateless_event_manager_impl._enabled_for_event = mock.Mock(return_value=True) - stateless_event_manager_impl.on_guild_update(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_update_event" + ) as patched_deserialize_guild_update_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_guild_update(shard, payload) stateless_event_manager_impl._enabled_for_event.assert_called_once_with(guild_events.GuildUpdateEvent) - shard.get_user_id.deserialize_gateway_guild.assert_not_called() - shard.user_id.assert_not_called() - event_factory.deserialize_guild_update_event.assert_called_once_with(shard, payload, old_guild=None) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_update_event.return_value - ) + shard.get_user_id.deserialize_gateway_guild.assert_not_called() # FIXME: I don't think this is logically correct. + shard.user_id.assert_not_called() # FIXME: This does not seem to even exist. + patched_deserialize_guild_update_event.assert_called_once_with(shard, payload, old_guild=None) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_update_event.return_value) - def test_on_guild_delete_stateful_when_available(self, event_manager_impl, shard, event_factory): + def test_on_guild_delete_stateful_when_available( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"unavailable": False, "id": "123"} event = mock.Mock(guild_id=123) @@ -838,753 +1164,1303 @@ def test_on_guild_delete_stateful_when_available(self, event_manager_impl, shard ) event_manager_impl.dispatch.assert_called_once_with(event) - def test_on_guild_delete_stateful_when_unavailable(self, event_manager_impl, shard, event_factory): + def test_on_guild_delete_stateful_when_unavailable( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"unavailable": True, "id": "123"} event = mock.Mock(guild_id=123) - event_factory.deserialize_guild_unavailable_event.return_value = event - - event_manager_impl.on_guild_delete(shard, payload) - - event_manager_impl._cache.set_guild_availability.assert_called_once_with(event.guild_id, False) - event_factory.deserialize_guild_unavailable_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_guild_delete_stateless_when_available(self, stateless_event_manager_impl, shard, event_factory): + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "set_guild_availability") as patched_set_guild_availability, + mock.patch.object( + event_factory, "deserialize_guild_unavailable_event", return_value=event + ) as patched_deserialize_guild_unavailable_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_delete(shard, payload) + + patched_set_guild_availability.assert_called_once_with(event.guild_id, False) + patched_deserialize_guild_unavailable_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) + + def test_on_guild_delete_stateless_when_available( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"unavailable": False, "id": "123"} - stateless_event_manager_impl.on_guild_delete(shard, payload) + with ( + mock.patch.object(event_factory, "deserialize_guild_leave_event") as patched_deserialize_guild_leave_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_guild_delete(shard, payload) - event_factory.deserialize_guild_leave_event.assert_called_once_with(shard, payload, old_guild=None) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_leave_event.return_value - ) + patched_deserialize_guild_leave_event.assert_called_once_with(shard, payload, old_guild=None) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_leave_event.return_value) - def test_on_guild_delete_stateless_when_unavailable(self, stateless_event_manager_impl, shard, event_factory): + def test_on_guild_delete_stateless_when_unavailable( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"unavailable": True} - stateless_event_manager_impl.on_guild_delete(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_unavailable_event" + ) as patched_deserialize_guild_unavailable_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_guild_delete(shard, payload) - event_factory.deserialize_guild_unavailable_event.assert_called_once_with(shard, payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_unavailable_event.return_value - ) + patched_deserialize_guild_unavailable_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_unavailable_event.return_value) - def test_on_guild_ban_add(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_guild_ban_add( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock() - event_factory.deserialize_guild_ban_add_event.return_value = event + with ( + mock.patch.object( + event_factory, "deserialize_guild_ban_add_event", return_value=event + ) as patched_deserialize_guild_ban_add_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_ban_add(shard, payload) - event_manager_impl.on_guild_ban_add(shard, payload) + patched_deserialize_guild_ban_add_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) - event_factory.deserialize_guild_ban_add_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_guild_ban_remove(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_guild_ban_remove( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock() - event_factory.deserialize_guild_ban_remove_event.return_value = event + with ( + mock.patch.object( + event_factory, "deserialize_guild_ban_remove_event", return_value=event + ) as patched_deserialize_guild_ban_remove_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_ban_remove(shard, payload) - event_manager_impl.on_guild_ban_remove(shard, payload) - - event_factory.deserialize_guild_ban_remove_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) + patched_deserialize_guild_ban_remove_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) - def test_on_guild_emojis_update_stateful(self, event_manager_impl, shard, event_factory): + def test_on_guild_emojis_update_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"guild_id": 123} old_emojis = {"Test": 123} - mock_emoji = object() + mock_emoji = mock.Mock() event = mock.Mock(emojis=[mock_emoji], guild_id=123) - event_factory.deserialize_guild_emojis_update_event.return_value = event - event_manager_impl._cache.clear_emojis_for_guild.return_value = old_emojis - - event_manager_impl.on_guild_emojis_update(shard, payload) - - event_manager_impl._cache.clear_emojis_for_guild.assert_called_once_with(123) - event_manager_impl._cache.set_emoji.assert_called_once_with(mock_emoji) - event_factory.deserialize_guild_emojis_update_event.assert_called_once_with(shard, payload, old_emojis=[123]) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_guild_emojis_update_stateless(self, stateless_event_manager_impl, shard, event_factory): + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "set_emoji") as patched_set_emoji, + mock.patch.object( + patched__cache, "clear_emojis_for_guild", return_value=old_emojis + ) as patched_clear_emojis_for_guild, + mock.patch.object( + event_factory, "deserialize_guild_emojis_update_event", return_value=event + ) as patched_deserialize_guild_emojis_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_emojis_update(shard, payload) + + patched_clear_emojis_for_guild.assert_called_once_with(123) + patched_set_emoji.assert_called_once_with(mock_emoji) + patched_deserialize_guild_emojis_update_event.assert_called_once_with(shard, payload, old_emojis=[123]) + patched_dispatch.assert_called_once_with(event) + + def test_on_guild_emojis_update_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"guild_id": 123} - stateless_event_manager_impl.on_guild_emojis_update(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_emojis_update_event" + ) as patched_deserialize_guild_emojis_update_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_guild_emojis_update(shard, payload) - event_factory.deserialize_guild_emojis_update_event.assert_called_once_with(shard, payload, old_emojis=None) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_emojis_update_event.return_value - ) + patched_deserialize_guild_emojis_update_event.assert_called_once_with(shard, payload, old_emojis=None) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_emojis_update_event.return_value) - def test_on_guild_stickers_update_stateful(self, event_manager_impl, shard, event_factory): + def test_on_guild_stickers_update_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"guild_id": 720} old_stickers = {700: 123} - mock_sticker = object() + mock_sticker = mock.Mock() event = mock.Mock(stickers=[mock_sticker], guild_id=123) - event_factory.deserialize_guild_stickers_update_event.return_value = event - event_manager_impl._cache.clear_stickers_for_guild.return_value = old_stickers - - event_manager_impl.on_guild_stickers_update(shard, payload) - - event_manager_impl._cache.clear_stickers_for_guild.assert_called_once_with(720) - event_manager_impl._cache.set_sticker.assert_called_once_with(mock_sticker) - event_factory.deserialize_guild_stickers_update_event.assert_called_once_with( - shard, payload, old_stickers=[123] - ) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_guild_stickers_update_stateless(self, stateless_event_manager_impl, shard, event_factory): + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "set_sticker") as patched_set_sticker, + mock.patch.object( + patched__cache, "clear_stickers_for_guild", return_value=old_stickers + ) as patched_clear_stickers_for_guild, + mock.patch.object( + event_factory, "deserialize_guild_stickers_update_event", return_value=event + ) as patched_deserialize_guild_stickers_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_stickers_update(shard, payload) + + patched_clear_stickers_for_guild.assert_called_once_with(720) + patched_set_sticker.assert_called_once_with(mock_sticker) + patched_deserialize_guild_stickers_update_event.assert_called_once_with(shard, payload, old_stickers=[123]) + patched_dispatch.assert_called_once_with(event) + + def test_on_guild_stickers_update_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"guild_id": 123} - stateless_event_manager_impl.on_guild_stickers_update(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_stickers_update_event" + ) as patched_deserialize_guild_stickers_update_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_guild_stickers_update(shard, payload) - event_factory.deserialize_guild_stickers_update_event.assert_called_once_with(shard, payload, old_stickers=None) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_stickers_update_event.return_value - ) + patched_deserialize_guild_stickers_update_event.assert_called_once_with(shard, payload, old_stickers=None) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_stickers_update_event.return_value) - def test_on_guild_integrations_update(self, event_manager_impl, shard): - with pytest.raises(NotImplementedError): + def test_on_guild_integrations_update( + self, event_manager_impl: event_manager.EventManagerImpl, shard: shard_api.GatewayShard + ): + with mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, pytest.raises(NotImplementedError): event_manager_impl.on_guild_integrations_update(shard, {}) - event_manager_impl.dispatch.assert_not_called() + patched_dispatch.assert_not_called() - def test_on_integration_create(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_integration_create( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock() - event_factory.deserialize_integration_create_event.return_value = event + with ( + mock.patch.object( + event_factory, "deserialize_integration_create_event", return_value=event + ) as patched_deserialize_integration_create_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_integration_create(shard, payload) - event_manager_impl.on_integration_create(shard, payload) + patched_deserialize_integration_create_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) - event_factory.deserialize_integration_create_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_integration_delete(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_integration_delete( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock() - event_factory.deserialize_integration_delete_event.return_value = event - - event_manager_impl.on_integration_delete(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_integration_delete_event", return_value=event + ) as patched_deserialize_integration_delete_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_integration_delete(shard, payload) - event_factory.deserialize_integration_delete_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) + patched_deserialize_integration_delete_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) - def test_on_integration_update(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_integration_update( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock() - event_factory.deserialize_integration_update_event.return_value = event + with ( + mock.patch.object( + event_factory, "deserialize_integration_update_event", return_value=event + ) as patched_deserialize_integration_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_integration_update(shard, payload) - event_manager_impl.on_integration_update(shard, payload) + patched_deserialize_integration_update_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) - event_factory.deserialize_integration_update_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_guild_member_add_stateful(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_guild_member_add_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock(user=object(), member=object()) - event_factory.deserialize_guild_member_add_event.return_value = event - - event_manager_impl.on_guild_member_add(shard, payload) - - event_manager_impl._cache.update_member.assert_called_once_with(event.member) - event_factory.deserialize_guild_member_add_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "update_member") as patched_update_member, + mock.patch.object( + event_factory, "deserialize_guild_member_add_event", return_value=event + ) as patched_deserialize_guild_member_add_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_member_add(shard, payload) + + patched_update_member.assert_called_once_with(event.member) + patched_deserialize_guild_member_add_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) + + def test_on_guild_member_add_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - def test_on_guild_member_add_stateless(self, stateless_event_manager_impl, shard, event_factory): - payload = {} + with ( + mock.patch.object( + event_factory, "deserialize_guild_member_add_event" + ) as patched_deserialize_guild_member_add_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_guild_member_add(shard, payload) - stateless_event_manager_impl.on_guild_member_add(shard, payload) + patched_deserialize_guild_member_add_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_member_add_event.return_value) - event_factory.deserialize_guild_member_add_event.assert_called_once_with(shard, payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_member_add_event.return_value - ) - - def test_on_guild_member_remove_stateful(self, event_manager_impl, shard, event_factory): + def test_on_guild_member_remove_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"guild_id": "456", "user": {"id": "123"}} - event_manager_impl.on_guild_member_remove(shard, payload) + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "delete_member") as patched_delete_member, + mock.patch.object( + event_factory, "deserialize_guild_member_remove_event" + ) as patched_deserialize_guild_member_remove_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_member_remove(shard, payload) - event_manager_impl._cache.delete_member.assert_called_once_with(456, 123) - event_factory.deserialize_guild_member_remove_event.assert_called_once_with( - shard, payload, old_member=event_manager_impl._cache.delete_member.return_value - ) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_member_remove_event.return_value + patched_delete_member.assert_called_once_with(456, 123) + patched_deserialize_guild_member_remove_event.assert_called_once_with( + shard, payload, old_member=patched_delete_member.return_value ) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_member_remove_event.return_value) - def test_on_guild_member_remove_stateless(self, stateless_event_manager_impl, shard, event_factory): - payload = {} + def test_on_guild_member_remove_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - stateless_event_manager_impl.on_guild_member_remove(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_member_remove_event" + ) as patched_deserialize_guild_member_remove_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_guild_member_remove(shard, payload) - event_factory.deserialize_guild_member_remove_event.assert_called_once_with(shard, payload, old_member=None) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_member_remove_event.return_value - ) + patched_deserialize_guild_member_remove_event.assert_called_once_with(shard, payload, old_member=None) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_member_remove_event.return_value) - def test_on_guild_member_update_stateful(self, event_manager_impl, shard, event_factory): + def test_on_guild_member_update_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"user": {"id": 123}, "guild_id": 456} old_member = object() event = mock.Mock(member=mock.Mock()) - event_factory.deserialize_guild_member_update_event.return_value = event - event_manager_impl._cache.get_member.return_value = old_member - - event_manager_impl.on_guild_member_update(shard, payload) - - event_manager_impl._cache.get_member.assert_called_once_with(456, 123) - event_manager_impl._cache.update_member.assert_called_once_with(event.member) - event_factory.deserialize_guild_member_update_event.assert_called_once_with( - shard, payload, old_member=old_member - ) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_guild_member_update_stateless(self, stateless_event_manager_impl, shard, event_factory): + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "update_member") as patched_update_member, + mock.patch.object(patched__cache, "get_member", return_value=old_member) as patched_get_member, + mock.patch.object( + event_factory, "deserialize_guild_member_update_event", return_value=event + ) as patched_deserialize_guild_member_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_member_update(shard, payload) + + patched_get_member.assert_called_once_with(456, 123) + patched_update_member.assert_called_once_with(event.member) + patched_deserialize_guild_member_update_event.assert_called_once_with(shard, payload, old_member=old_member) + patched_dispatch.assert_called_once_with(event) + + def test_on_guild_member_update_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"user": {"id": 123}, "guild_id": 456} - stateless_event_manager_impl.on_guild_member_update(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_member_update_event" + ) as patched_deserialize_guild_member_update_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_guild_member_update(shard, payload) - event_factory.deserialize_guild_member_update_event.assert_called_once_with(shard, payload, old_member=None) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_member_update_event.return_value - ) + patched_deserialize_guild_member_update_event.assert_called_once_with(shard, payload, old_member=None) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_member_update_event.return_value) - def test_on_guild_members_chunk_stateful(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_guild_members_chunk_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock(members={"TestMember": 123}, presences={"TestPresences": 456}) - event_factory.deserialize_guild_member_chunk_event.return_value = event - - event_manager_impl.on_guild_members_chunk(shard, payload) - - event_manager_impl._cache.set_member.assert_called_once_with(123) - event_manager_impl._cache.set_presence.assert_called_once_with(456) - event_factory.deserialize_guild_member_chunk_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) - def test_on_guild_members_chunk_stateless(self, stateless_event_manager_impl, shard, event_factory): - payload = {} + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "set_member") as patched_set_member, + mock.patch.object(patched__cache, "set_presence") as patched_set_presence, + mock.patch.object( + event_factory, "deserialize_guild_member_chunk_event", return_value=event + ) as patched_deserialize_guild_member_chunk_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_members_chunk(shard, payload) + + patched_set_member.assert_called_once_with(123) + patched_set_presence.assert_called_once_with(456) + patched_deserialize_guild_member_chunk_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) + + def test_on_guild_members_chunk_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - stateless_event_manager_impl.on_guild_members_chunk(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_member_chunk_event" + ) as patched_deserialize_guild_member_chunk_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_guild_members_chunk(shard, payload) - event_factory.deserialize_guild_member_chunk_event.assert_called_once_with(shard, payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_member_chunk_event.return_value - ) + patched_deserialize_guild_member_chunk_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_member_chunk_event.return_value) - def test_on_guild_role_create_stateful(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_guild_role_create_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock(role=object()) - event_factory.deserialize_guild_role_create_event.return_value = event - - event_manager_impl.on_guild_role_create(shard, payload) - - event_manager_impl._cache.set_role.assert_called_once_with(event.role) - event_factory.deserialize_guild_role_create_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_guild_role_create_stateless(self, stateless_event_manager_impl, shard, event_factory): - payload = {} + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "set_role") as patched_set_role, + mock.patch.object( + event_factory, "deserialize_guild_role_create_event", return_value=event + ) as patched_deserialize_guild_role_create_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_role_create(shard, payload) + + patched_set_role.assert_called_once_with(event.role) + patched_deserialize_guild_role_create_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) + + def test_on_guild_role_create_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - stateless_event_manager_impl.on_guild_role_create(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_role_create_event" + ) as patched_deserialize_guild_role_create_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_guild_role_create(shard, payload) - event_factory.deserialize_guild_role_create_event.assert_called_once_with(shard, payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_role_create_event.return_value - ) + patched_deserialize_guild_role_create_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_role_create_event.return_value) - def test_on_guild_role_update_stateful(self, event_manager_impl, shard, event_factory): + def test_on_guild_role_update_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"role": {"id": 123}} old_role = object() event = mock.Mock(role=mock.Mock()) - event_factory.deserialize_guild_role_update_event.return_value = event - event_manager_impl._cache.get_role.return_value = old_role - - event_manager_impl.on_guild_role_update(shard, payload) - - event_manager_impl._cache.get_role.assert_called_once_with(123) - event_manager_impl._cache.update_role.assert_called_once_with(event.role) - event_factory.deserialize_guild_role_update_event.assert_called_once_with(shard, payload, old_role=old_role) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_guild_role_update_stateless(self, stateless_event_manager_impl, shard, event_factory): + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "get_role", return_value=old_role) as patched_get_role, + mock.patch.object(patched__cache, "update_role") as patched_update_role, + mock.patch.object( + event_factory, "deserialize_guild_role_update_event", return_value=event + ) as patched_deserialize_guild_role_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_role_update(shard, payload) + + patched_get_role.assert_called_once_with(123) + patched_update_role.assert_called_once_with(event.role) + patched_deserialize_guild_role_update_event.assert_called_once_with(shard, payload, old_role=old_role) + patched_dispatch.assert_called_once_with(event) + + def test_on_guild_role_update_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"role": {"id": 123}} - stateless_event_manager_impl.on_guild_role_update(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_role_update_event" + ) as patched_deserialize_guild_role_update_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_guild_role_update(shard, payload) - event_factory.deserialize_guild_role_update_event.assert_called_once_with(shard, payload, old_role=None) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_role_update_event.return_value - ) + patched_deserialize_guild_role_update_event.assert_called_once_with(shard, payload, old_role=None) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_role_update_event.return_value) - def test_on_guild_role_delete_stateful(self, event_manager_impl, shard, event_factory): + def test_on_guild_role_delete_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"role_id": "123"} - event_manager_impl.on_guild_role_delete(shard, payload) + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "delete_role") as patched_delete_role, + mock.patch.object( + event_factory, "deserialize_guild_role_delete_event" + ) as patched_deserialize_guild_role_delete_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_role_delete(shard, payload) - event_manager_impl._cache.delete_role.assert_called_once_with(123) - event_factory.deserialize_guild_role_delete_event.assert_called_once_with( - shard, payload, old_role=event_manager_impl._cache.delete_role.return_value - ) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_role_delete_event.return_value + patched_delete_role.assert_called_once_with(123) + patched_deserialize_guild_role_delete_event.assert_called_once_with( + shard, payload, old_role=patched_delete_role.return_value ) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_role_delete_event.return_value) - def test_on_guild_role_delete_stateless(self, stateless_event_manager_impl, shard, event_factory): - payload = {} + def test_on_guild_role_delete_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - stateless_event_manager_impl.on_guild_role_delete(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_role_delete_event" + ) as patched_deserialize_guild_role_delete_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_guild_role_delete(shard, payload) - event_factory.deserialize_guild_role_delete_event.assert_called_once_with(shard, payload, old_role=None) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_role_delete_event.return_value - ) + patched_deserialize_guild_role_delete_event.assert_called_once_with(shard, payload, old_role=None) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_role_delete_event.return_value) - def test_on_invite_create_stateful(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_invite_create_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock(invite="qwerty") - event_factory.deserialize_invite_create_event.return_value = event - - event_manager_impl.on_invite_create(shard, payload) - - event_manager_impl._cache.set_invite.assert_called_once_with("qwerty") - event_factory.deserialize_invite_create_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_invite_create_stateless(self, stateless_event_manager_impl, shard, event_factory): - payload = {} + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "set_invite") as patched_set_invite, + mock.patch.object( + event_factory, "deserialize_invite_create_event", return_value=event + ) as patched_deserialize_invite_create_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_invite_create(shard, payload) + + patched_set_invite.assert_called_once_with("qwerty") + patched_deserialize_invite_create_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) + + def test_on_invite_create_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - stateless_event_manager_impl.on_invite_create(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_invite_create_event" + ) as patched_deserialize_invite_create_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_invite_create(shard, payload) - event_factory.deserialize_invite_create_event.assert_called_once_with(shard, payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_invite_create_event.return_value - ) + patched_deserialize_invite_create_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(patched_deserialize_invite_create_event.return_value) - def test_on_invite_delete_stateful(self, event_manager_impl, shard, event_factory): + def test_on_invite_delete_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"code": "qwerty"} - event_manager_impl.on_invite_delete(shard, payload) + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "delete_invite") as patched_delete_invite, + mock.patch.object( + event_factory, "deserialize_invite_delete_event" + ) as patched_deserialize_invite_delete_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_invite_delete(shard, payload) - event_manager_impl._cache.delete_invite.assert_called_once_with("qwerty") - event_factory.deserialize_invite_delete_event.assert_called_once_with( - shard, payload, old_invite=event_manager_impl._cache.delete_invite.return_value + patched_delete_invite.assert_called_once_with("qwerty") + patched_deserialize_invite_delete_event.assert_called_once_with( + shard, payload, old_invite=patched_delete_invite.return_value ) - event_manager_impl.dispatch.assert_called_once_with(event_factory.deserialize_invite_delete_event.return_value) + patched_dispatch.assert_called_once_with(patched_deserialize_invite_delete_event.return_value) - def test_on_invite_delete_stateless(self, stateless_event_manager_impl, shard, event_factory): - payload = {} + def test_on_invite_delete_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - stateless_event_manager_impl.on_invite_delete(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_invite_delete_event" + ) as patched_deserialize_invite_delete_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_invite_delete(shard, payload) - event_factory.deserialize_invite_delete_event.assert_called_once_with(shard, payload, old_invite=None) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_invite_delete_event.return_value - ) + patched_deserialize_invite_delete_event.assert_called_once_with(shard, payload, old_invite=None) + patched_dispatch.assert_called_once_with(patched_deserialize_invite_delete_event.return_value) - def test_on_message_create_stateful(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_message_create_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock(message=object()) - event_factory.deserialize_message_create_event.return_value = event - - event_manager_impl.on_message_create(shard, payload) - - event_manager_impl._cache.set_message.assert_called_once_with(event.message) - event_factory.deserialize_message_create_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_message_create_stateless(self, stateless_event_manager_impl, shard, event_factory): - payload = {} + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "set_message") as patched_set_message, + mock.patch.object( + event_factory, "deserialize_message_create_event", return_value=event + ) as patched_deserialize_message_create_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_message_create(shard, payload) + + patched_set_message.assert_called_once_with(event.message) + patched_deserialize_message_create_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) + + def test_on_message_create_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - stateless_event_manager_impl.on_message_create(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_message_create_event" + ) as patched_deserialize_message_create_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_message_create(shard, payload) - event_factory.deserialize_message_create_event.assert_called_once_with(shard, payload) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_message_create_event.return_value - ) + patched_deserialize_message_create_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(patched_deserialize_message_create_event.return_value) - def test_on_message_update_stateful(self, event_manager_impl, shard, event_factory): + def test_on_message_update_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"id": 123} old_message = object() event = mock.Mock(message=mock.Mock()) - event_factory.deserialize_message_update_event.return_value = event - event_manager_impl._cache.get_message.return_value = old_message - - event_manager_impl.on_message_update(shard, payload) - - event_manager_impl._cache.get_message.assert_called_once_with(123) - event_manager_impl._cache.update_message.assert_called_once_with(event.message) - event_factory.deserialize_message_update_event.assert_called_once_with(shard, payload, old_message=old_message) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_message_update_stateless(self, stateless_event_manager_impl, shard, event_factory): + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "update_message") as patched_update_message, + mock.patch.object(patched__cache, "get_message", return_value=old_message) as patched_get_message, + mock.patch.object( + event_factory, "deserialize_message_update_event", return_value=event + ) as patched_deserialize_message_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_message_update(shard, payload) + + patched_get_message.assert_called_once_with(123) + patched_update_message.assert_called_once_with(event.message) + patched_deserialize_message_update_event.assert_called_once_with(shard, payload, old_message=old_message) + patched_dispatch.assert_called_once_with(event) + + def test_on_message_update_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"id": 123} - stateless_event_manager_impl.on_message_update(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_message_update_event" + ) as patched_deserialize_message_update_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_message_update(shard, payload) - event_factory.deserialize_message_update_event.assert_called_once_with(shard, payload, old_message=None) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_message_update_event.return_value - ) + patched_deserialize_message_update_event.assert_called_once_with(shard, payload, old_message=None) + patched_dispatch.assert_called_once_with(patched_deserialize_message_update_event.return_value) - def test_on_message_delete_stateful(self, event_manager_impl, shard, event_factory): + def test_on_message_delete_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"id": 123} - event_manager_impl.on_message_delete(shard, payload) + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "delete_message") as patched_delete_message, + mock.patch.object( + event_factory, "deserialize_message_delete_event" + ) as patched_deserialize_message_delete_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_message_delete(shard, payload) - event_manager_impl._cache.delete_message.assert_called_once_with(123) - event_factory.deserialize_message_delete_event.assert_called_once_with( - shard, payload, old_message=event_manager_impl._cache.delete_message.return_value + patched_delete_message.assert_called_once_with(123) + patched_deserialize_message_delete_event.assert_called_once_with( + shard, payload, old_message=patched_delete_message.return_value ) - event_manager_impl.dispatch.assert_called_once_with(event_factory.deserialize_message_delete_event.return_value) + patched_dispatch.assert_called_once_with(patched_deserialize_message_delete_event.return_value) - def test_on_message_delete_stateless(self, stateless_event_manager_impl, shard, event_factory): - payload = {} + def test_on_message_delete_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - stateless_event_manager_impl.on_message_delete(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_message_delete_event" + ) as patched_deserialize_message_delete_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_message_delete(shard, payload) - event_factory.deserialize_message_delete_event.assert_called_once_with(shard, payload, old_message=None) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_message_delete_event.return_value - ) + patched_deserialize_message_delete_event.assert_called_once_with(shard, payload, old_message=None) + patched_dispatch.assert_called_once_with(patched_deserialize_message_delete_event.return_value) - def test_on_message_delete_bulk_stateful(self, event_manager_impl, shard, event_factory): + def test_on_message_delete_bulk_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"ids": [123, 456, 789, 987]} message1 = object() message2 = object() message3 = object() - event_manager_impl._cache.delete_message.side_effect = [message1, message2, message3, None] - event_manager_impl.on_message_delete_bulk(shard, payload) - - event_manager_impl._cache.delete_message.assert_has_calls( - [mock.call(123), mock.call(456), mock.call(789), mock.call(987)] - ) - event_factory.deserialize_guild_message_delete_bulk_event.assert_called_once_with( + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object( + patched__cache, "delete_message", side_effect=[message1, message2, message3, None] + ) as patched_delete_message, + mock.patch.object( + event_factory, "deserialize_guild_message_delete_bulk_event" + ) as patched_deserialize_guild_message_delete_bulk_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_message_delete_bulk(shard, payload) + + patched_delete_message.assert_has_calls([mock.call(123), mock.call(456), mock.call(789), mock.call(987)]) + patched_deserialize_guild_message_delete_bulk_event.assert_called_once_with( shard, payload, old_messages={123: message1, 456: message2, 789: message3} ) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_message_delete_bulk_event.return_value - ) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_message_delete_bulk_event.return_value) - def test_on_message_delete_bulk_stateless(self, stateless_event_manager_impl, shard, event_factory): - payload = {} + def test_on_message_delete_bulk_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - stateless_event_manager_impl.on_message_delete_bulk(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_guild_message_delete_bulk_event" + ) as patched_deserialize_guild_message_delete_bulk_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_message_delete_bulk(shard, payload) - event_factory.deserialize_guild_message_delete_bulk_event.assert_called_once_with( - shard, payload, old_messages={} - ) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_guild_message_delete_bulk_event.return_value - ) + patched_deserialize_guild_message_delete_bulk_event.assert_called_once_with(shard, payload, old_messages={}) + patched_dispatch.assert_called_once_with(patched_deserialize_guild_message_delete_bulk_event.return_value) - def test_on_message_reaction_add(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_message_reaction_add( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock() - event_factory.deserialize_message_reaction_add_event.return_value = event - - event_manager_impl.on_message_reaction_add(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_message_reaction_add_event", return_value=event + ) as patched_deserialize_message_reaction_add_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_message_reaction_add(shard, payload) - event_factory.deserialize_message_reaction_add_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) + patched_deserialize_message_reaction_add_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) - def test_on_message_reaction_remove(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_message_reaction_remove( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock() - event_factory.deserialize_message_reaction_remove_event.return_value = event + with ( + mock.patch.object( + event_factory, "deserialize_message_reaction_remove_event", return_value=event + ) as patched_deserialize_message_reaction_remove_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_message_reaction_remove(shard, payload) - event_manager_impl.on_message_reaction_remove(shard, payload) + patched_deserialize_message_reaction_remove_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) - event_factory.deserialize_message_reaction_remove_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_message_reaction_remove_all(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_message_reaction_remove_all( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock() - event_factory.deserialize_message_reaction_remove_all_event.return_value = event - - event_manager_impl.on_message_reaction_remove_all(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_message_reaction_remove_all_event", return_value=event + ) as patched_deserialize_message_reaction_remove_all_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_message_reaction_remove_all(shard, payload) - event_factory.deserialize_message_reaction_remove_all_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) + patched_deserialize_message_reaction_remove_all_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) - def test_on_message_reaction_remove_emoji(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_message_reaction_remove_emoji( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock() - event_factory.deserialize_message_reaction_remove_emoji_event.return_value = event - - event_manager_impl.on_message_reaction_remove_emoji(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_message_reaction_remove_emoji_event", return_value=event + ) as patched_deserialize_message_reaction_remove_emoji_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_message_reaction_remove_emoji(shard, payload) - event_factory.deserialize_message_reaction_remove_emoji_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) + patched_deserialize_message_reaction_remove_emoji_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) - def test_on_presence_update_stateful_update(self, event_manager_impl, shard, event_factory): + def test_on_presence_update_stateful_update( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"user": {"id": 123}, "guild_id": 456} old_presence = object() event = mock.Mock(presence=mock.Mock(visible_status=presences.Status.ONLINE)) - event_factory.deserialize_presence_update_event.return_value = event - event_manager_impl._cache.get_presence.return_value = old_presence - - event_manager_impl.on_presence_update(shard, payload) - - event_manager_impl._cache.get_presence.assert_called_once_with(456, 123) - event_manager_impl._cache.update_presence.assert_called_once_with(event.presence) - event_factory.deserialize_presence_update_event.assert_called_once_with( - shard, payload, old_presence=old_presence - ) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_presence_update_stateful_delete(self, event_manager_impl, shard, event_factory): + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "update_presence") as patched_update_presence, + mock.patch.object(patched__cache, "get_presence", return_value=old_presence) as patched_get_presence, + mock.patch.object( + event_factory, "deserialize_presence_update_event", return_value=event + ) as patched_deserialize_presence_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_presence_update(shard, payload) + + patched_get_presence.assert_called_once_with(456, 123) + patched_update_presence.assert_called_once_with(event.presence) + patched_deserialize_presence_update_event.assert_called_once_with(shard, payload, old_presence=old_presence) + patched_dispatch.assert_called_once_with(event) + + def test_on_presence_update_stateful_delete( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"user": {"id": 123}, "guild_id": 456} old_presence = object() event = mock.Mock(presence=mock.Mock(visible_status=presences.Status.OFFLINE)) - event_factory.deserialize_presence_update_event.return_value = event - event_manager_impl._cache.get_presence.return_value = old_presence - - event_manager_impl.on_presence_update(shard, payload) - - event_manager_impl._cache.get_presence.assert_called_once_with(456, 123) - event_manager_impl._cache.delete_presence.assert_called_once_with( - event.presence.guild_id, event.presence.user_id - ) - event_factory.deserialize_presence_update_event.assert_called_once_with( - shard, payload, old_presence=old_presence - ) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_presence_update_stateless(self, stateless_event_manager_impl, shard, event_factory): + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "delete_presence") as patched_delete_presence, + mock.patch.object(patched__cache, "get_presence", return_value=old_presence) as patched_get_presence, + mock.patch.object( + event_factory, "deserialize_presence_update_event", return_value=event + ) as patched_deserialize_presence_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_presence_update(shard, payload) + + patched_get_presence.assert_called_once_with(456, 123) + patched_delete_presence.assert_called_once_with(event.presence.guild_id, event.presence.user_id) + patched_deserialize_presence_update_event.assert_called_once_with(shard, payload, old_presence=old_presence) + patched_dispatch.assert_called_once_with(event) + + def test_on_presence_update_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"user": {"id": 123}, "guild_id": 456} - stateless_event_manager_impl.on_presence_update(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_presence_update_event" + ) as patched_deserialize_presence_update_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_presence_update(shard, payload) - event_factory.deserialize_presence_update_event.assert_called_once_with(shard, payload, old_presence=None) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_presence_update_event.return_value - ) + patched_deserialize_presence_update_event.assert_called_once_with(shard, payload, old_presence=None) + patched_dispatch.assert_called_once_with(patched_deserialize_presence_update_event.return_value) - def test_on_typing_start(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_typing_start( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock() - event_factory.deserialize_typing_start_event.return_value = event + with ( + mock.patch.object( + event_factory, "deserialize_typing_start_event", return_value=event + ) as patched_deserialize_typing_start_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_typing_start(shard, payload) - event_manager_impl.on_typing_start(shard, payload) + patched_deserialize_typing_start_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) - event_factory.deserialize_typing_start_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_user_update_stateful(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_user_update_stateful( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} old_user = object() event = mock.Mock(user=mock.Mock()) - event_factory.deserialize_own_user_update_event.return_value = event - event_manager_impl._cache.get_me.return_value = old_user - - event_manager_impl.on_user_update(shard, payload) - - event_manager_impl._cache.update_me.assert_called_once_with(event.user) - event_factory.deserialize_own_user_update_event.assert_called_once_with(shard, payload, old_user=old_user) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_user_update_stateless(self, stateless_event_manager_impl, shard, event_factory): - payload = {} + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "update_me") as patched_update_me, + mock.patch.object(patched__cache, "get_me", return_value=old_user) as patched_get_me, + mock.patch.object( + event_factory, "deserialize_own_user_update_event", return_value=event + ) as patched_deserialize_own_user_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_user_update(shard, payload) + + patched_get_me.assert_called_once() + patched_update_me.assert_called_once_with(event.user) + patched_deserialize_own_user_update_event.assert_called_once_with(shard, payload, old_user=old_user) + patched_dispatch.assert_called_once_with(event) + + def test_on_user_update_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} - stateless_event_manager_impl.on_user_update(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_own_user_update_event" + ) as patched_deserialize_own_user_update_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_user_update(shard, payload) - event_factory.deserialize_own_user_update_event.assert_called_once_with(shard, payload, old_user=None) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_own_user_update_event.return_value - ) + patched_deserialize_own_user_update_event.assert_called_once_with(shard, payload, old_user=None) + patched_dispatch.assert_called_once_with(patched_deserialize_own_user_update_event.return_value) - def test_on_voice_state_update_stateful_update(self, event_manager_impl, shard, event_factory): + def test_on_voice_state_update_stateful_update( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"user_id": 123, "guild_id": 456} old_state = object() event = mock.Mock(state=mock.Mock(channel_id=123)) - event_factory.deserialize_voice_state_update_event.return_value = event - event_manager_impl._cache.get_voice_state.return_value = old_state - - event_manager_impl.on_voice_state_update(shard, payload) - - event_manager_impl._cache.get_voice_state.assert_called_once_with(456, 123) - event_manager_impl._cache.update_voice_state.assert_called_once_with(event.state) - event_factory.deserialize_voice_state_update_event.assert_called_once_with(shard, payload, old_state=old_state) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_voice_state_update_stateful_delete(self, event_manager_impl, shard, event_factory): + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "update_voice_state") as patched_update_voice_state, + mock.patch.object(patched__cache, "get_voice_state", return_value=old_state) as patched_get_voice_state, + mock.patch.object( + event_factory, "deserialize_voice_state_update_event", return_value=event + ) as patched_deserialize_voice_state_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_voice_state_update(shard, payload) + + patched_get_voice_state.assert_called_once_with(456, 123) + patched_update_voice_state.assert_called_once_with(event.state) + patched_deserialize_voice_state_update_event.assert_called_once_with(shard, payload, old_state=old_state) + patched_dispatch.assert_called_once_with(event) + + def test_on_voice_state_update_stateful_delete( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"user_id": 123, "guild_id": 456} old_state = object() event = mock.Mock(state=mock.Mock(channel_id=None)) - event_factory.deserialize_voice_state_update_event.return_value = event - event_manager_impl._cache.get_voice_state.return_value = old_state - - event_manager_impl.on_voice_state_update(shard, payload) - - event_manager_impl._cache.get_voice_state.assert_called_once_with(456, 123) - event_manager_impl._cache.delete_voice_state.assert_called_once_with(event.state.guild_id, event.state.user_id) - event_factory.deserialize_voice_state_update_event.assert_called_once_with(shard, payload, old_state=old_state) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_voice_state_update_stateless(self, stateless_event_manager_impl, shard, event_factory): + with ( + mock.patch.object(event_manager_impl, "_cache") as patched__cache, + mock.patch.object(patched__cache, "delete_voice_state") as patched_delete_voice_state, + mock.patch.object(patched__cache, "get_voice_state", return_value=old_state) as patched_get_voice_state, + mock.patch.object( + event_factory, "deserialize_voice_state_update_event", return_value=event + ) as patched_deserialize_voice_state_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_voice_state_update(shard, payload) + + patched_get_voice_state.assert_called_once_with(456, 123) + patched_delete_voice_state.assert_called_once_with(event.state.guild_id, event.state.user_id) + patched_deserialize_voice_state_update_event.assert_called_once_with(shard, payload, old_state=old_state) + patched_dispatch.assert_called_once_with(event) + + def test_on_voice_state_update_stateless( + self, + stateless_event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"user_id": 123, "guild_id": 456} - stateless_event_manager_impl.on_voice_state_update(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_voice_state_update_event" + ) as patched_deserialize_voice_state_update_event, + mock.patch.object(stateless_event_manager_impl, "dispatch") as patched_dispatch, + ): + stateless_event_manager_impl.on_voice_state_update(shard, payload) - event_factory.deserialize_voice_state_update_event.assert_called_once_with(shard, payload, old_state=None) - stateless_event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_voice_state_update_event.return_value - ) + patched_deserialize_voice_state_update_event.assert_called_once_with(shard, payload, old_state=None) + patched_dispatch.assert_called_once_with(patched_deserialize_voice_state_update_event.return_value) - def test_on_voice_server_update(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_voice_server_update( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock() - event_factory.deserialize_voice_server_update_event.return_value = event + with ( + mock.patch.object( + event_factory, "deserialize_voice_server_update_event", return_value=event + ) as patched_deserialize_voice_server_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_voice_server_update(shard, payload) - event_manager_impl.on_voice_server_update(shard, payload) + patched_deserialize_voice_server_update_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) - event_factory.deserialize_voice_server_update_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) - - def test_on_webhooks_update(self, event_manager_impl, shard, event_factory): - payload = {} + def test_on_webhooks_update( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): + payload: dict[str, typing.Any] = {} event = mock.Mock() - event_factory.deserialize_webhook_update_event.return_value = event - - event_manager_impl.on_webhooks_update(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_webhook_update_event", return_value=event + ) as patched_deserialize_webhook_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_webhooks_update(shard, payload) - event_factory.deserialize_webhook_update_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with(event) + patched_deserialize_webhook_update_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(event) - def test_on_interaction_create(self, event_manager_impl, shard, event_factory): + def test_on_interaction_create( + self, + event_manager_impl: event_manager.EventManagerImpl, + shard: shard_api.GatewayShard, + event_factory: event_factory_impl.EventFactoryImpl, + ): payload = {"id": "123"} - event_manager_impl.on_interaction_create(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_interaction_create_event" + ) as patched_deserialize_interaction_create_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_interaction_create(shard, payload) - event_factory.deserialize_interaction_create_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_interaction_create_event.return_value - ) + patched_deserialize_interaction_create_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(patched_deserialize_interaction_create_event.return_value) def test_on_guild_scheduled_event_create( self, event_manager_impl: event_manager.EventManagerImpl, - shard: mock.Mock, + shard: shard_api.GatewayShard, event_factory: event_factory_.EventFactory, ): - mock_payload = mock.Mock() + mock_payload: dict[str, typing.Any] = mock.Mock() - event_manager_impl.on_guild_scheduled_event_create(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_scheduled_event_create_event" + ) as patched_deserialize_scheduled_event_create_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_scheduled_event_create(shard, mock_payload) - event_factory.deserialize_scheduled_event_create_event.assert_called_once_with(shard, mock_payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_scheduled_event_create_event.return_value - ) + patched_deserialize_scheduled_event_create_event.assert_called_once_with(shard, mock_payload) + patched_dispatch.assert_called_once_with(patched_deserialize_scheduled_event_create_event.return_value) def test_on_guild_scheduled_event_delete( self, event_manager_impl: event_manager.EventManagerImpl, - shard: mock.Mock, + shard: shard_api.GatewayShard, event_factory: event_factory_.EventFactory, ): - mock_payload = mock.Mock() + mock_payload: dict[str, typing.Any] = mock.Mock() - event_manager_impl.on_guild_scheduled_event_delete(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_scheduled_event_delete_event" + ) as patched_deserialize_scheduled_event_delete_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_scheduled_event_delete(shard, mock_payload) - event_factory.deserialize_scheduled_event_delete_event.assert_called_once_with(shard, mock_payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_scheduled_event_delete_event.return_value - ) + patched_deserialize_scheduled_event_delete_event.assert_called_once_with(shard, mock_payload) + patched_dispatch.assert_called_once_with(patched_deserialize_scheduled_event_delete_event.return_value) def test_on_guild_scheduled_event_update( self, event_manager_impl: event_manager.EventManagerImpl, - shard: mock.Mock, + shard: shard_api.GatewayShard, event_factory: event_factory_.EventFactory, ): - mock_payload = mock.Mock() + mock_payload: dict[str, typing.Any] = mock.Mock() - event_manager_impl.on_guild_scheduled_event_update(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_scheduled_event_update_event" + ) as patched_deserialize_scheduled_event_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_scheduled_event_update(shard, mock_payload) - event_factory.deserialize_scheduled_event_update_event.assert_called_once_with(shard, mock_payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_scheduled_event_update_event.return_value - ) + patched_deserialize_scheduled_event_update_event.assert_called_once_with(shard, mock_payload) + patched_dispatch.assert_called_once_with(patched_deserialize_scheduled_event_update_event.return_value) def test_on_guild_scheduled_event_user_add( self, event_manager_impl: event_manager.EventManagerImpl, - shard: mock.Mock, + shard: shard_api.GatewayShard, event_factory: event_factory_.EventFactory, ): - mock_payload = mock.Mock() + mock_payload: dict[str, typing.Any] = mock.Mock() - event_manager_impl.on_guild_scheduled_event_user_add(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_scheduled_event_user_add_event" + ) as patched_deserialize_scheduled_event_user_add_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_scheduled_event_user_add(shard, mock_payload) - event_factory.deserialize_scheduled_event_user_add_event.assert_called_once_with(shard, mock_payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_scheduled_event_user_add_event.return_value - ) + patched_deserialize_scheduled_event_user_add_event.assert_called_once_with(shard, mock_payload) + patched_dispatch.assert_called_once_with(patched_deserialize_scheduled_event_user_add_event.return_value) def test_on_guild_scheduled_event_user_remove( self, event_manager_impl: event_manager.EventManagerImpl, - shard: mock.Mock, + shard: shard_api.GatewayShard, event_factory: event_factory_.EventFactory, ): - mock_payload = mock.Mock() + mock_payload: dict[str, typing.Any] = mock.Mock() - event_manager_impl.on_guild_scheduled_event_user_remove(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_scheduled_event_user_remove_event" + ) as patched_deserialize_scheduled_event_user_remove_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_scheduled_event_user_remove(shard, mock_payload) - event_factory.deserialize_scheduled_event_user_remove_event.assert_called_once_with(shard, mock_payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_scheduled_event_user_remove_event.return_value - ) + patched_deserialize_scheduled_event_user_remove_event.assert_called_once_with(shard, mock_payload) + patched_dispatch.assert_called_once_with(patched_deserialize_scheduled_event_user_remove_event.return_value) def test_on_guild_audit_log_entry_create( self, event_manager_impl: event_manager.EventManagerImpl, - shard: mock.Mock, + shard: shard_api.GatewayShard, event_factory: event_factory_.EventFactory, ): - mock_payload = mock.Mock() + mock_payload: dict[str, typing.Any] = mock.Mock() - event_manager_impl.on_guild_audit_log_entry_create(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_audit_log_entry_create_event" + ) as patched_deserialize_audit_log_entry_create_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_guild_audit_log_entry_create(shard, mock_payload) - event_factory.deserialize_audit_log_entry_create_event.assert_called_once_with(shard, mock_payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_audit_log_entry_create_event.return_value - ) + patched_deserialize_audit_log_entry_create_event.assert_called_once_with(shard, mock_payload) + patched_dispatch.assert_called_once_with(patched_deserialize_audit_log_entry_create_event.return_value) def test_on_stage_instance_create( self, event_manager_impl: event_manager.EventManagerImpl, - shard: mock.Mock, + shard: shard_api.GatewayShard, event_factory: event_factory_.EventFactory, ): - payload = { + payload: dict[str, typing.Any] = { "id": "840647391636226060", "guild_id": "197038439483310086", "channel_id": "733488538393510049", @@ -1593,20 +2469,24 @@ def test_on_stage_instance_create( "discoverable_disabled": False, } - event_manager_impl.on_stage_instance_create(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_stage_instance_create_event" + ) as patched_deserialize_stage_instance_create_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_stage_instance_create(shard, payload) - event_factory.deserialize_stage_instance_create_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_stage_instance_create_event.return_value - ) + patched_deserialize_stage_instance_create_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(patched_deserialize_stage_instance_create_event.return_value) def test_on_stage_instance_update( self, event_manager_impl: event_manager.EventManagerImpl, - shard: mock.Mock, + shard: shard_api.GatewayShard, event_factory: event_factory_.EventFactory, ): - payload = { + payload: dict[str, typing.Any] = { "id": "840647391636226060", "guild_id": "197038439483310086", "channel_id": "733488538393510049", @@ -1615,20 +2495,24 @@ def test_on_stage_instance_update( "discoverable_disabled": False, } - event_manager_impl.on_stage_instance_update(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_stage_instance_update_event" + ) as patched_deserialize_stage_instance_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_stage_instance_update(shard, payload) - event_factory.deserialize_stage_instance_update_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_stage_instance_update_event.return_value - ) + patched_deserialize_stage_instance_update_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(patched_deserialize_stage_instance_update_event.return_value) def test_on_stage_instance_delete( self, event_manager_impl: event_manager.EventManagerImpl, - shard: mock.Mock, + shard: shard_api.GatewayShard, event_factory: event_factory_.EventFactory, ): - payload = { + payload: dict[str, typing.Any] = { "id": "840647391636226060", "guild_id": "197038439483310086", "channel_id": "733488538393510049", @@ -1637,12 +2521,16 @@ def test_on_stage_instance_delete( "discoverable_disabled": False, } - event_manager_impl.on_stage_instance_delete(shard, payload) + with ( + mock.patch.object( + event_factory, "deserialize_stage_instance_delete_event" + ) as patched_deserialize_stage_instance_delete_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_stage_instance_delete(shard, payload) - event_factory.deserialize_stage_instance_delete_event.assert_called_once_with(shard, payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_stage_instance_delete_event.return_value - ) + patched_deserialize_stage_instance_delete_event.assert_called_once_with(shard, payload) + patched_dispatch.assert_called_once_with(patched_deserialize_stage_instance_delete_event.return_value) def test_on_message_poll_vote_create( self, @@ -1652,12 +2540,16 @@ def test_on_message_poll_vote_create( ): mock_payload = mock.Mock() - event_manager_impl.on_message_poll_vote_add(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_poll_vote_create_event" + ) as patched_deserialize_poll_vote_create_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_message_poll_vote_add(shard, mock_payload) - event_factory.deserialize_poll_vote_create_event.assert_called_once_with(shard, mock_payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_poll_vote_create_event.return_value - ) + patched_deserialize_poll_vote_create_event.assert_called_once_with(shard, mock_payload) + patched_dispatch.assert_called_once_with(patched_deserialize_poll_vote_create_event.return_value) def test_on_message_poll_vote_delete( self, @@ -1667,12 +2559,16 @@ def test_on_message_poll_vote_delete( ): mock_payload = mock.Mock() - event_manager_impl.on_message_poll_vote_remove(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_poll_vote_delete_event" + ) as patched_deserialize_poll_vote_delete_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_message_poll_vote_remove(shard, mock_payload) - event_factory.deserialize_poll_vote_delete_event.assert_called_once_with(shard, mock_payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_poll_vote_delete_event.return_value - ) + patched_deserialize_poll_vote_delete_event.assert_called_once_with(shard, mock_payload) + patched_dispatch.assert_called_once_with(patched_deserialize_poll_vote_delete_event.return_value) def test_on_auto_moderation_rule_create( self, @@ -1682,12 +2578,16 @@ def test_on_auto_moderation_rule_create( ): mock_payload = mock.Mock() - event_manager_impl.on_auto_moderation_rule_create(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_auto_mod_rule_create_event" + ) as patched_deserialize_auto_mod_rule_create_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_auto_moderation_rule_create(shard, mock_payload) - event_factory.deserialize_auto_mod_rule_create_event.assert_called_once_with(shard, mock_payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_auto_mod_rule_create_event.return_value - ) + patched_deserialize_auto_mod_rule_create_event.assert_called_once_with(shard, mock_payload) + patched_dispatch.assert_called_once_with(patched_deserialize_auto_mod_rule_create_event.return_value) def test_on_auto_moderation_rule_update( self, @@ -1697,12 +2597,16 @@ def test_on_auto_moderation_rule_update( ): mock_payload = mock.Mock() - event_manager_impl.on_auto_moderation_rule_update(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_auto_mod_rule_update_event" + ) as patched_deserialize_auto_mod_rule_update_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_auto_moderation_rule_update(shard, mock_payload) - event_factory.deserialize_auto_mod_rule_update_event.assert_called_once_with(shard, mock_payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_auto_mod_rule_update_event.return_value - ) + patched_deserialize_auto_mod_rule_update_event.assert_called_once_with(shard, mock_payload) + patched_dispatch.assert_called_once_with(patched_deserialize_auto_mod_rule_update_event.return_value) def test_on_auto_moderation_rule_delete( self, @@ -1712,12 +2616,16 @@ def test_on_auto_moderation_rule_delete( ): mock_payload = mock.Mock() - event_manager_impl.on_auto_moderation_rule_delete(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_auto_mod_rule_delete_event" + ) as patched_deserialize_auto_mod_rule_delete_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_auto_moderation_rule_delete(shard, mock_payload) - event_factory.deserialize_auto_mod_rule_delete_event.assert_called_once_with(shard, mock_payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_auto_mod_rule_delete_event.return_value - ) + patched_deserialize_auto_mod_rule_delete_event.assert_called_once_with(shard, mock_payload) + patched_dispatch.assert_called_once_with(patched_deserialize_auto_mod_rule_delete_event.return_value) def test_on_auto_moderation_action_execution( self, @@ -1727,9 +2635,13 @@ def test_on_auto_moderation_action_execution( ): mock_payload = mock.Mock() - event_manager_impl.on_auto_moderation_action_execution(shard, mock_payload) + with ( + mock.patch.object( + event_factory, "deserialize_auto_mod_action_execution_event" + ) as patched_deserialize_auto_mod_action_execution_event, + mock.patch.object(event_manager_impl, "dispatch") as patched_dispatch, + ): + event_manager_impl.on_auto_moderation_action_execution(shard, mock_payload) - event_factory.deserialize_auto_mod_action_execution_event.assert_called_once_with(shard, mock_payload) - event_manager_impl.dispatch.assert_called_once_with( - event_factory.deserialize_auto_mod_action_execution_event.return_value - ) + patched_deserialize_auto_mod_action_execution_event.assert_called_once_with(shard, mock_payload) + patched_dispatch.assert_called_once_with(patched_deserialize_auto_mod_action_execution_event.return_value) diff --git a/tests/hikari/impl/test_event_manager_base.py b/tests/hikari/impl/test_event_manager_base.py index 7b9e661cc6..6d3a5e5895 100644 --- a/tests/hikari/impl/test_event_manager_base.py +++ b/tests/hikari/impl/test_event_manager_base.py @@ -21,7 +21,6 @@ from __future__ import annotations import asyncio -import contextlib import gc import sys import typing @@ -55,12 +54,12 @@ def test(): TypeError, match=r"dead weak referenced subscriber method cannot be executed, try actually closing your event streamers", ): - await call_weak_method(None) + await call_weak_method(mock.Mock()) @pytest.mark.asyncio async def test__generate_weak_listener(self): mock_listener = mock.AsyncMock() - mock_event = object() + mock_event = mock.Mock() def test(): return mock_listener @@ -72,16 +71,11 @@ def test(): mock_listener.assert_awaited_once_with(mock_event) -@pytest.fixture -def mock_app(): - return mock.Mock() - - class TestEventStream: def test___enter___and___exit__(self): stub_stream = hikari_test_helpers.mock_class_namespace( event_manager_base.EventStream, open=mock.Mock(), close=mock.Mock() - )(mock_app, base_events.Event, timeout=None) + )(mock.Mock(), base_events.Event, timeout=None) with stub_stream: stub_stream.open.assert_called_once_with() @@ -91,22 +85,22 @@ def test___enter___and___exit__(self): stub_stream.close.assert_called_once_with() @pytest.mark.asyncio - async def test__listener_when_filter_returns_false(self, mock_app): - stream = event_manager_base.EventStream(mock_app, base_events.Event, timeout=None) + async def test__listener_when_filter_returns_false(self): + stream = event_manager_base.EventStream(mock.Mock(), base_events.Event, timeout=None) stream.filter(lambda _: False) - mock_event = object() + mock_event = mock.Mock() assert await stream._listener(mock_event) is None assert not stream._queue @hikari_test_helpers.timeout() @pytest.mark.asyncio - async def test__listener_when_filter_passes_and_queue_full(self, mock_app): - stream = event_manager_base.EventStream(mock_app, base_events.Event, timeout=None, limit=2) - stream._queue.append(object()) - stream._queue.append(object()) + async def test__listener_when_filter_passes_and_queue_full(self): + stream = event_manager_base.EventStream(mock.Mock(), base_events.Event, timeout=None, limit=2) + stream._queue.append(mock.Mock()) + stream._queue.append(mock.Mock()) stream.filter(lambda _: True) - mock_event = object() + mock_event = mock.Mock() with stream: assert await stream._listener(mock_event) is None @@ -116,12 +110,12 @@ async def test__listener_when_filter_passes_and_queue_full(self, mock_app): @hikari_test_helpers.timeout() @pytest.mark.asyncio - async def test__listener_when_filter_passes_and_queue_not_full(self, mock_app): - stream = event_manager_base.EventStream(mock_app, base_events.Event, timeout=None, limit=None) - stream._queue.append(object()) - stream._queue.append(object()) + async def test__listener_when_filter_passes_and_queue_not_full(self): + stream = event_manager_base.EventStream(mock.Mock(), base_events.Event, timeout=None, limit=None) + stream._queue.append(mock.Mock()) + stream._queue.append(mock.Mock()) stream.filter(lambda _: True) - mock_event = object() + mock_event = mock.Mock() with stream: assert await stream._listener(mock_event) is None @@ -150,7 +144,7 @@ async def test___anext___times_out(self): @pytest.mark.asyncio @hikari_test_helpers.timeout() async def test___anext___waits_for_next_event(self): - mock_event = object() + mock_event = mock.Mock() streamer = event_manager_base.EventStream(mock.Mock(), event_type=base_events.Event, timeout=None) async def quickly_run_task(task): @@ -171,7 +165,7 @@ async def quickly_run_task(task): @pytest.mark.asyncio @hikari_test_helpers.timeout() async def test___anext__(self): - mock_event = object() + mock_event = mock.Mock() streamer = event_manager_base.EventStream( event_manager=mock.Mock(), event_type=base_events.Event, @@ -184,9 +178,9 @@ async def test___anext__(self): @pytest.mark.asyncio async def test___await__(self): - mock_event_0 = object() - mock_event_1 = object() - mock_event_2 = object() + mock_event_0 = mock.Mock() + mock_event_1 = mock.Mock() + mock_event_2 = mock.Mock() streamer = hikari_test_helpers.mock_class_namespace( event_manager_base.EventStream, close=mock.Mock(), @@ -201,7 +195,7 @@ async def test___await__(self): streamer.close.assert_called_once_with() def test___del___for_active_stream(self): - mock_coroutine = object() + mock_coroutine = mock.Mock() close_method = mock.Mock(return_value=mock_coroutine) streamer = hikari_test_helpers.mock_class_namespace( event_manager_base.EventStream, close=close_method, init_=False @@ -227,13 +221,15 @@ def test___del___for_inactive_stream(self): del streamer close_method.assert_not_called() - def test_close_for_inactive_stream(self, mock_app): - stream = event_manager_base.EventStream(mock_app, base_events.Event, timeout=None, limit=None) + def test_close_for_inactive_stream(self): + app = mock.Mock() + + stream = event_manager_base.EventStream(app.event_manager, base_events.Event, timeout=None, limit=None) stream.close() - mock_app.event_manager.unsubscribe.assert_not_called() + app.event_manager.unsubscribe.assert_not_called() def test_close_for_active_stream(self): - mock_registered_listener = object() + mock_registered_listener = mock.Mock() mock_manager = mock.Mock() stream = hikari_test_helpers.mock_class_namespace(event_manager_base.EventStream)( event_manager=mock_manager, event_type=base_events.Event, timeout=float("inf") @@ -247,7 +243,7 @@ def test_close_for_active_stream(self): assert stream._registered_listener is None def test_close_for_active_stream_handles_value_error(self): - mock_registered_listener = object() + mock_registered_listener = mock.Mock() mock_manager = mock.Mock() mock_manager.unsubscribe.side_effect = ValueError stream = hikari_test_helpers.mock_class_namespace(event_manager_base.EventStream)( @@ -272,7 +268,7 @@ async def test_filter(self): first_fails = mock.Mock(attr=True) second_fail = mock.Mock(attr=False) - def predicate(obj): + def predicate(obj: typing.Any): return obj in (first_pass, second_pass) stream.filter(predicate, attr=True) @@ -299,7 +295,7 @@ async def test_filter_handles_calls_while_active(self): await stream._listener(second_pass) await stream._listener(second_fail) - def predicate(obj): + def predicate(obj: typing.Any): return obj in (first_pass, second_pass) with stream: @@ -308,7 +304,7 @@ def predicate(obj): assert await stream == [first_pass, second_pass] def test_open_for_inactive_stream(self): - mock_listener = object() + mock_listener = mock.Mock() mock_manager = mock.Mock() stream = hikari_test_helpers.mock_class_namespace(event_manager_base.EventStream)( event_manager=mock_manager, event_type=base_events.Event, timeout=float("inf") @@ -337,8 +333,8 @@ def test_open_for_active_stream(self): event_manager=mock_manager, event_type=base_events.Event, timeout=float("inf") ) stream._active = False - mock_listener = object() - mock_listener_ref = object() + mock_listener = mock.Mock() + mock_listener_ref = mock.Mock() with mock.patch.object(event_manager_base, "_generate_weak_listener", return_value=mock_listener): with mock.patch.object(weakref, "WeakMethod", return_value=mock_listener_ref): @@ -360,20 +356,23 @@ class TestConsumer: ("is_caching", "listener_group_count", "waiter_group_count", "expected_result"), [(True, -10000, -10000, True), (False, 0, 1, True), (False, 1, 0, True), (False, 0, 0, False)], ) - def test_is_enabled(self, is_caching, listener_group_count, waiter_group_count, expected_result): - consumer = event_manager_base._Consumer(object(), 123, is_caching) + def test_is_enabled( + self, is_caching: bool, listener_group_count: int, waiter_group_count: int, expected_result: bool + ): + consumer = event_manager_base._Consumer(mock.Mock(), 123, is_caching) consumer.listener_group_count = listener_group_count consumer.waiter_group_count = waiter_group_count assert consumer.is_enabled is expected_result +class EventManagerBaseImpl(event_manager_base.EventManagerBase): + on_existing_event = None + + class TestEventManagerBase: @pytest.fixture - def event_manager(self): - class EventManagerBaseImpl(event_manager_base.EventManagerBase): - on_existing_event = None - + def event_manager(self) -> EventManagerBaseImpl: return EventManagerBaseImpl(mock.Mock(), mock.Mock()) def test___init___loads_consumers(self): @@ -426,7 +425,7 @@ async def on_not_decorated(self, event): async def not_a_listener(self): raise NotImplementedError - manager = StubManager(mock.Mock(), 0, cache_components=config.CacheComponents.NONE) + manager = StubManager(mock.Mock(), intents.Intents.NONE, cache_components=config.CacheComponents.NONE) assert manager._consumers == { "foo": event_manager_base._Consumer(manager.on_foo, 9, False), "bar": event_manager_base._Consumer(manager.on_bar, 105, False), @@ -434,10 +433,10 @@ async def not_a_listener(self): "not_decorated": event_manager_base._Consumer(manager.on_not_decorated, -1, False), } - def test__increment_listener_group_count(self, event_manager): - on_foo_consumer = event_manager_base._Consumer(None, 9, False) - on_bar_consumer = event_manager_base._Consumer(None, 105, False) - on_bat_consumer = event_manager_base._Consumer(None, 1, False) + def test__increment_listener_group_count(self, event_manager: EventManagerBaseImpl): + on_foo_consumer = event_manager_base._Consumer(mock.Mock(), 9, False) + on_bar_consumer = event_manager_base._Consumer(mock.Mock(), 105, False) + on_bat_consumer = event_manager_base._Consumer(mock.Mock(), 1, False) event_manager._consumers = {"foo": on_foo_consumer, "bar": on_bar_consumer, "bat": on_bat_consumer} event_manager._increment_listener_group_count(shard_events.ShardEvent, 1) @@ -446,10 +445,10 @@ def test__increment_listener_group_count(self, event_manager): assert on_bar_consumer.listener_group_count == 1 assert on_bat_consumer.listener_group_count == 0 - def test__increment_waiter_group_count(self, event_manager): - on_foo_consumer = event_manager_base._Consumer(None, 9, False) - on_bar_consumer = event_manager_base._Consumer(None, 105, False) - on_bat_consumer = event_manager_base._Consumer(None, 1, False) + def test__increment_waiter_group_count(self, event_manager: EventManagerBaseImpl): + on_foo_consumer = event_manager_base._Consumer(mock.Mock(), 9, False) + on_bar_consumer = event_manager_base._Consumer(mock.Mock(), 105, False) + on_bat_consumer = event_manager_base._Consumer(mock.Mock(), 1, False) event_manager._consumers = {"foo": on_foo_consumer, "bar": on_bar_consumer, "bat": on_bat_consumer} event_manager._increment_waiter_group_count(shard_events.ShardEvent, 1) @@ -458,26 +457,26 @@ def test__increment_waiter_group_count(self, event_manager): assert on_bar_consumer.waiter_group_count == 1 assert on_bat_consumer.waiter_group_count == 0 - def test__enabled_for_event_when_listener_registered(self, event_manager): + def test__enabled_for_event_when_listener_registered(self, event_manager: EventManagerBaseImpl): event_manager._listeners = {shard_events.ShardStateEvent: [], shard_events.MemberChunkEvent: []} event_manager._waiters = {} assert event_manager._enabled_for_event(shard_events.ShardStateEvent) is True - def test__enabled_for_event_when_waiter_registered(self, event_manager): + def test__enabled_for_event_when_waiter_registered(self, event_manager: EventManagerBaseImpl): event_manager._listeners = {} event_manager._waiters = {shard_events.ShardStateEvent: [], shard_events.MemberChunkEvent: []} assert event_manager._enabled_for_event(shard_events.ShardStateEvent) is True - def test__enabled_for_event_when_not_registered(self, event_manager): + def test__enabled_for_event_when_not_registered(self, event_manager: EventManagerBaseImpl): event_manager._listeners = {shard_events.ShardPayloadEvent: [], shard_events.MemberChunkEvent: []} event_manager._waiters = {shard_events.ShardPayloadEvent: [], shard_events.MemberChunkEvent: []} assert event_manager._enabled_for_event(shard_events.ShardStateEvent) is False @pytest.mark.asyncio - async def test_consume_raw_event_when_KeyError(self, event_manager): + async def test_consume_raw_event_when_KeyError(self, event_manager: EventManagerBaseImpl): event_manager._enabled_for_event = mock.Mock(return_value=True) mock_payload = {"id": "3123123123"} mock_shard = mock.Mock(id=123) @@ -495,12 +494,12 @@ async def test_consume_raw_event_when_KeyError(self, event_manager): event_manager._enabled_for_event.assert_called_once_with(shard_events.ShardPayloadEvent) @pytest.mark.asyncio - async def test_consume_raw_event_when_found(self, event_manager): + async def test_consume_raw_event_when_found(self, event_manager: EventManagerBaseImpl): event_manager._enabled_for_event = mock.Mock(return_value=True) event_manager.dispatch = mock.Mock() on_existing_event = mock.Mock(is_enabled=True, callback=mock.Mock(__name__="testing")) event_manager._consumers = {"existing_event": on_existing_event} - shard = object() + shard = mock.Mock() payload = {"berp": "baz"} event_manager.consume_raw_event("EXISTING_EVENT", shard, payload) @@ -515,12 +514,14 @@ async def test_consume_raw_event_when_found(self, event_manager): event_manager._enabled_for_event.assert_called_once_with(shard_events.ShardPayloadEvent) @pytest.mark.asyncio - async def test_consume_raw_event_skips_consumer_callback_when_not_enabled(self, event_manager): + async def test_consume_raw_event_skips_consumer_callback_when_not_enabled( + self, event_manager: EventManagerBaseImpl + ): event_manager._enabled_for_event = mock.Mock(return_value=True) event_manager.dispatch = mock.Mock() on_existing_event = mock.Mock(is_enabled=False, callback=mock.Mock(__name__="testing")) event_manager._consumers = {"existing_event": on_existing_event} - shard = object() + shard = mock.Mock() payload = {"berp": "baz"} event_manager.consume_raw_event("EXISTING_EVENT", shard, payload) @@ -535,7 +536,7 @@ async def test_consume_raw_event_skips_consumer_callback_when_not_enabled(self, event_manager._enabled_for_event.assert_called_once_with(shard_events.ShardPayloadEvent) @pytest.mark.asyncio - async def test_consume_raw_event_skips_raw_dispatch_when_not_enabled(self, event_manager): + async def test_consume_raw_event_skips_raw_dispatch_when_not_enabled(self, event_manager: EventManagerBaseImpl): event_manager._enabled_for_event = mock.Mock(return_value=False) event_manager.dispatch = mock.Mock() on_existing_event = mock.Mock(is_enabled=False, callback=mock.Mock(__name__="testing")) @@ -550,7 +551,7 @@ async def test_consume_raw_event_skips_raw_dispatch_when_not_enabled(self, event event_manager._enabled_for_event.assert_called_once_with(shard_events.ShardPayloadEvent) @pytest.mark.asyncio - async def test_consume_raw_event_handles_exceptions(self, event_manager): + async def test_consume_raw_event_handles_exceptions(self, event_manager: EventManagerBaseImpl): mock_task = mock.Mock() # On Python 3.12+ Asyncio uses this to get the task's context if set to call the # error handler in. We want to avoid for this test for simplicity. @@ -562,7 +563,7 @@ async def test_consume_raw_event_handles_exceptions(self, event_manager): error_handler = mock.MagicMock() event_loop = asyncio.get_running_loop() event_loop.set_exception_handler(error_handler) - shard = object() + shard = mock.Mock() pl = {"i like": "cats"} with mock.patch.object(asyncio, "current_task", return_value=mock_task): @@ -578,7 +579,7 @@ async def test_consume_raw_event_handles_exceptions(self, event_manager): }, ) - def test_subscribe_when_class_call(self, event_manager): + def test_subscribe_when_class_call(self, event_manager: EventManagerBaseImpl): class Foo: async def __call__(self) -> None: ... @@ -589,13 +590,13 @@ async def __call__(self) -> None: ... assert event_manager._listeners[member_events.MemberCreateEvent] == [foo] - def test_subscribe_when_callback_is_not_coroutine(self, event_manager): + def test_subscribe_when_callback_is_not_coroutine(self, event_manager: EventManagerBaseImpl): def test(): ... with pytest.raises(TypeError, match=r"Cannot subscribe a non-coroutine function callback"): event_manager.subscribe(member_events.MemberCreateEvent, test) - def test_subscribe_when_event_type_not_in_listeners(self, event_manager): + def test_subscribe_when_event_type_not_in_listeners(self, event_manager: EventManagerBaseImpl): async def test(): ... event_manager._increment_listener_group_count = mock.Mock() @@ -607,7 +608,7 @@ async def test(): ... event_manager._check_event.assert_called_once_with(member_events.MemberCreateEvent, 1) event_manager._increment_listener_group_count.assert_called_once_with(member_events.MemberCreateEvent, 1) - def test_subscribe_when_event_type_in_listeners(self, event_manager): + def test_subscribe_when_event_type_in_listeners(self, event_manager: EventManagerBaseImpl): async def test(): ... async def test2(): ... @@ -623,11 +624,13 @@ async def test2(): ... event_manager._increment_listener_group_count.assert_not_called() @pytest.mark.parametrize("obj", ["test", event_manager_base.EventManagerBase]) - def test__check_event_when_event_type_does_not_subclass_Event(self, event_manager, obj): + def test__check_event_when_event_type_does_not_subclass_Event( + self, event_manager: EventManagerBaseImpl, obj: typing.Any + ): with pytest.raises(TypeError, match=r"'event_type' is a non-Event type"): event_manager._check_event(obj, 0) - def test__check_event_when_no_intents_required(self, event_manager): + def test__check_event_when_no_intents_required(self, event_manager: EventManagerBaseImpl): event_manager._intents = intents.Intents.ALL with mock.patch.object(base_events, "get_required_intents_for", return_value=None) as get_intents: @@ -637,7 +640,7 @@ def test__check_event_when_no_intents_required(self, event_manager): get_intents.assert_called_once_with(member_events.MemberCreateEvent) warn.assert_not_called() - def test__check_event_when_generic_event(self, event_manager): + def test__check_event_when_generic_event(self, event_manager: EventManagerBaseImpl): T = typing.TypeVar("T") class GenericEvent(typing.Generic[T], base_events.Event): ... @@ -653,7 +656,7 @@ class GenericEvent(typing.Generic[T], base_events.Event): ... get_intents.assert_called_once_with(GenericEvent) warn.assert_not_called() - def test__check_event_when_intents_correct(self, event_manager): + def test__check_event_when_intents_correct(self, event_manager: EventManagerBaseImpl): event_manager._intents = intents.Intents.GUILD_EMOJIS | intents.Intents.GUILD_MEMBERS with mock.patch.object( @@ -665,7 +668,7 @@ def test__check_event_when_intents_correct(self, event_manager): get_intents.assert_called_once_with(member_events.MemberCreateEvent) warn.assert_not_called() - def test__check_event_when_intents_incorrect(self, event_manager): + def test__check_event_when_intents_incorrect(self, event_manager: EventManagerBaseImpl): event_manager._intents = intents.Intents.GUILD_EMOJIS with mock.patch.object( @@ -682,13 +685,13 @@ def test__check_event_when_intents_incorrect(self, event_manager): stacklevel=3, ) - def test_get_listeners_when_not_event(self, event_manager): + def test_get_listeners_when_not_event(self, event_manager: EventManagerBaseImpl): event_manager._listeners = {} assert event_manager.get_listeners(base_events.Event) == [] - def test_get_listeners_polymorphic(self, event_manager): - event_manager._listeners = { + def test_get_listeners_polymorphic(self, event_manager: EventManagerBaseImpl): + listeners = { base_events.Event: ["coroutine0"], member_events.MemberEvent: ["coroutine1"], member_events.MemberCreateEvent: ["hi", "i am"], @@ -696,28 +699,31 @@ def test_get_listeners_polymorphic(self, event_manager): base_events.ExceptionEvent: ["so you won't see me"], } - assert event_manager.get_listeners(member_events.MemberEvent) == ["coroutine1", "coroutine0"] + with mock.patch.object(event_manager, "_listeners", listeners): + assert event_manager.get_listeners(member_events.MemberEvent) == ["coroutine1", "coroutine0"] - def test_get_listeners_monomorphic_and_no_results(self, event_manager): - event_manager._listeners = { + def test_get_listeners_monomorphic_and_no_results(self, event_manager: EventManagerBaseImpl): + listeners = { member_events.MemberCreateEvent: ["coroutine1", "coroutine2"], member_events.MemberUpdateEvent: ["coroutine3"], member_events.MemberDeleteEvent: ["coroutine4", "coroutine5"], } - assert event_manager.get_listeners(member_events.MemberEvent, polymorphic=False) == () + with mock.patch.object(event_manager, "_listeners", listeners): + assert event_manager.get_listeners(member_events.MemberEvent, polymorphic=False) == () - def test_get_listeners_monomorphic_and_results(self, event_manager): - event_manager._listeners = { + def test_get_listeners_monomorphic_and_results(self, event_manager: EventManagerBaseImpl): + listeners = { member_events.MemberEvent: ["coroutine0"], member_events.MemberCreateEvent: ["coroutine1", "coroutine2"], member_events.MemberUpdateEvent: ["coroutine3"], member_events.MemberDeleteEvent: ["coroutine4", "coroutine5"], } - assert event_manager.get_listeners(member_events.MemberEvent, polymorphic=False) == ["coroutine0"] + with mock.patch.object(event_manager, "_listeners", listeners): + assert event_manager.get_listeners(member_events.MemberEvent, polymorphic=False) == ["coroutine0"] - def test_unsubscribe_when_event_type_not_in_listeners(self, event_manager): + def test_unsubscribe_when_event_type_not_in_listeners(self, event_manager: EventManagerBaseImpl): async def test(): ... event_manager._increment_listener_group_count = mock.Mock() @@ -728,67 +734,70 @@ async def test(): ... assert event_manager._listeners == {} event_manager._increment_listener_group_count.assert_not_called() - def test_unsubscribe_when_event_type_when_list_not_empty_after_delete(self, event_manager): + def test_unsubscribe_when_event_type_when_list_not_empty_after_delete(self, event_manager: EventManagerBaseImpl): async def test(): ... async def test2(): ... event_manager._increment_listener_group_count = mock.Mock() - event_manager._listeners = { - member_events.MemberCreateEvent: [test, test2], - member_events.MemberDeleteEvent: [test], - } + listeners = {member_events.MemberCreateEvent: [test, test2], member_events.MemberDeleteEvent: [test]} - event_manager.unsubscribe(member_events.MemberCreateEvent, test) + with ( + mock.patch.object(event_manager, "_listeners", listeners) as patched__listeners, + mock.patch.object( + event_manager, "_increment_listener_group_count" + ) as patched__increment_listener_group_count, + ): + event_manager.unsubscribe(member_events.MemberCreateEvent, test) - assert event_manager._listeners == { - member_events.MemberCreateEvent: [test2], - member_events.MemberDeleteEvent: [test], - } - event_manager._increment_listener_group_count.assert_not_called() + assert patched__listeners == {member_events.MemberCreateEvent: [test2], member_events.MemberDeleteEvent: [test]} + patched__increment_listener_group_count.assert_not_called() - def test_unsubscribe_when_event_type_when_list_empty_after_delete(self, event_manager): + def test_unsubscribe_when_event_type_when_list_empty_after_delete(self, event_manager: EventManagerBaseImpl): async def test(): ... - event_manager._increment_listener_group_count = mock.Mock() - event_manager._listeners = {member_events.MemberCreateEvent: [test], member_events.MemberDeleteEvent: [test]} + listeners = {member_events.MemberCreateEvent: [test], member_events.MemberDeleteEvent: [test]} - event_manager.unsubscribe(member_events.MemberCreateEvent, test) + with ( + mock.patch.object(event_manager, "_listeners", listeners) as patched__listeners, + mock.patch.object( + event_manager, "_increment_listener_group_count" + ) as patched__increment_listener_group_count, + ): + event_manager.unsubscribe(member_events.MemberCreateEvent, test) - assert event_manager._listeners == {member_events.MemberDeleteEvent: [test]} - event_manager._increment_listener_group_count.assert_called_once_with(member_events.MemberCreateEvent, -1) + assert patched__listeners == {member_events.MemberDeleteEvent: [test]} + patched__increment_listener_group_count.assert_called_once_with(member_events.MemberCreateEvent, -1) - def test_listen_when_no_params(self, event_manager): + def test_listen_when_no_params(self, event_manager: EventManagerBaseImpl): with pytest.raises(TypeError): @event_manager.listen() async def test(): ... - def test_listen_when_more_then_one_param_when_provided_in_typehint(self, event_manager): + def test_listen_when_more_then_one_param_when_provided_in_typehint(self, event_manager: EventManagerBaseImpl): with pytest.raises(TypeError): @event_manager.listen() async def test(a, b, c): ... - def test_listen_when_more_then_one_param_when_provided_in_decorator(self, event_manager): + def test_listen_when_more_then_one_param_when_provided_in_decorator(self, event_manager: EventManagerBaseImpl): with pytest.raises(TypeError): @event_manager.listen(object) async def test(a, b, c): ... - def test_listen_when_param_not_provided_in_decorator_nor_typehint(self, event_manager): + def test_listen_when_param_not_provided_in_decorator_nor_typehint(self, event_manager: EventManagerBaseImpl): with pytest.raises(TypeError): @event_manager.listen() async def test(event): ... - def test_listen_when_param_provided_in_decorator(self, event_manager): - stack = contextlib.ExitStack() - - subscribe = stack.enter_context(mock.patch.object(event_manager_base.EventManagerBase, "subscribe")) - resolve_signature = stack.enter_context(mock.patch.object(reflect, "resolve_signature")) - - with stack: + def test_listen_when_param_provided_in_decorator(self, event_manager: EventManagerBaseImpl): + with ( + mock.patch.object(event_manager_base.EventManagerBase, "subscribe") as subscribe, + mock.patch.object(reflect, "resolve_signature") as resolve_signature, + ): @event_manager.listen(member_events.MemberCreateEvent) async def test(event): ... @@ -796,13 +805,11 @@ async def test(event): ... resolve_signature.assert_not_called() subscribe.assert_called_once_with(member_events.MemberCreateEvent, test, _nested=1) - def test_listen_when_multiple_params_provided_in_decorator(self, event_manager): - stack = contextlib.ExitStack() - - subscribe = stack.enter_context(mock.patch.object(event_manager_base.EventManagerBase, "subscribe")) - resolve_signature = stack.enter_context(mock.patch.object(reflect, "resolve_signature")) - - with stack: + def test_listen_when_multiple_params_provided_in_decorator(self, event_manager: EventManagerBaseImpl): + with ( + mock.patch.object(event_manager_base.EventManagerBase, "subscribe") as subscribe, + mock.patch.object(reflect, "resolve_signature") as resolve_signature, + ): @event_manager.listen(member_events.MemberCreateEvent, member_events.MemberDeleteEvent) async def test(event): ... @@ -816,7 +823,7 @@ async def test(event): ... ] ) - def test_listen_when_param_provided_in_typehint(self, event_manager): + def test_listen_when_param_provided_in_typehint(self, event_manager: EventManagerBaseImpl): with mock.patch.object(event_manager_base.EventManagerBase, "subscribe") as subscribe: @event_manager.listen() @@ -824,7 +831,9 @@ async def test(event: member_events.MemberCreateEvent): ... subscribe.assert_called_once_with(member_events.MemberCreateEvent, test, _nested=1) - def test_listen_when_multiple_params_provided_as_typing_union_in_typehint(self, event_manager): + def test_listen_when_multiple_params_provided_as_typing_union_in_typehint( + self, event_manager: EventManagerBaseImpl + ): with mock.patch.object(event_manager_base.EventManagerBase, "subscribe") as subscribe: @event_manager.listen() @@ -839,7 +848,9 @@ async def test(event: typing.Union[member_events.MemberCreateEvent, member_event ) @pytest.mark.skipif(sys.version_info < (3, 10), reason="Bitwise union only available on 3.10+") - def test_listen_when_multiple_params_provided_as_bitwise_union_in_typehint(self, event_manager): + def test_listen_when_multiple_params_provided_as_bitwise_union_in_typehint( + self, event_manager: EventManagerBaseImpl + ): with mock.patch.object(event_manager_base.EventManagerBase, "subscribe") as subscribe: @event_manager.listen() @@ -853,7 +864,7 @@ async def test(event: member_events.MemberCreateEvent | member_events.MemberDele ] ) - def test_listen_when_incorrect_type_in_typehint(self, event_manager): + def test_listen_when_incorrect_type_in_typehint(self, event_manager: EventManagerBaseImpl): with pytest.raises(TypeError): @event_manager.listen() diff --git a/tests/hikari/impl/test_gateway_bot.py b/tests/hikari/impl/test_gateway_bot.py index a9afd390b9..74e417c56f 100644 --- a/tests/hikari/impl/test_gateway_bot.py +++ b/tests/hikari/impl/test_gateway_bot.py @@ -21,8 +21,10 @@ from __future__ import annotations import asyncio -import contextlib +import concurrent.futures +import datetime import sys +import typing import warnings import mock @@ -30,6 +32,7 @@ from hikari import applications from hikari import errors +from hikari import intents as intents_ from hikari import presences from hikari import snowflakes from hikari import undefined @@ -49,7 +52,7 @@ @pytest.mark.parametrize("activity", [undefined.UNDEFINED, None]) -def test_validate_activity_when_no_activity(activity): +def test_validate_activity_when_no_activity(activity: undefined.UndefinedType | None): with mock.patch.object(warnings, "warn") as warn: bot_impl._validate_activity(activity) @@ -81,71 +84,70 @@ def test_validate_activity_when_no_warning(): class TestGatewayBot: @pytest.fixture - def cache(self): + def cache(self) -> cache_impl.CacheImpl: return mock.Mock() @pytest.fixture - def entity_factory(self): + def entity_factory(self) -> entity_factory_impl.EntityFactoryImpl: return mock.Mock() @pytest.fixture - def event_factory(self): + def event_factory(self) -> event_factory_impl.EventFactoryImpl: return mock.Mock() @pytest.fixture - def event_manager(self): + def event_manager(self) -> event_manager_impl.EventManagerImpl: return mock.Mock() @pytest.fixture - def rest(self): + def rest(self) -> rest_impl.RESTClientImpl: return mock.Mock() @pytest.fixture - def voice(self): + def voice(self) -> voice_impl.VoiceComponentImpl: return mock.Mock() @pytest.fixture - def executor(self): + def executor(self) -> concurrent.futures.Executor: return mock.Mock() @pytest.fixture - def intents(self): + def intents(self) -> intents_.Intents: return mock.Mock() @pytest.fixture - def proxy_settings(self): + def proxy_settings(self) -> config.ProxySettings: return mock.Mock() @pytest.fixture - def http_settings(self): + def http_settings(self) -> config.HTTPSettings: return mock.Mock() @pytest.fixture def bot( self, - cache, - entity_factory, - event_factory, - event_manager, - rest, - voice, - executor, - intents, - proxy_settings, - http_settings, + cache: cache_impl.CacheImpl, + entity_factory: entity_factory_impl.EntityFactoryImpl, + event_factory: event_factory_impl.EventFactoryImpl, + event_manager: event_manager_impl.EventManagerImpl, + rest: rest_impl.RESTClientImpl, + voice: voice_impl.VoiceComponentImpl, + executor: concurrent.futures.Executor, + intents: intents_.Intents, + proxy_settings: config.ProxySettings, + http_settings: config.HTTPSettings, ): - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(cache_impl, "CacheImpl", return_value=cache)) - stack.enter_context(mock.patch.object(entity_factory_impl, "EntityFactoryImpl", return_value=entity_factory)) - stack.enter_context(mock.patch.object(event_factory_impl, "EventFactoryImpl", return_value=event_factory)) - stack.enter_context(mock.patch.object(event_manager_impl, "EventManagerImpl", return_value=event_manager)) - stack.enter_context(mock.patch.object(voice_impl, "VoiceComponentImpl", return_value=voice)) - stack.enter_context(mock.patch.object(rest_impl, "RESTClientImpl", return_value=rest)) - stack.enter_context(mock.patch.object(ux, "init_logging")) - stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "print_banner")) - stack.enter_context(mock.patch.object(ux, "warn_if_not_optimized")) - - with stack: + with ( + mock.patch.object(cache_impl, "CacheImpl", return_value=cache), + mock.patch.object(entity_factory_impl, "EntityFactoryImpl", return_value=entity_factory), + mock.patch.object(event_factory_impl, "EventFactoryImpl", return_value=event_factory), + mock.patch.object(event_manager_impl, "EventManagerImpl", return_value=event_manager), + mock.patch.object(voice_impl, "VoiceComponentImpl", return_value=voice), + mock.patch.object(rest_impl, "RESTClientImpl", return_value=rest), + mock.patch.object(ux, "init_logging"), + mock.patch.object(bot_impl.GatewayBot, "print_banner"), + mock.patch.object(ux, "warn_if_not_optimized"), + ): return bot_impl.GatewayBot( "token", executor=executor, @@ -156,23 +158,23 @@ def bot( ) def test_init(self): - stack = contextlib.ExitStack() - cache = stack.enter_context(mock.patch.object(cache_impl, "CacheImpl")) - entity_factory = stack.enter_context(mock.patch.object(entity_factory_impl, "EntityFactoryImpl")) - event_factory = stack.enter_context(mock.patch.object(event_factory_impl, "EventFactoryImpl")) - event_manager = stack.enter_context(mock.patch.object(event_manager_impl, "EventManagerImpl")) - voice = stack.enter_context(mock.patch.object(voice_impl, "VoiceComponentImpl")) - rest = stack.enter_context(mock.patch.object(rest_impl, "RESTClientImpl")) - init_logging = stack.enter_context(mock.patch.object(ux, "init_logging")) - warn_if_not_optimized = stack.enter_context(mock.patch.object(ux, "warn_if_not_optimized")) - print_banner = stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "print_banner")) - executor = object() - cache_settings = object() - http_settings = object() - proxy_settings = object() - intents = object() - - with stack: + executor = mock.Mock() + cache_settings = mock.Mock() + http_settings = mock.Mock() + proxy_settings = mock.Mock() + intents = mock.Mock() + + with ( + mock.patch.object(cache_impl, "CacheImpl") as cache, + mock.patch.object(entity_factory_impl, "EntityFactoryImpl") as entity_factory, + mock.patch.object(event_factory_impl, "EventFactoryImpl") as event_factory, + mock.patch.object(event_manager_impl, "EventManagerImpl") as event_manager, + mock.patch.object(voice_impl, "VoiceComponentImpl") as voice, + mock.patch.object(rest_impl, "RESTClientImpl") as rest, + mock.patch.object(ux, "init_logging") as init_logging, + mock.patch.object(ux, "warn_if_not_optimized") as warn_if_not_optimized, + mock.patch.object(bot_impl.GatewayBot, "print_banner") as print_banner, + ): bot = bot_impl.GatewayBot( "token", allow_color=False, @@ -230,21 +232,20 @@ def test_init(self): warn_if_not_optimized.assert_called_once_with(suppress=True) def test_init_when_no_settings(self): - stack = contextlib.ExitStack() - cache = stack.enter_context(mock.patch.object(cache_impl, "CacheImpl")) - stack.enter_context(mock.patch.object(entity_factory_impl, "EntityFactoryImpl")) - stack.enter_context(mock.patch.object(event_factory_impl, "EventFactoryImpl")) - stack.enter_context(mock.patch.object(event_manager_impl, "EventManagerImpl")) - stack.enter_context(mock.patch.object(voice_impl, "VoiceComponentImpl")) - stack.enter_context(mock.patch.object(rest_impl, "RESTClientImpl")) - stack.enter_context(mock.patch.object(ux, "init_logging")) - stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "print_banner")) - stack.enter_context(mock.patch.object(ux, "warn_if_not_optimized")) - http_settings = stack.enter_context(mock.patch.object(config, "HTTPSettings")) - proxy_settings = stack.enter_context(mock.patch.object(config, "ProxySettings")) - cache_settings = stack.enter_context(mock.patch.object(config, "CacheSettings")) - - with stack: + with ( + mock.patch.object(cache_impl, "CacheImpl") as cache, + mock.patch.object(entity_factory_impl, "EntityFactoryImpl"), + mock.patch.object(event_factory_impl, "EventFactoryImpl"), + mock.patch.object(event_manager_impl, "EventManagerImpl"), + mock.patch.object(voice_impl, "VoiceComponentImpl"), + mock.patch.object(rest_impl, "RESTClientImpl"), + mock.patch.object(ux, "init_logging"), + mock.patch.object(bot_impl.GatewayBot, "print_banner"), + mock.patch.object(ux, "warn_if_not_optimized"), + mock.patch.object(config, "HTTPSettings") as http_settings, + mock.patch.object(config, "ProxySettings") as proxy_settings, + mock.patch.object(config, "CacheSettings") as cache_settings, + ): bot = bot_impl.GatewayBot("token", cache_settings=None, http_settings=None, proxy_settings=None) assert bot._http_settings is http_settings.return_value @@ -255,34 +256,33 @@ def test_init_when_no_settings(self): cache_settings.assert_called_once_with() def test_init_strips_token(self): - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(ux, "init_logging")) - stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "print_banner")) - stack.enter_context(mock.patch.object(ux, "warn_if_not_optimized")) - - with stack: + with ( + mock.patch.object(ux, "init_logging"), + mock.patch.object(bot_impl.GatewayBot, "print_banner"), + mock.patch.object(ux, "warn_if_not_optimized"), + ): bot = bot_impl.GatewayBot( "\n\r token yeet \r\n", cache_settings=None, http_settings=None, proxy_settings=None ) assert bot._token == "token yeet" - def test_cache(self, bot, cache): + def test_cache(self, bot: bot_impl.GatewayBot, cache: cache_impl.CacheImpl): assert bot.cache is cache - def test_event_manager(self, bot, event_manager): + def test_event_manager(self, bot: bot_impl.GatewayBot, event_manager: event_manager_impl.EventManagerImpl): assert bot.event_manager is event_manager - def test_entity_factory(self, bot, entity_factory): + def test_entity_factory(self, bot: bot_impl.GatewayBot, entity_factory: entity_factory_impl.EntityFactoryImpl): assert bot.entity_factory is entity_factory - def test_event_factory(self, bot, event_factory): + def test_event_factory(self, bot: bot_impl.GatewayBot, event_factory: event_factory_impl.EventFactoryImpl): assert bot.event_factory is event_factory - def test_executor(self, bot, executor): + def test_executor(self, bot: bot_impl.GatewayBot, executor: concurrent.futures.Executor): assert bot.executor is executor - def test_heartbeat_latencies(self, bot): + def test_heartbeat_latencies(self, bot: bot_impl.GatewayBot): bot._shards = { 0: mock.Mock(id=0, heartbeat_latency=96), 1: mock.Mock(id=1, heartbeat_latency=123), @@ -291,7 +291,7 @@ def test_heartbeat_latencies(self, bot): assert bot.heartbeat_latencies == {0: 96, 1: 123, 2: 456} - def test_heartbeat_latency(self, bot): + def test_heartbeat_latency(self, bot: bot_impl.GatewayBot): bot._shards = { 0: mock.Mock(heartbeat_latency=96), 1: mock.Mock(heartbeat_latency=123), @@ -300,60 +300,61 @@ def test_heartbeat_latency(self, bot): assert bot.heartbeat_latency == 109.5 - def test_http_settings(self, bot, http_settings): + def test_http_settings(self, bot: bot_impl.GatewayBot, http_settings: config.HTTPSettings): assert bot.http_settings is http_settings - def test_intents(self, bot, intents): + def test_intents(self, bot: bot_impl.GatewayBot, intents: intents_.Intents): assert bot.intents is intents - def test_get_me(self, bot, cache): - assert bot.get_me() is cache.get_me.return_value + def test_get_me(self, bot: bot_impl.GatewayBot, cache: cache_impl.CacheImpl): + with mock.patch.object(cache, "get_me") as patched_get_me: + assert bot.get_me() is patched_get_me.return_value - def test_proxy_settings(self, bot, proxy_settings): + def test_proxy_settings(self, bot: bot_impl.GatewayBot, proxy_settings: config.ProxySettings): assert bot.proxy_settings is proxy_settings - def test_shard_count_when_no_shards(self, bot): + def test_shard_count_when_no_shards(self, bot: bot_impl.GatewayBot): bot._shards = {} assert bot.shard_count == 0 - def test_shard_count(self, bot): + def test_shard_count(self, bot: bot_impl.GatewayBot): bot._shards = {0: mock.Mock(shard_count=96), 1: mock.Mock(shard_count=123)} assert bot.shard_count == 96 - def test_voice(self, bot, voice): + def test_voice(self, bot: bot_impl.GatewayBot, voice: voice_impl.VoiceComponentImpl): assert bot.voice is voice - def test_rest(self, bot, rest): + def test_rest(self, bot: bot_impl.GatewayBot, rest: rest_impl.RESTClientImpl): assert bot.rest is rest - @pytest.mark.parametrize(("closed_event", "expected"), [("something", True), (None, False)]) - def test_is_alive(self, bot, closed_event, expected): + @pytest.mark.parametrize(("closed_event", "expected"), [(mock.Mock(asyncio.Event), True), (None, False)]) + def test_is_alive(self, bot: bot_impl.GatewayBot, closed_event: asyncio.Event | None, expected: bool): bot._closed_event = closed_event assert bot.is_alive is expected - def test_check_if_alive(self, bot): - bot._closed_event = object() + def test_check_if_alive(self, bot: bot_impl.GatewayBot): + bot._closed_event = mock.Mock() bot._check_if_alive() - def test_check_if_alive_when_False(self, bot): + def test_check_if_alive_when_False(self, bot: bot_impl.GatewayBot): bot._closed_event = None with pytest.raises(errors.ComponentStateConflictError): bot._check_if_alive() @pytest.mark.asyncio - async def test_close_when_already_closed(self, bot): + async def test_close_when_already_closed(self, bot: bot_impl.GatewayBot): bot._closed_event = mock.Mock() with pytest.raises(errors.ComponentStateConflictError): await bot.close() @pytest.mark.asyncio - async def test_close_when_already_closing(self, bot): + async def test_close_when_already_closing(self, bot: bot_impl.GatewayBot): bot._closed_event = mock.Mock() bot._closing_event = mock.Mock(is_set=mock.Mock(return_value=True)) @@ -364,12 +365,20 @@ async def test_close_when_already_closing(self, bot): bot._closed_event.set.assert_not_called() @pytest.mark.asyncio - async def test_close(self, bot, event_manager, event_factory, rest, voice, cache): - def null_call(arg): + async def test_close( + self, + bot: bot_impl.GatewayBot, + event_manager: event_manager_impl.EventManagerImpl, + event_factory: event_factory_impl.EventFactoryImpl, + rest: rest_impl.RESTClientImpl, + voice: voice_impl.VoiceComponentImpl, + cache: cache_impl.CacheImpl, + ): + def null_call(arg: typing.Any): return arg class AwaitableMock: - def __init__(self, error=None): + def __init__(self, error: typing.Any = None): self._awaited_count = 0 self._error = error @@ -388,13 +397,6 @@ def __call__(self): def assert_awaited_once(self): assert self._awaited_count == 1 - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(asyncio, "as_completed", side_effect=null_call)) - ensure_future = stack.enter_context(mock.patch.object(asyncio, "ensure_future", side_effect=null_call)) - get_running_loop = stack.enter_context(mock.patch.object(asyncio, "get_running_loop")) - mock_future = mock.Mock() - get_running_loop.return_value.create_future.return_value = mock_future - event_manager.dispatch = mock.AsyncMock() rest.close = AwaitableMock() voice.close = AwaitableMock() @@ -406,7 +408,15 @@ def assert_awaited_once(self): shard2 = mock.Mock(id=2, close=AwaitableMock()) bot._shards = {0: shard0, 1: shard1, 2: shard2} - with stack: + mock_future = mock.Mock() + + with ( + mock.patch.object(asyncio, "as_completed", side_effect=null_call), + mock.patch.object(asyncio, "ensure_future", side_effect=null_call) as ensure_future, + mock.patch.object( + asyncio, "get_running_loop", return_value=mock.Mock(create_future=mock.Mock(return_value=mock_future)) + ) as get_running_loop, + ): await bot.close() # Events and args @@ -450,22 +460,24 @@ def assert_awaited_once(self): ] ) - def test_dispatch(self, bot, event_manager): + def test_dispatch(self, bot: bot_impl.GatewayBot, event_manager: event_manager_impl.EventManagerImpl): event = mock.Mock() - assert bot.dispatch(event, return_tasks=True) is event_manager.dispatch.return_value + with mock.patch.object(event_manager, "dispatch") as patched_dispatch: + assert bot.dispatch(event, return_tasks=True) is patched_dispatch.return_value - event_manager.dispatch.assert_called_once_with(event, return_tasks=True) + patched_dispatch.assert_called_once_with(event, return_tasks=True) - def test_get_listeners(self, bot, event_manager): - event = object() + def test_get_listeners(self, bot: bot_impl.GatewayBot, event_manager: event_manager_impl.EventManagerImpl): + event = mock.Mock - assert bot.get_listeners(event, polymorphic=False) is event_manager.get_listeners.return_value + with mock.patch.object(event_manager, "get_listeners") as patched_get_listeners: + assert bot.get_listeners(event, polymorphic=False) is patched_get_listeners.return_value - event_manager.get_listeners.assert_called_once_with(event, polymorphic=False) + patched_get_listeners.assert_called_once_with(event, polymorphic=False) @pytest.mark.asyncio - async def test_join(self, bot, event_manager): + async def test_join(self, bot: bot_impl.GatewayBot): bot._closed_event = mock.AsyncMock() await bot.join() @@ -473,20 +485,20 @@ async def test_join(self, bot, event_manager): bot._closed_event.wait.assert_awaited_once_with() @pytest.mark.asyncio - async def test_join_when_not_running(self, bot, event_manager): + async def test_join_when_not_running(self, bot: bot_impl.GatewayBot): bot._closed_event = None with pytest.raises(errors.ComponentStateConflictError): await bot.join() - def test_listen(self, bot, event_manager): - event = object() + def test_listen(self, bot: bot_impl.GatewayBot, event_manager: event_manager_impl.EventManagerImpl): + event = mock.Mock assert bot.listen(event) is event_manager.listen.return_value event_manager.listen.assert_called_once_with(event) - def test_print_banner(self, bot): + def test_print_banner(self, bot: bot_impl.GatewayBot): with mock.patch.object(ux, "print_banner") as print_banner: bot.print_banner("testing", allow_color=True, force_color=False, extra_args={"test_key": "test_value"}) @@ -494,100 +506,92 @@ def test_print_banner(self, bot): "testing", allow_color=True, force_color=False, extra_args={"test_key": "test_value"} ) - def test_run_when_already_running(self, bot): - bot._closed_event = object() + def test_run_when_already_running(self, bot: bot_impl.GatewayBot): + bot._closed_event = mock.Mock() with pytest.raises(errors.ComponentStateConflictError): bot.run() - def test_run_when_shard_ids_specified_without_shard_count(self, bot): + def test_run_when_shard_ids_specified_without_shard_count(self, bot: bot_impl.GatewayBot): with pytest.raises(TypeError, match=r"'shard_ids' must be passed with 'shard_count'"): - bot.run(shard_ids={1}) - - def test_run_with_asyncio_debug(self, bot): - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "start", new=mock.Mock())) - stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "join", new=mock.Mock())) - stack.enter_context( - mock.patch.object(signals, "handle_interrupts", return_value=hikari_test_helpers.ContextManagerMock()) - ) - loop = stack.enter_context(mock.patch.object(aio, "get_or_make_loop")).return_value + bot.run(shard_ids=[1]) + + def test_run_with_asyncio_debug(self, bot: bot_impl.GatewayBot): + loop = mock.Mock() - with stack: + with ( + mock.patch.object(bot_impl.GatewayBot, "start", new=mock.Mock()), + mock.patch.object(bot_impl.GatewayBot, "join", new=mock.Mock()), + mock.patch.object(signals, "handle_interrupts", return_value=hikari_test_helpers.ContextManagerMock()), + mock.patch.object(aio, "get_or_make_loop", return_value=loop), + ): bot.run(close_loop=False, asyncio_debug=True) loop.set_debug.assert_called_once_with(True) - def test_run_with_coroutine_tracking_depth(self, bot): - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "start", new=mock.Mock())) - stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "join", new=mock.Mock())) - stack.enter_context( - mock.patch.object(signals, "handle_interrupts", return_value=hikari_test_helpers.ContextManagerMock()) - ) - stack.enter_context(mock.patch.object(aio, "get_or_make_loop")) - coroutine_tracking_depth = stack.enter_context( - mock.patch.object(sys, "set_coroutine_origin_tracking_depth", create=True, side_effect=AttributeError) - ) - - with stack: + def test_run_with_coroutine_tracking_depth(self, bot: bot_impl.GatewayBot): + with ( + mock.patch.object(bot_impl.GatewayBot, "start", new=mock.Mock()), + mock.patch.object(bot_impl.GatewayBot, "join", new=mock.Mock()), + mock.patch.object(signals, "handle_interrupts", return_value=hikari_test_helpers.ContextManagerMock()), + mock.patch.object(aio, "get_or_make_loop"), + mock.patch.object( + sys, "set_coroutine_origin_tracking_depth", create=True, side_effect=AttributeError + ) as coroutine_tracking_depth, + ): bot.run(close_loop=False, coroutine_tracking_depth=100) coroutine_tracking_depth.assert_called_once_with(100) - def test_run_with_close_passed_executor(self, bot): - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "start", new=mock.Mock())) - stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "join", new=mock.Mock())) - stack.enter_context( - mock.patch.object(signals, "handle_interrupts", return_value=hikari_test_helpers.ContextManagerMock()) - ) - stack.enter_context(mock.patch.object(aio, "get_or_make_loop")) + def test_run_with_close_passed_executor(self, bot: bot_impl.GatewayBot): executor = mock.Mock() bot._executor = executor - with stack: + with ( + mock.patch.object(bot_impl.GatewayBot, "start", new=mock.Mock()), + mock.patch.object(bot_impl.GatewayBot, "join", new=mock.Mock()), + mock.patch.object(signals, "handle_interrupts", return_value=hikari_test_helpers.ContextManagerMock()), + mock.patch.object(aio, "get_or_make_loop"), + ): bot.run(close_loop=False, close_passed_executor=True) executor.shutdown.assert_called_once_with(wait=True) assert bot._executor is None - def test_run_when_close_loop(self, bot): - stack = contextlib.ExitStack() - logger = stack.enter_context(mock.patch.object(bot_impl, "_LOGGER")) - stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "start", new=mock.Mock())) - stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "join", new=mock.Mock())) - stack.enter_context( - mock.patch.object(signals, "handle_interrupts", return_value=hikari_test_helpers.ContextManagerMock()) - ) - destroy_loop = stack.enter_context(mock.patch.object(aio, "destroy_loop")) - loop = stack.enter_context(mock.patch.object(aio, "get_or_make_loop")).return_value - - with stack: + def test_run_when_close_loop(self, bot: bot_impl.GatewayBot): + loop = mock.Mock() + with ( + mock.patch.object(bot_impl, "_LOGGER") as logger, + mock.patch.object(bot_impl.GatewayBot, "start", new=mock.Mock()), + mock.patch.object(bot_impl.GatewayBot, "join", new=mock.Mock()), + mock.patch.object(signals, "handle_interrupts", return_value=hikari_test_helpers.ContextManagerMock()), + mock.patch.object(aio, "destroy_loop") as destroy_loop, + mock.patch.object(aio, "get_or_make_loop", return_value=loop), + ): bot.run(close_loop=True) destroy_loop.assert_called_once_with(loop, logger) - def test_run(self, bot): - activity = object() - afk = object() - check_for_updates = object() - idle_since = object() - ignore_session_start_limit = object() - large_threshold = object() - shard_ids = object() - shard_count = object() - status = object() - - stack = contextlib.ExitStack() - start_function = stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "start", new=mock.Mock())) - join_function = stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "join", new=mock.Mock())) - handle_interrupts = stack.enter_context( - mock.patch.object(signals, "handle_interrupts", return_value=hikari_test_helpers.ContextManagerMock()) - ) - loop = stack.enter_context(mock.patch.object(aio, "get_or_make_loop")).return_value - - with stack: + def test_run(self, bot: bot_impl.GatewayBot): + activity = mock.Mock() + afk = mock.Mock() + check_for_updates = mock.Mock() + idle_since = mock.Mock() + ignore_session_start_limit = mock.Mock() + large_threshold = mock.Mock() + shard_ids = mock.Mock() + shard_count = mock.Mock() + status = mock.Mock() + loop = mock.Mock() + + with ( + mock.patch.object(bot_impl.GatewayBot, "start", new=mock.Mock()) as start_function, + mock.patch.object(bot_impl.GatewayBot, "join", new=mock.Mock()) as join_function, + mock.patch.object( + signals, "handle_interrupts", return_value=hikari_test_helpers.ContextManagerMock() + ) as handle_interrupts, + mock.patch.object(aio, "get_or_make_loop", return_value=loop), + ): bot.run( activity=activity, afk=afk, @@ -626,19 +630,28 @@ def test_run(self, bot): handle_interrupts.return_value.assert_used_once() @pytest.mark.asyncio - async def test_start_when_shard_ids_specified_without_shard_count(self, bot): + async def test_start_when_shard_ids_specified_without_shard_count(self, bot: bot_impl.GatewayBot): with pytest.raises(TypeError, match=r"'shard_ids' must be passed with 'shard_count'"): await bot.start(shard_ids=(1,)) @pytest.mark.asyncio - async def test_start_when_already_running(self, bot): - bot._closed_event = object() + async def test_start_when_already_running(self, bot: bot_impl.GatewayBot): + bot._closed_event = mock.Mock() with pytest.raises(errors.ComponentStateConflictError): await bot.start() @pytest.mark.asyncio - async def test_start(self, bot, rest, voice, event_manager, event_factory, http_settings, proxy_settings): + async def test_start( + self, + bot: bot_impl.GatewayBot, + rest: rest_impl.RESTClientImpl, + voice: voice_impl.VoiceComponentImpl, + event_manager: event_manager_impl.EventManagerImpl, + event_factory: event_factory_impl.EventFactoryImpl, + http_settings: config.HTTPSettings, + proxy_settings: config.ProxySettings, + ): class MockSessionStartLimit: remaining = 10 reset_at = "now" @@ -655,33 +668,32 @@ class MockInfo: mock_start_one_shard = mock.Mock() - def _mock_start_one_shard(*args, **kwargs): + def _mock_start_one_shard(*args: typing.Any, **kwargs: typing.Any): bot._shards[kwargs["shard_id"]] = next(shards_iter) return mock_start_one_shard(*args, **kwargs) - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(bot_impl, "_validate_activity")) - stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "_start_one_shard", new=_mock_start_one_shard)) - create_task = stack.enter_context(mock.patch.object(asyncio, "create_task")) - gather = stack.enter_context(mock.patch.object(asyncio, "gather")) - event = stack.enter_context(mock.patch.object(asyncio, "Event")) - first_completed = stack.enter_context( - mock.patch.object(aio, "first_completed", side_effect=[None, asyncio.TimeoutError, None]) - ) - check_for_updates = stack.enter_context(mock.patch.object(ux, "check_for_updates", new=mock.Mock())) - event_manager.dispatch = mock.AsyncMock() rest.fetch_gateway_bot_info = mock.AsyncMock(return_value=MockInfo()) - with stack: + with ( + mock.patch.object(bot_impl, "_validate_activity"), + mock.patch.object(bot_impl.GatewayBot, "_start_one_shard", new=_mock_start_one_shard), + mock.patch.object(asyncio, "create_task") as create_task, + mock.patch.object(asyncio, "gather") as gather, + mock.patch.object(asyncio, "Event") as event, + mock.patch.object( + aio, "first_completed", side_effect=[None, asyncio.TimeoutError, None] + ) as first_completed, + mock.patch.object(ux, "check_for_updates", new=mock.Mock()) as check_for_updates, + ): await bot.start( check_for_updates=True, shard_ids=(2, 10), shard_count=20, - activity="some activity", + activity=presences.Activity(name="some activity"), afk=True, - idle_since="some idle since", - status="some status", + idle_since=datetime.datetime.fromtimestamp(0), + status=presences.Status.IDLE, large_threshold=500, startup_window_delay=10, ) @@ -709,10 +721,10 @@ def _mock_start_one_shard(*args, **kwargs): [ mock.call( bot, - activity="some activity", + activity=presences.Activity(name="some activity"), afk=True, - idle_since="some idle since", - status="some status", + idle_since=datetime.datetime.fromtimestamp(0), + status=presences.Status.IDLE, large_threshold=500, shard_id=i, shard_count=20, @@ -736,7 +748,12 @@ def _mock_start_one_shard(*args, **kwargs): ) @pytest.mark.asyncio - async def test_start_when_request_close_mid_startup(self, bot, rest, voice, event_manager, event_factory): + async def test_start_when_request_close_mid_startup( + self, + bot: bot_impl.GatewayBot, + rest: rest_impl.RESTClientImpl, + event_manager: event_manager_impl.EventManagerImpl, + ): class MockSessionStartLimit: remaining = 10 reset_at = "now" @@ -749,19 +766,16 @@ class MockInfo: # Assume that we already started one shard shard1 = mock.Mock() - bot._shards = {"1": shard1} - - stack = contextlib.ExitStack() - start_one_shard = stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "_start_one_shard")) - first_completed = stack.enter_context(mock.patch.object(aio, "first_completed")) - event = stack.enter_context( - mock.patch.object(asyncio, "Event", return_value=mock.Mock(is_set=mock.Mock(return_value=True))) - ) + bot._shards = {1: shard1} event_manager.dispatch = mock.AsyncMock() rest.fetch_gateway_bot_info = mock.AsyncMock(return_value=MockInfo()) - with stack: + with ( + mock.patch.object(bot_impl.GatewayBot, "_start_one_shard") as start_one_shard, + mock.patch.object(aio, "first_completed") as first_completed, + mock.patch.object(asyncio, "Event", return_value=mock.Mock(is_set=mock.Mock(return_value=True))) as event, + ): await bot.start(shard_ids=(2, 10), shard_count=20, check_for_updates=False) start_one_shard.assert_not_called() @@ -771,7 +785,14 @@ class MockInfo: ) @pytest.mark.asyncio - async def test_start_when_shard_closed_mid_startup(self, bot, rest, voice, event_manager, event_factory): + async def test_start_when_shard_closed_mid_startup( + self, + bot: bot_impl.GatewayBot, + rest: rest_impl.RESTClientImpl, + voice: voice_impl.VoiceComponentImpl, + event_manager: event_manager_impl.EventManagerImpl, + event_factory: event_factory_impl.EventFactoryImpl, + ): class MockSessionStartLimit: remaining = 10 reset_at = "now" @@ -784,20 +805,17 @@ class MockInfo: # Assume that we already started one shard shard1 = mock.Mock() - bot._shards = {"1": shard1} - - stack = contextlib.ExitStack() - start_one_shard = stack.enter_context(mock.patch.object(bot_impl.GatewayBot, "_start_one_shard")) - first_completed = stack.enter_context(mock.patch.object(aio, "first_completed")) - event = stack.enter_context( - mock.patch.object(asyncio, "Event", return_value=mock.Mock(is_set=mock.Mock(return_value=False))) - ) - stack.enter_context(pytest.raises(RuntimeError, match="One or more shards closed while starting")) + bot._shards = {1: shard1} event_manager.dispatch = mock.AsyncMock() rest.fetch_gateway_bot_info = mock.AsyncMock(return_value=MockInfo()) - with stack: + with ( + mock.patch.object(bot_impl.GatewayBot, "_start_one_shard") as start_one_shard, + mock.patch.object(aio, "first_completed") as first_completed, + mock.patch.object(asyncio, "Event", return_value=mock.Mock(is_set=mock.Mock(return_value=False))) as event, + pytest.raises(RuntimeError, match="One or more shards closed while starting"), + ): await bot.start(shard_ids=(2, 10), shard_count=20, check_for_updates=False) start_one_shard.assert_not_called() @@ -806,35 +824,41 @@ class MockInfo: event.return_value.wait.return_value, shard1.join.return_value, timeout=5 ) - def test_stream(self, bot): - event_type = object() + def test_stream(self, bot: bot_impl.GatewayBot): + event_type = mock.Mock - with mock.patch.object(bot_impl.GatewayBot, "_check_if_alive") as check_if_alive: + with ( + mock.patch.object(bot, "_event_manager") as patched__event_manager, + mock.patch.object(patched__event_manager, "stream") as patched_stream, + mock.patch.object(bot_impl.GatewayBot, "_check_if_alive") as patched_check_if_alive, + ): bot.stream(event_type, timeout=100, limit=400) - check_if_alive.assert_called_once_with() - bot._event_manager.stream.assert_called_once_with(event_type, timeout=100, limit=400) + patched_check_if_alive.assert_called_once_with() + patched_stream.assert_called_once_with(event_type, timeout=100, limit=400) - def test_subscribe(self, bot): - event_type = object() - callback = object() + def test_subscribe(self, bot: bot_impl.GatewayBot): + event_type = mock.Mock() + callback = mock.Mock() - bot.subscribe(event_type, callback) + with mock.patch.object(bot._event_manager, "subscribe") as patched_subscribe: + bot.subscribe(event_type, callback) - bot._event_manager.subscribe.assert_called_once_with(event_type, callback) + patched_subscribe.assert_called_once_with(event_type, callback) - def test_unsubscribe(self, bot): - event_type = object() - callback = object() + def test_unsubscribe(self, bot: bot_impl.GatewayBot): + event_type = mock.Mock() + callback = mock.Mock() - bot.unsubscribe(event_type, callback) + with mock.patch.object(bot._event_manager, "unsubscribe") as patched_unsubscribe: + bot.unsubscribe(event_type, callback) - bot._event_manager.unsubscribe.assert_called_once_with(event_type, callback) + patched_unsubscribe.assert_called_once_with(event_type, callback) @pytest.mark.asyncio - async def test_wait_for(self, bot): - event_type = object() - predicate = object() + async def test_wait_for(self, bot: bot_impl.GatewayBot): + event_type = mock.Mock + predicate = mock.Mock() bot._event_manager.wait_for = mock.AsyncMock() with mock.patch.object(bot_impl.GatewayBot, "_check_if_alive") as check_if_alive: @@ -843,7 +867,7 @@ async def test_wait_for(self, bot): check_if_alive.assert_called_once_with() bot._event_manager.wait_for.assert_awaited_once_with(event_type, timeout=100, predicate=predicate) - def test_get_shard_when_not_present(self, bot): + def test_get_shard_when_not_present(self, bot: bot_impl.GatewayBot): shard = mock.Mock(shard_count=96) bot._shards = {96: shard} @@ -855,7 +879,7 @@ def test_get_shard_when_not_present(self, bot): calculate_shard_id.assert_called_once_with(96, 702763150025556029) - def test_get_shard(self, bot): + def test_get_shard(self, bot: bot_impl.GatewayBot): shard = mock.Mock(shard_count=96) bot._shards = {96: shard} @@ -865,11 +889,11 @@ def test_get_shard(self, bot): calculate_shard_id.assert_called_once_with(96, 702763150025556029) @pytest.mark.asyncio - async def test_update_presence(self, bot): - status = object() - activity = object() - idle_since = object() - afk = object() + async def test_update_presence(self, bot: bot_impl.GatewayBot): + status = mock.Mock() + activity = mock.Mock() + idle_since = mock.Mock() + afk = mock.Mock() shard0 = mock.Mock() shard1 = mock.Mock() @@ -893,7 +917,7 @@ async def test_update_presence(self, bot): shard2.update_presence.assert_called_once_with(status=status, activity=activity, idle_since=idle_since, afk=afk) @pytest.mark.asyncio - async def test_update_voice_state(self, bot): + async def test_update_voice_state(self, bot: bot_impl.GatewayBot): shard = mock.Mock() shard.update_voice_state = mock.AsyncMock() @@ -908,7 +932,7 @@ async def test_update_voice_state(self, bot): ) @pytest.mark.asyncio - async def test_request_guild_members(self, bot): + async def test_request_guild_members(self, bot: bot_impl.GatewayBot): shard = mock.Mock(shard_count=3) shard.request_guild_members = mock.AsyncMock() @@ -925,9 +949,9 @@ async def test_request_guild_members(self, bot): ) @pytest.mark.asyncio - async def test_start_one_shard(self, bot): - activity = object() - status = object() + async def test_start_one_shard(self, bot: bot_impl.GatewayBot): + activity = mock.Mock() + status = mock.Mock() bot._shards = {} shard_obj = mock.Mock(is_alive=True, start=mock.AsyncMock()) @@ -965,9 +989,9 @@ async def test_start_one_shard(self, bot): assert bot._shards == {1: shard_obj} @pytest.mark.asyncio - async def test_start_one_shard_when_not_alive(self, bot): - activity = object() - status = object() + async def test_start_one_shard_when_not_alive(self, bot: bot_impl.GatewayBot): + activity = mock.Mock() + status = mock.Mock() bot._shards = {} shard_obj = mock.Mock(is_alive=False, start=mock.AsyncMock()) @@ -989,9 +1013,9 @@ async def test_start_one_shard_when_not_alive(self, bot): @pytest.mark.parametrize("is_alive", [True, False]) @pytest.mark.asyncio - async def test_start_one_shard_when_exception(self, bot, is_alive): - activity = object() - status = object() + async def test_start_one_shard_when_exception(self, bot: bot_impl.GatewayBot, is_alive: bool): + activity = mock.Mock() + status = mock.Mock() bot._shards = {} shard_obj = mock.Mock( is_alive=is_alive, start=mock.AsyncMock(side_effect=RuntimeError("exit in tests")), close=mock.AsyncMock() diff --git a/tests/hikari/impl/test_interaction_server.py b/tests/hikari/impl/test_interaction_server.py index b5c816f588..ac7837949b 100644 --- a/tests/hikari/impl/test_interaction_server.py +++ b/tests/hikari/impl/test_interaction_server.py @@ -21,9 +21,9 @@ from __future__ import annotations import asyncio -import contextlib import re import threading +import typing import aiohttp import aiohttp.abc @@ -34,6 +34,8 @@ from hikari import applications from hikari import files +from hikari import snowflakes +from hikari.api import entity_factory try: import nacl.exceptions @@ -167,7 +169,7 @@ def valid_edd25519(): @pytest.fixture -def valid_payload(): +def valid_payload() -> dict[str, typing.Any]: return { "application_id": "658822586720976907", "channel_id": "938391701561679903", @@ -207,7 +209,7 @@ def valid_payload(): @pytest.fixture -def invalid_ed25519(): +def invalid_ed25519() -> tuple[bytes, bytes, bytes]: body = ( b'{"application_id":"658822586720976907","id":"838085779104202754","token":"aW50ZXJhY3Rpb246ODM4MDg1Nzc5MTA0MjA' b"yNzU0OmNhSk9QUU4wa1BKV21nTjFvSGhIbUp0QnQ1NjNGZFRtMlJVRlNjR0ttaDhtUGJrWUNvcmxYZnd2NTRLeUQ2c0hGS1YzTU03dFJ0V0s5" @@ -224,7 +226,7 @@ def invalid_ed25519(): @pytest.fixture -def public_key(): +def public_key() -> bytes: return b"\x12-\xdfX\xa8\x95\xd7\xe1\xb7o\xf5\xd0q\xb0\xaa\xc9\xb7v^*\xb5\x15\xe1\x1b\x7f\xca\xf9d\xdbT\x90\xc6" @@ -242,7 +244,7 @@ def test_interaction_server_init_when_no_pynacl(): @pytest.mark.skipif(not nacl_present, reason="PyNacl not present") class TestInteractionServer: @pytest.fixture - def mock_entity_factory(self): + def mock_entity_factory(self) -> entity_factory.EntityFactory: return mock.Mock(entity_factory_impl.EntityFactoryImpl) @pytest.fixture @@ -251,25 +253,20 @@ def mock_rest_client(self): @pytest.fixture def mock_interaction_server( - self, mock_entity_factory: entity_factory_impl.EntityFactoryImpl, mock_rest_client: rest_impl.RESTClientImpl + self, mock_entity_factory: entity_factory.EntityFactory, mock_rest_client: rest_impl.RESTClientImpl ): cls = hikari_test_helpers.mock_class_namespace(interaction_server_impl.InteractionServer, slots_=False) - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(rest_impl, "RESTClientImpl", return_value=mock_rest_client)) - with stack: + with mock.patch.object(rest_impl, "RESTClientImpl", return_value=mock_rest_client): return cls(entity_factory=mock_entity_factory, rest_client=mock_rest_client) def test___init__( self, mock_rest_client: rest_impl.RESTClientImpl, mock_entity_factory: entity_factory_impl.EntityFactoryImpl ): - mock_dumps = object() - mock_loads = object() + mock_dumps = mock.Mock() + mock_loads = mock.Mock() - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(aiohttp.web, "Application")) - - with stack: + with mock.patch.object(aiohttp.web, "Application"): result = interaction_server_impl.InteractionServer( dumps=mock_dumps, entity_factory=mock_entity_factory, @@ -287,13 +284,10 @@ def test___init__( def test___init___with_public_key( self, mock_rest_client: rest_impl.RESTClientImpl, mock_entity_factory: entity_factory_impl.EntityFactoryImpl ): - mock_dumps = object() - mock_loads = object() - - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(aiohttp.web, "Application")) + mock_dumps = mock.Mock() + mock_loads = mock.Mock() - with stack: + with mock.patch.object(aiohttp.web, "Application"): result = interaction_server_impl.InteractionServer( dumps=mock_dumps, entity_factory=mock_entity_factory, @@ -308,7 +302,7 @@ def test_is_alive_property_when_inactive(self, mock_interaction_server: interact assert mock_interaction_server.is_alive is False def test_is_alive_property_when_active(self, mock_interaction_server: interaction_server_impl.InteractionServer): - mock_interaction_server._server = object() + mock_interaction_server._server = mock.Mock() assert mock_interaction_server.is_alive is True @@ -318,14 +312,13 @@ async def test___fetch_public_key_when_lock_is_None_gets_new_lock_and_doesnt_ove mock_interaction_server: interaction_server_impl.InteractionServer, mock_rest_client: rest_impl.RESTClientImpl, ): - mock_rest_client.token_type = "Bot" mock_interaction_server._application_fetch_lock = None mock_rest_client.fetch_application.return_value.public_key = ( b"e\xb9\xf8\xac]eH\xb1\xe1D\xafaW\xdd\x1c.\xc1s\xfd<\x82\t\xeaO\xd4w\xaf\xc4\x1b\xd0\x8f\xc5" ) results = [] - with mock.patch.object(asyncio, "Lock") as lock_class: + with mock.patch.object(mock_rest_client, "token_type", "Bot"), mock.patch.object(asyncio, "Lock") as lock_class: # Run some times to make sure it does not overwrite it for _ in range(5): results.append(await mock_interaction_server._fetch_public_key()) @@ -384,7 +377,7 @@ async def test__fetch_public_key_when_public_key_already_set( self, mock_interaction_server: interaction_server_impl.InteractionServer ): mock_lock = mock.AsyncMock() - mock_public_key = object() + mock_public_key = mock.Mock() mock_interaction_server._application_fetch_lock = mock_lock mock_interaction_server._public_key = mock_public_key @@ -623,7 +616,7 @@ async def test_close_when_closing(self, mock_interaction_server: interaction_ser mock_interaction_server._close_event = mock_event mock_interaction_server._is_closing = True mock_interaction_server.join = mock.AsyncMock() - mock_listener = object() + mock_listener = mock.Mock() mock_interaction_server._running_generator_listeners = [mock_listener] await mock_interaction_server.close() @@ -640,9 +633,9 @@ async def test_close_when_not_running(self, mock_interaction_server: interaction await mock_interaction_server.close() @pytest.mark.asyncio - async def test_join(self, mock_interaction_server): + async def test_join(self, mock_interaction_server: interaction_server_impl.InteractionServer): mock_event = mock.AsyncMock() - mock_interaction_server._server = object() + mock_interaction_server._server = mock.Mock() mock_interaction_server._close_event = mock_event await mock_interaction_server.join() @@ -668,8 +661,8 @@ async def test_on_interaction( mock_file_2 = mock.Mock() mock_entity_factory.deserialize_interaction.return_value = base_interactions.PartialInteraction( app=None, - id=123, - application_id=541324, + id=snowflakes.Snowflake(123), + application_id=snowflakes.Snowflake(541324), type=2, token="ok", version=1, @@ -723,9 +716,9 @@ async def mock_generator_listener(event): mock_file_1 = mock.Mock() mock_file_2 = mock.Mock() mock_entity_factory.deserialize_interaction.return_value = base_interactions.PartialInteraction( - app=None, - id=123, - application_id=541324, + app=mock.Mock(), + id=snowflakes.Snowflake(123), + application_id=snowflakes.Snowflake(541324), type=2, token="ok", version=1, @@ -894,9 +887,11 @@ async def test_on_interaction_on_deserialize_unrecognised_entity_error( mock_entity_factory: entity_factory_impl.EntityFactoryImpl, ): mock_interaction_server._public_key = mock.Mock() - mock_entity_factory.deserialize_interaction.side_effect = errors.UnrecognisedEntityError("blah") - result = await mock_interaction_server.on_interaction(b'{"type": 2}', b"signature", b"timestamp") + with mock.patch.object( + mock_entity_factory, "deserialize_interaction", side_effect=errors.UnrecognisedEntityError("blah") + ): + result = await mock_interaction_server.on_interaction(b'{"type": 2}', b"signature", b"timestamp") assert result.content_type == "text/plain" assert result.charset == "UTF-8" @@ -913,19 +908,13 @@ async def test_on_interaction_on_failed_deserialize( ): mock_interaction_server._public_key = mock.Mock() mock_exception = TypeError("OK") - mock_entity_factory.deserialize_interaction.side_effect = mock_exception - with mock.patch.object(asyncio, "get_running_loop") as get_running_loop: + with ( + mock.patch.object(mock_entity_factory, "deserialize_interaction", side_effect=mock_exception), + mock.patch.object(asyncio, "get_running_loop") as get_running_loop, + ): result = await mock_interaction_server.on_interaction(b'{"type": 2}', b"signature", b"timestamp") - get_running_loop.return_value.call_exception_handler.assert_called_once_with( - { - "message": "Exception occurred during interaction deserialization", - "payload": {"type": 2}, - "exception": mock_exception, - } - ) - assert result.content_type == "text/plain" assert result.charset == "UTF-8" assert result.files == () @@ -933,6 +922,14 @@ async def test_on_interaction_on_failed_deserialize( assert result.payload == b"Exception occurred during interaction deserialization" assert result.status_code == 500 + get_running_loop.return_value.call_exception_handler.assert_called_once_with( + { + "message": "Exception occurred during interaction deserialization", + "payload": {"type": 2}, + "exception": mock_exception, + } + ) + @pytest.mark.asyncio async def test_on_interaction_on_dispatch_error( self, @@ -940,7 +937,6 @@ async def test_on_interaction_on_dispatch_error( mock_entity_factory: entity_factory_impl.EntityFactoryImpl, ): mock_interaction_server._public_key = mock.Mock() - mock_exception = TypeError("OK") mock_entity_factory.deserialize_interaction.return_value = base_interactions.PartialInteraction( app=None, id=123, @@ -959,6 +955,7 @@ async def test_on_interaction_on_dispatch_error( locale="es-ES", entitlements=[], ) + mock_exception = TypeError("OK") mock_interaction_server.set_listener( base_interactions.PartialInteraction, mock.Mock(side_effect=mock_exception) ) @@ -966,10 +963,6 @@ async def test_on_interaction_on_dispatch_error( with mock.patch.object(asyncio, "get_running_loop") as get_running_loop: result = await mock_interaction_server.on_interaction(b'{"type": 2}', b"signature", b"timestamp") - get_running_loop.return_value.call_exception_handler.assert_called_once_with( - {"message": "Exception occurred during interaction dispatch", "exception": mock_exception} - ) - assert result.content_type == "text/plain" assert result.charset == "UTF-8" assert result.files == () @@ -977,6 +970,10 @@ async def test_on_interaction_on_dispatch_error( assert result.payload == b"Exception occurred during interaction dispatch" assert result.status_code == 500 + get_running_loop.return_value.call_exception_handler.assert_called_once_with( + {"message": "Exception occurred during interaction dispatch", "exception": mock_exception} + ) + @pytest.mark.asyncio async def test_on_interaction_when_response_builder_error( self, @@ -984,7 +981,6 @@ async def test_on_interaction_when_response_builder_error( mock_entity_factory: entity_factory_impl.EntityFactoryImpl, ): mock_interaction_server._public_key = mock.Mock() - mock_exception = TypeError("OK") mock_entity_factory.deserialize_interaction.return_value = base_interactions.PartialInteraction( app=None, id=123, @@ -1003,6 +999,7 @@ async def test_on_interaction_when_response_builder_error( locale="es-ES", entitlements=[], ) + mock_exception = TypeError("OK") mock_builder = mock.Mock(build=mock.Mock(side_effect=mock_exception)) mock_interaction_server.set_listener( base_interactions.PartialInteraction, mock.AsyncMock(return_value=mock_builder) @@ -1011,10 +1008,6 @@ async def test_on_interaction_when_response_builder_error( with mock.patch.object(asyncio, "get_running_loop") as get_running_loop: result = await mock_interaction_server.on_interaction(b'{"type": 2}', b"signature", b"timestamp") - get_running_loop.return_value.call_exception_handler.assert_called_once_with( - {"message": "Exception occurred during interaction dispatch", "exception": mock_exception} - ) - assert result.content_type == "text/plain" assert result.charset == "UTF-8" assert result.files == () @@ -1022,6 +1015,10 @@ async def test_on_interaction_when_response_builder_error( assert result.payload == b"Exception occurred during interaction dispatch" assert result.status_code == 500 + get_running_loop.return_value.call_exception_handler.assert_called_once_with( + {"message": "Exception occurred during interaction dispatch", "exception": mock_exception} + ) + @pytest.mark.asyncio async def test_on_interaction_when_json_encode_fails( self, @@ -1029,8 +1026,6 @@ async def test_on_interaction_when_json_encode_fails( mock_entity_factory: entity_factory_impl.EntityFactoryImpl, ): mock_interaction_server._public_key = mock.Mock() - mock_exception = TypeError("OK") - mock_interaction_server._dumps = mock.Mock(side_effect=mock_exception) mock_entity_factory.deserialize_interaction.return_value = base_interactions.PartialInteraction( app=None, id=123, @@ -1049,6 +1044,8 @@ async def test_on_interaction_when_json_encode_fails( locale="es-ES", entitlements=[], ) + mock_exception = TypeError("OK") + mock_interaction_server._dumps = mock.Mock(side_effect=mock_exception) mock_builder = mock.Mock(build=mock.Mock(return_value=({"ok": "No"}, []))) mock_interaction_server.set_listener( base_interactions.PartialInteraction, mock.AsyncMock(return_value=mock_builder) @@ -1057,10 +1054,6 @@ async def test_on_interaction_when_json_encode_fails( with mock.patch.object(asyncio, "get_running_loop") as get_running_loop: result = await mock_interaction_server.on_interaction(b'{"type": 2}', b"signature", b"timestamp") - get_running_loop.return_value.call_exception_handler.assert_called_once_with( - {"message": "Exception occurred during interaction dispatch", "exception": mock_exception} - ) - assert result.content_type == "text/plain" assert result.charset == "UTF-8" assert result.files == () @@ -1068,11 +1061,13 @@ async def test_on_interaction_when_json_encode_fails( assert result.payload == b"Exception occurred during interaction dispatch" assert result.status_code == 500 + get_running_loop.return_value.call_exception_handler.assert_called_once_with( + {"message": "Exception occurred during interaction dispatch", "exception": mock_exception} + ) + @pytest.mark.asyncio async def test_on_interaction_when_no_registered_listener( - self, - mock_interaction_server: interaction_server_impl.InteractionServer, - mock_entity_factory: entity_factory_impl.EntityFactoryImpl, + self, mock_interaction_server: interaction_server_impl.InteractionServer ): mock_interaction_server._public_key = mock.Mock() @@ -1087,19 +1082,19 @@ async def test_on_interaction_when_no_registered_listener( @pytest.mark.asyncio async def test_start(self, mock_interaction_server: interaction_server_impl.InteractionServer): - mock_context = object() - mock_socket = object() + mock_context = mock.Mock() + mock_socket = mock.Mock() mock_interaction_server._is_closing = True mock_interaction_server._fetch_public_key = mock.AsyncMock() - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(aiohttp.web, "TCPSite", return_value=mock.AsyncMock())) - stack.enter_context(mock.patch.object(aiohttp.web, "UnixSite", return_value=mock.AsyncMock())) - stack.enter_context(mock.patch.object(aiohttp.web, "SockSite", return_value=mock.AsyncMock())) - stack.enter_context(mock.patch.object(aiohttp.web_runner, "AppRunner", return_value=mock.AsyncMock())) - stack.enter_context(mock.patch.object(aiohttp.web, "Application")) - stack.enter_context(mock.patch.object(asyncio, "Event")) - - with stack: + + with ( + mock.patch.object(aiohttp.web, "TCPSite", return_value=mock.AsyncMock()) as web_tcp_site, + mock.patch.object(aiohttp.web, "UnixSite", return_value=mock.AsyncMock()) as web_unix_site, + mock.patch.object(aiohttp.web, "SockSite", return_value=mock.AsyncMock()) as web_sock_site, + mock.patch.object(aiohttp.web_runner, "AppRunner", return_value=mock.AsyncMock()) as web_app_runner, + mock.patch.object(aiohttp.web, "Application") as web_application, + mock.patch.object(asyncio, "Event") as event, + ): await mock_interaction_server.start( backlog=123123, host="hoototototo", @@ -1112,204 +1107,190 @@ async def test_start(self, mock_interaction_server: interaction_server_impl.Inte ssl_context=mock_context, ) - mock_interaction_server._fetch_public_key.assert_awaited_once_with() + mock_interaction_server._fetch_public_key.assert_awaited_once_with() - aiohttp.web.Application.assert_called_once_with() - aiohttp.web.Application.return_value.add_routes.assert_called_once_with( - [aiohttp.web.post("/", mock_interaction_server.aiohttp_hook)] - ) - aiohttp.web_runner.AppRunner.assert_called_once_with( - aiohttp.web.Application.return_value, access_log=interaction_server_impl._LOGGER - ) - aiohttp.web_runner.AppRunner.return_value.setup.assert_awaited_once() - aiohttp.web.TCPSite.assert_called_once_with( - aiohttp.web_runner.AppRunner.return_value, - "hoototototo", - port=123123123, - shutdown_timeout=3232.3232, - ssl_context=mock_context, - backlog=123123, - reuse_address=True, - reuse_port=False, - ) - aiohttp.web.UnixSite.assert_called_once_with( - aiohttp.web_runner.AppRunner.return_value, - "hshshshshsh", - shutdown_timeout=3232.3232, - ssl_context=mock_context, - backlog=123123, - ) - aiohttp.web.SockSite.assert_called_once_with( - aiohttp.web_runner.AppRunner.return_value, - mock_socket, - shutdown_timeout=3232.3232, - ssl_context=mock_context, - backlog=123123, - ) - aiohttp.web.TCPSite.return_value.start.assert_awaited_once() - aiohttp.web.UnixSite.return_value.start.assert_awaited_once() - aiohttp.web.SockSite.return_value.start.assert_awaited_once() - assert mock_interaction_server._close_event is asyncio.Event.return_value - assert mock_interaction_server._is_closing is False + web_application.assert_called_once_with() + web_application.return_value.add_routes.assert_called_once_with( + [aiohttp.web.post("/", mock_interaction_server.aiohttp_hook)] + ) + web_app_runner.assert_called_once_with(web_application.return_value, access_log=interaction_server_impl._LOGGER) + web_app_runner.return_value.setup.assert_awaited_once() + web_tcp_site.assert_called_once_with( + web_app_runner.return_value, + "hoototototo", + port=123123123, + shutdown_timeout=3232.3232, + ssl_context=mock_context, + backlog=123123, + reuse_address=True, + reuse_port=False, + ) + web_unix_site.assert_called_once_with( + web_app_runner.return_value, + "hshshshshsh", + shutdown_timeout=3232.3232, + ssl_context=mock_context, + backlog=123123, + ) + web_sock_site.assert_called_once_with( + web_app_runner.return_value, + mock_socket, + shutdown_timeout=3232.3232, + ssl_context=mock_context, + backlog=123123, + ) + web_tcp_site.return_value.start.assert_awaited_once() + web_unix_site.return_value.start.assert_awaited_once() + web_sock_site.return_value.start.assert_awaited_once() + assert mock_interaction_server._close_event is event.return_value + assert mock_interaction_server._is_closing is False @pytest.mark.asyncio async def test_start_with_default_behaviour( self, mock_interaction_server: interaction_server_impl.InteractionServer ): - mock_context = object() + mock_context = mock.Mock() mock_interaction_server._fetch_public_key = mock.AsyncMock() - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(aiohttp.web, "TCPSite", return_value=mock.AsyncMock())) - stack.enter_context(mock.patch.object(aiohttp.web_runner, "AppRunner", return_value=mock.AsyncMock())) - stack.enter_context(mock.patch.object(aiohttp.web, "Application")) - with stack: + with ( + mock.patch.object(aiohttp.web, "TCPSite", return_value=mock.AsyncMock()) as web_tcp_site, + mock.patch.object(aiohttp.web_runner, "AppRunner", return_value=mock.AsyncMock()) as web_app_runner, + mock.patch.object(aiohttp.web, "Application") as web_application, + ): await mock_interaction_server.start(ssl_context=mock_context) - mock_interaction_server._fetch_public_key.assert_awaited_once_with() + mock_interaction_server._fetch_public_key.assert_awaited_once_with() - aiohttp.web.Application.assert_called_once_with() - aiohttp.web.Application.return_value.add_routes.assert_called_once_with( - [aiohttp.web.post("/", mock_interaction_server.aiohttp_hook)] - ) - aiohttp.web_runner.AppRunner.assert_called_once_with( - aiohttp.web.Application.return_value, access_log=interaction_server_impl._LOGGER - ) - aiohttp.web_runner.AppRunner.return_value.setup.assert_awaited_once() - aiohttp.web.TCPSite.assert_called_once_with( - aiohttp.web_runner.AppRunner.return_value, - port=None, - shutdown_timeout=60.0, - ssl_context=mock_context, - backlog=128, - reuse_address=None, - reuse_port=None, - ) - aiohttp.web.TCPSite.return_value.start.assert_awaited_once() + web_application.assert_called_once_with() + web_application.return_value.add_routes.assert_called_once_with( + [aiohttp.web.post("/", mock_interaction_server.aiohttp_hook)] + ) + web_app_runner.assert_called_once_with(web_application.return_value, access_log=interaction_server_impl._LOGGER) + web_app_runner.return_value.setup.assert_awaited_once() + web_tcp_site.assert_called_once_with( + web_app_runner.return_value, + port=None, + shutdown_timeout=60.0, + ssl_context=mock_context, + backlog=128, + reuse_address=None, + reuse_port=None, + ) + web_tcp_site.return_value.start.assert_awaited_once() @pytest.mark.asyncio async def test_start_with_default_behaviour_and_not_main_thread( self, mock_interaction_server: interaction_server_impl.InteractionServer ): - mock_context = object() + mock_context = mock.Mock() mock_interaction_server._fetch_public_key = mock.AsyncMock() - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(aiohttp.web, "TCPSite", return_value=mock.AsyncMock())) - stack.enter_context(mock.patch.object(aiohttp.web_runner, "AppRunner", return_value=mock.AsyncMock())) - stack.enter_context(mock.patch.object(aiohttp.web, "Application")) - stack.enter_context(mock.patch.object(threading, "current_thread")) - with stack: + with ( + mock.patch.object(aiohttp.web, "TCPSite", return_value=mock.AsyncMock()) as web_tcp_site, + mock.patch.object(aiohttp.web_runner, "AppRunner", return_value=mock.AsyncMock()) as web_app_runner, + mock.patch.object(aiohttp.web, "Application") as web_application, + mock.patch.object(threading, "current_thread"), + ): await mock_interaction_server.start(ssl_context=mock_context) - mock_interaction_server._fetch_public_key.assert_awaited_once_with() + mock_interaction_server._fetch_public_key.assert_awaited_once_with() - aiohttp.web.Application.assert_called_once_with() - aiohttp.web.Application.return_value.add_routes.assert_called_once_with( - [aiohttp.web.post("/", mock_interaction_server.aiohttp_hook)] - ) - aiohttp.web_runner.AppRunner.assert_called_once_with( - aiohttp.web.Application.return_value, access_log=interaction_server_impl._LOGGER - ) - aiohttp.web_runner.AppRunner.return_value.setup.assert_awaited_once() - aiohttp.web.TCPSite.assert_called_once_with( - aiohttp.web_runner.AppRunner.return_value, - port=None, - shutdown_timeout=60.0, - ssl_context=mock_context, - backlog=128, - reuse_address=None, - reuse_port=None, - ) - aiohttp.web.TCPSite.return_value.start.assert_awaited_once() + web_application.assert_called_once_with() + web_application.return_value.add_routes.assert_called_once_with( + [aiohttp.web.post("/", mock_interaction_server.aiohttp_hook)] + ) + web_app_runner.assert_called_once_with(web_application.return_value, access_log=interaction_server_impl._LOGGER) + web_app_runner.return_value.setup.assert_awaited_once() + web_tcp_site.assert_called_once_with( + web_app_runner.return_value, + port=None, + shutdown_timeout=60.0, + ssl_context=mock_context, + backlog=128, + reuse_address=None, + reuse_port=None, + ) + web_tcp_site.return_value.start.assert_awaited_once() @pytest.mark.asyncio async def test_start_with_multiple_hosts(self, mock_interaction_server: interaction_server_impl.InteractionServer): - mock_context = object() + mock_context = mock.Mock() mock_interaction_server._fetch_public_key = mock.AsyncMock() - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(aiohttp.web, "TCPSite", return_value=mock.AsyncMock())) - stack.enter_context(mock.patch.object(aiohttp.web_runner, "AppRunner", return_value=mock.AsyncMock())) - stack.enter_context(mock.patch.object(aiohttp.web, "Application")) - with stack: + with ( + mock.patch.object(aiohttp.web, "TCPSite", return_value=mock.AsyncMock()) as web_tcp_site, + mock.patch.object(aiohttp.web_runner, "AppRunner", return_value=mock.AsyncMock()) as web_app_runner, + mock.patch.object(aiohttp.web, "Application") as web_application, + ): await mock_interaction_server.start(ssl_context=mock_context, host=["123", "4312"], port=453123) - mock_interaction_server._fetch_public_key.assert_awaited_once_with() + mock_interaction_server._fetch_public_key.assert_awaited_once_with() - aiohttp.web.Application.assert_called_once_with() - aiohttp.web.Application.return_value.add_routes.assert_called_once_with( - [aiohttp.web.post("/", mock_interaction_server.aiohttp_hook)] - ) - aiohttp.web_runner.AppRunner.assert_called_once_with( - aiohttp.web.Application.return_value, access_log=interaction_server_impl._LOGGER - ) - aiohttp.web_runner.AppRunner.return_value.setup.assert_awaited_once() - aiohttp.web.TCPSite.assert_has_calls( - [ - mock.call( - aiohttp.web_runner.AppRunner.return_value, - "123", - port=453123, - shutdown_timeout=60.0, - ssl_context=mock_context, - backlog=128, - reuse_address=None, - reuse_port=None, - ), - mock.call( - aiohttp.web_runner.AppRunner.return_value, - "4312", - port=453123, - shutdown_timeout=60.0, - ssl_context=mock_context, - backlog=128, - reuse_address=None, - reuse_port=None, - ), - ] - ) - aiohttp.web.TCPSite.return_value.start.assert_has_awaits([mock.call(), mock.call()]) + web_application.assert_called_once_with() + web_application.return_value.add_routes.assert_called_once_with( + [aiohttp.web.post("/", mock_interaction_server.aiohttp_hook)] + ) + web_app_runner.assert_called_once_with(web_application.return_value, access_log=interaction_server_impl._LOGGER) + web_app_runner.return_value.setup.assert_awaited_once() + web_tcp_site.assert_has_calls( + [ + mock.call( + web_app_runner.return_value, + "123", + port=453123, + shutdown_timeout=60.0, + ssl_context=mock_context, + backlog=128, + reuse_address=None, + reuse_port=None, + ), + mock.call( + web_app_runner.return_value, + "4312", + port=453123, + shutdown_timeout=60.0, + ssl_context=mock_context, + backlog=128, + reuse_address=None, + reuse_port=None, + ), + ] + ) + web_tcp_site.return_value.start.assert_has_awaits([mock.call(), mock.call()]) @pytest.mark.asyncio async def test_start_when_no_tcp_sites(self, mock_interaction_server: interaction_server_impl.InteractionServer): - mock_socket = object() - mock_context = object() + mock_socket = mock.Mock() + mock_context = mock.Mock() mock_interaction_server._fetch_public_key = mock.AsyncMock() - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(aiohttp.web, "TCPSite", return_value=mock.AsyncMock())) - stack.enter_context(mock.patch.object(aiohttp.web_runner, "AppRunner", return_value=mock.AsyncMock())) - stack.enter_context(mock.patch.object(aiohttp.web, "Application")) - stack.enter_context(mock.patch.object(aiohttp.web, "SockSite", return_value=mock.AsyncMock())) - stack.enter_context(mock.patch.object(aiohttp.web, "UnixSite", return_value=mock.AsyncMock())) - - with stack: + + with ( + mock.patch.object(aiohttp.web, "SockSite", return_value=mock.AsyncMock()) as web_sock_site, + mock.patch.object(aiohttp.web, "UnixSite", return_value=mock.AsyncMock()) as web_unix_site, + mock.patch.object(aiohttp.web, "TCPSite", return_value=mock.AsyncMock()) as web_tcp_site, + mock.patch.object(aiohttp.web_runner, "AppRunner", return_value=mock.AsyncMock()) as web_app_runner, + mock.patch.object(aiohttp.web, "Application") as web_application, + ): await mock_interaction_server.start(ssl_context=mock_context, socket=mock_socket) - mock_interaction_server._fetch_public_key.assert_awaited_once_with() + mock_interaction_server._fetch_public_key.assert_awaited_once_with() - aiohttp.web.Application.assert_called_once_with() - aiohttp.web.Application.return_value.add_routes.assert_called_once_with( - [aiohttp.web.post("/", mock_interaction_server.aiohttp_hook)] - ) - aiohttp.web_runner.AppRunner.assert_called_once_with( - aiohttp.web.Application.return_value, access_log=interaction_server_impl._LOGGER - ) - aiohttp.web_runner.AppRunner.return_value.setup.assert_awaited_once() - aiohttp.web.TCPSite.assert_not_called() - aiohttp.web.UnixSite.assert_not_called() - aiohttp.web.SockSite.assert_called_once_with( - aiohttp.web_runner.AppRunner.return_value, - mock_socket, - shutdown_timeout=60.0, - ssl_context=mock_context, - backlog=128, - ) - aiohttp.web.SockSite.return_value.start.assert_awaited_once() + web_application.assert_called_once_with() + web_application.return_value.add_routes.assert_called_once_with( + [aiohttp.web.post("/", mock_interaction_server.aiohttp_hook)] + ) + web_app_runner.assert_called_once_with(web_application.return_value, access_log=interaction_server_impl._LOGGER) + web_app_runner.return_value.setup.assert_awaited_once() + web_tcp_site.assert_not_called() + web_unix_site.assert_not_called() + web_sock_site.assert_called_once_with( + web_app_runner.return_value, mock_socket, shutdown_timeout=60.0, ssl_context=mock_context, backlog=128 + ) + web_sock_site.return_value.start.assert_awaited_once() @pytest.mark.asyncio async def test_start_when_already_running(self, mock_interaction_server: interaction_server_impl.InteractionServer): - mock_interaction_server._server = object() + mock_interaction_server._server = mock.Mock() with pytest.raises(errors.ComponentStateConflictError): await mock_interaction_server.start() @@ -1318,13 +1299,13 @@ def test_get_listener_when_unknown(self, mock_interaction_server: interaction_se assert mock_interaction_server.get_listener(base_interactions.PartialInteraction) is None def test_get_listener_when_registered(self, mock_interaction_server: interaction_server_impl.InteractionServer): - mock_listener = object() + mock_listener = mock.Mock() mock_interaction_server.set_listener(base_interactions.PartialInteraction, mock_listener) assert mock_interaction_server.get_listener(base_interactions.PartialInteraction) is mock_listener def test_set_listener(self, mock_interaction_server: interaction_server_impl.InteractionServer): - mock_listener = object() + mock_listener = mock.Mock() mock_interaction_server.set_listener(base_interactions.PartialInteraction, mock_listener) @@ -1333,16 +1314,16 @@ def test_set_listener(self, mock_interaction_server: interaction_server_impl.Int def test_set_listener_when_already_registered_without_replace( self, mock_interaction_server: interaction_server_impl.InteractionServer ): - mock_interaction_server.set_listener(base_interactions.PartialInteraction, object()) + mock_interaction_server.set_listener(base_interactions.PartialInteraction, mock.Mock()) with pytest.raises(TypeError): - mock_interaction_server.set_listener(base_interactions.PartialInteraction, object()) + mock_interaction_server.set_listener(base_interactions.PartialInteraction, mock.Mock()) def test_set_listener_when_already_registered_with_replace( self, mock_interaction_server: interaction_server_impl.InteractionServer ): - mock_listener = object() - mock_interaction_server.set_listener(base_interactions.PartialInteraction, object()) + mock_listener = mock.Mock() + mock_interaction_server.set_listener(base_interactions.PartialInteraction, mock.Mock()) mock_interaction_server.set_listener(base_interactions.PartialInteraction, mock_listener, replace=True) @@ -1351,7 +1332,7 @@ def test_set_listener_when_already_registered_with_replace( def test_set_listener_when_removing_listener( self, mock_interaction_server: interaction_server_impl.InteractionServer ): - mock_interaction_server.set_listener(base_interactions.PartialInteraction, object()) + mock_interaction_server.set_listener(base_interactions.PartialInteraction, mock.Mock()) mock_interaction_server.set_listener(base_interactions.PartialInteraction, None) assert mock_interaction_server.get_listener(base_interactions.PartialInteraction) is None diff --git a/tests/hikari/impl/test_rate_limits.py b/tests/hikari/impl/test_rate_limits.py index 8bf16c089b..1381120880 100644 --- a/tests/hikari/impl/test_rate_limits.py +++ b/tests/hikari/impl/test_rate_limits.py @@ -25,6 +25,7 @@ import math import sys import time +import typing import mock import pytest @@ -52,22 +53,25 @@ class MockedBaseRateLimiter(rate_limits.BaseRateLimiter): m.close.assert_called_once() +class MockBurstLimiterImpl(rate_limits.BurstRateLimiter): + async def acquire(self, *args: typing.Any, **kwargs: typing.Any) -> None: + raise NotImplementedError + + class TestBurstRateLimiter: @pytest.fixture - def mock_burst_limiter(self): - class Impl(rate_limits.BurstRateLimiter): - async def acquire(self, *args, **kwargs) -> None: - raise NotImplementedError - - return Impl(__name__) + def mock_burst_limiter(self) -> MockBurstLimiterImpl: + return MockBurstLimiterImpl(__name__) @pytest.mark.parametrize(("queue", "is_empty"), [(["foo", "bar", "baz"], False), ([], True)]) - def test_is_empty(self, queue, is_empty, mock_burst_limiter): + def test_is_empty( + self, queue: list[asyncio.Future[typing.Any]], is_empty: bool, mock_burst_limiter: MockBurstLimiterImpl + ): mock_burst_limiter.queue = queue assert mock_burst_limiter.is_empty is is_empty @pytest.mark.asyncio - async def test_close_removes_all_futures_from_queue(self, mock_burst_limiter): + async def test_close_removes_all_futures_from_queue(self, mock_burst_limiter: MockBurstLimiterImpl): event_loop = asyncio.get_running_loop() mock_burst_limiter.throttle_task = None futures = [event_loop.create_future() for _ in range(10)] @@ -76,7 +80,9 @@ async def test_close_removes_all_futures_from_queue(self, mock_burst_limiter): assert len(mock_burst_limiter.queue) == 0 @pytest.mark.asyncio - async def test_close_cancels_all_futures_pending_when_futures_pending(self, mock_burst_limiter): + async def test_close_cancels_all_futures_pending_when_futures_pending( + self, mock_burst_limiter: MockBurstLimiterImpl + ): event_loop = asyncio.get_running_loop() mock_burst_limiter.throttle_task = None futures = [event_loop.create_future() for _ in range(10)] @@ -86,14 +92,14 @@ async def test_close_cancels_all_futures_pending_when_futures_pending(self, mock assert future.cancelled(), f"future {i} was not cancelled" @pytest.mark.asyncio - async def test_close_is_silent_when_no_futures_pending(self, mock_burst_limiter): + async def test_close_is_silent_when_no_futures_pending(self, mock_burst_limiter: MockBurstLimiterImpl): mock_burst_limiter.throttle_task = None mock_burst_limiter.queue = [] mock_burst_limiter.close() assert True, "passed successfully" @pytest.mark.asyncio - async def test_close_cancels_throttle_task_if_running(self, mock_burst_limiter): + async def test_close_cancels_throttle_task_if_running(self, mock_burst_limiter: MockBurstLimiterImpl): event_loop = asyncio.get_running_loop() task = event_loop.create_future() mock_burst_limiter.throttle_task = task @@ -102,7 +108,7 @@ async def test_close_cancels_throttle_task_if_running(self, mock_burst_limiter): assert task.cancelled(), "throttle_task is not cancelled" @pytest.mark.asyncio - async def test_close_when_closed(self, mock_burst_limiter): + async def test_close_when_closed(self, mock_burst_limiter: MockBurstLimiterImpl): # Double-running shouldn't do anything adverse. mock_burst_limiter.close() mock_burst_limiter.close() @@ -217,7 +223,7 @@ async def test_throttle_clears_throttle_task(self): class TestWindowedBurstRateLimiter: @pytest.fixture - def ratelimiter(self): + def ratelimiter(self) -> typing.Generator[rate_limits.WindowedBurstRateLimiter, typing.Any, None]: inst = hikari_test_helpers.mock_class_namespace(rate_limits.WindowedBurstRateLimiter, slots_=False)( __name__, 3, 3 ) @@ -226,7 +232,7 @@ def ratelimiter(self): inst.close() @pytest.mark.asyncio - async def test_drip_if_not_throttled_and_not_ratelimited(self, ratelimiter): + async def test_drip_if_not_throttled_and_not_ratelimited(self, ratelimiter: rate_limits.WindowedBurstRateLimiter): event_loop = asyncio.get_running_loop() ratelimiter.remaining = 10 @@ -240,7 +246,7 @@ async def test_drip_if_not_throttled_and_not_ratelimited(self, ratelimiter): event_loop.create_future.assert_not_called() @pytest.mark.asyncio - async def test_no_drip_if_throttle_task_is_not_None(self, ratelimiter): + async def test_no_drip_if_throttle_task_is_not_None(self, ratelimiter: rate_limits.WindowedBurstRateLimiter): event_loop = asyncio.get_running_loop() ratelimiter.remaining = 10 @@ -254,7 +260,7 @@ async def test_no_drip_if_throttle_task_is_not_None(self, ratelimiter): assert ratelimiter.remaining == 10 @pytest.mark.asyncio - async def test_no_drip_if_rate_limited(self, ratelimiter): + async def test_no_drip_if_rate_limited(self, ratelimiter: rate_limits.WindowedBurstRateLimiter): event_loop = asyncio.get_running_loop() ratelimiter.remaining = 10 @@ -268,7 +274,9 @@ async def test_no_drip_if_rate_limited(self, ratelimiter): assert ratelimiter.remaining == 10 @pytest.mark.asyncio - async def test_task_scheduled_if_rate_limited_and_throttle_task_is_None(self, ratelimiter): + async def test_task_scheduled_if_rate_limited_and_throttle_task_is_None( + self, ratelimiter: rate_limits.WindowedBurstRateLimiter + ): event_loop = asyncio.get_running_loop() ratelimiter.throttle_task = None @@ -283,7 +291,9 @@ async def test_task_scheduled_if_rate_limited_and_throttle_task_is_None(self, ra ratelimiter.throttle.assert_called() @pytest.mark.asyncio - async def test_task_not_scheduled_if_rate_limited_and_throttle_task_not_None(self, ratelimiter): + async def test_task_not_scheduled_if_rate_limited_and_throttle_task_not_None( + self, ratelimiter: rate_limits.WindowedBurstRateLimiter + ): event_loop = asyncio.get_running_loop() ratelimiter.throttle_task = event_loop.create_future() @@ -296,7 +306,9 @@ async def test_task_not_scheduled_if_rate_limited_and_throttle_task_not_None(sel assert old_task is ratelimiter.throttle_task, "task was rescheduled, that shouldn't happen :(" @pytest.mark.asyncio - async def test_future_is_added_to_queue_if_throttle_task_is_not_None(self, ratelimiter): + async def test_future_is_added_to_queue_if_throttle_task_is_not_None( + self, ratelimiter: rate_limits.WindowedBurstRateLimiter + ): event_loop = asyncio.get_running_loop() ratelimiter.throttle_task = asyncio.get_running_loop().create_future() @@ -310,7 +322,7 @@ async def test_future_is_added_to_queue_if_throttle_task_is_not_None(self, ratel assert ratelimiter.queue[-1:] == [future] @pytest.mark.asyncio - async def test_future_is_added_to_queue_if_rate_limited(self, ratelimiter): + async def test_future_is_added_to_queue_if_rate_limited(self, ratelimiter: rate_limits.WindowedBurstRateLimiter): event_loop = asyncio.get_running_loop() ratelimiter.throttle_task = None @@ -421,7 +433,7 @@ def test_is_rate_limited_when_rate_limit_expired(self): assert rl.remaining == 27 @pytest.mark.parametrize("remaining", [-1, 0, 1]) - def test_is_rate_limited_when_rate_limit_not_expired(self, remaining): + def test_is_rate_limited_when_rate_limit_not_expired(self, remaining: int): with rate_limits.WindowedBurstRateLimiter(__name__, 403, 27) as rl: now = 420 rl.move_at = now + 69 @@ -464,7 +476,7 @@ def test_reset(self): assert eb.increment == 0 @pytest.mark.parametrize(("iteration", "backoff"), enumerate((1, 2, 4, 8, 16, 32))) - def test_increment_linear(self, iteration, backoff): + def test_increment_linear(self, iteration: int, backoff: int): eb = rate_limits.ExponentialBackOff(2, 64, 0) for _ in range(iteration): @@ -499,7 +511,7 @@ def test_increment_does_not_increment_when_on_maximum(self): assert eb.increment == 5 @pytest.mark.parametrize(("iteration", "backoff"), enumerate((1, 2, 4, 8, 16, 32))) - def test_increment_jitter(self, iteration, backoff): + def test_increment_jitter(self, iteration: int, backoff: int): abs_tol = 1 eb = rate_limits.ExponentialBackOff(2, 64, abs_tol) diff --git a/tests/hikari/impl/test_rest.py b/tests/hikari/impl/test_rest.py index 65b92634eb..1c89f363c4 100644 --- a/tests/hikari/impl/test_rest.py +++ b/tests/hikari/impl/test_rest.py @@ -21,6 +21,7 @@ from __future__ import annotations import asyncio +import concurrent.futures import contextlib import datetime import http @@ -43,18 +44,23 @@ from hikari import errors from hikari import files from hikari import guilds +from hikari import interactions from hikari import invites from hikari import iterators from hikari import locales -from hikari import messages as message_models +from hikari import messages from hikari import permissions from hikari import scheduled_events +from hikari import sessions from hikari import snowflakes from hikari import stage_instances +from hikari import stickers from hikari import undefined from hikari import urls from hikari import users +from hikari import voices from hikari import webhooks +from hikari.api import cache from hikari.api import rest as rest_api from hikari.impl import config from hikari.impl import entity_factory @@ -75,42 +81,45 @@ class TestRestProvider: @pytest.fixture - def rest_client(self): - class StubRestClient: - http_settings = object() - proxy_settings = object() - - return StubRestClient() + def rest_client(self) -> rest.RESTClientImpl: + return mock.Mock() @pytest.fixture - def executor(self): + def executor(self) -> concurrent.futures.Executor: return mock.Mock() @pytest.fixture - def entity_factory(self): + def entity_factory(self) -> entity_factory.EntityFactoryImpl: return mock.Mock() @pytest.fixture - def rest_provider(self, rest_client, executor, entity_factory): + def rest_provider( + self, + rest_client: rest.RESTClientImpl, + executor: concurrent.futures.Executor, + entity_factory: entity_factory.EntityFactoryImpl, + ): provider = rest._RESTProvider(executor) provider.update(rest_client, entity_factory) return provider - def test_rest_property(self, rest_provider, rest_client): + def test_rest_property(self, rest_provider: rest._RESTProvider, rest_client: rest.RESTClientImpl): assert rest_provider.rest == rest_client - def test_http_settings_property(self, rest_provider, rest_client): + def test_http_settings_property(self, rest_provider: rest._RESTProvider, rest_client: rest.RESTClientImpl): assert rest_provider.http_settings == rest_client.http_settings - def test_proxy_settings_property(self, rest_provider, rest_client): + def test_proxy_settings_property(self, rest_provider: rest._RESTProvider, rest_client: rest.RESTClientImpl): assert rest_provider.proxy_settings == rest_client.proxy_settings - def test_entity_factory_property(self, rest_provider, entity_factory): + def test_entity_factory_property( + self, rest_provider: rest._RESTProvider, entity_factory: entity_factory.EntityFactoryImpl + ): assert rest_provider.entity_factory == entity_factory - def test_executor_property(self, rest_provider, executor): + def test_executor_property(self, rest_provider: rest._RESTProvider, executor: concurrent.futures.Executor): assert rest_provider.executor == executor @@ -121,7 +130,7 @@ def test_executor_property(self, rest_provider, executor): class TestClientCredentialsStrategy: @pytest.fixture - def mock_token(self): + def mock_token(self) -> applications.PartialOAuth2Token: return mock.Mock( applications.PartialOAuth2Token, expires_in=datetime.timedelta(weeks=1), @@ -129,16 +138,16 @@ def mock_token(self): access_token="okokok.fofofo.ddd", ) - def test_client_id_property(self): - mock_client = hikari_test_helpers.mock_class_namespace(applications.Application, id=43123, init_=False)() - token = rest.ClientCredentialsStrategy(client=mock_client, client_secret="123123123") + def test_client_id_property(self, mock_application: applications.Application): + token = rest.ClientCredentialsStrategy(client=mock_application, client_secret="123123123") - assert token.client_id == 43123 + assert token.client_id == 111 def test_scopes_property(self): - token = rest.ClientCredentialsStrategy(client=123, client_secret="123123123", scopes=[123, 5643]) + scopes = [applications.OAuth2Scope.BOT, applications.OAuth2Scope.APPLICATIONS_ENTITLEMENTS] + token = rest.ClientCredentialsStrategy(client=123, client_secret="123123123", scopes=scopes) - assert token.scopes == (123, 5643) + assert token.scopes == ("bot", "applications.entitlements") def test_token_type_property(self): token = rest.ClientCredentialsStrategy(client=123, client_secret="123123123", scopes=[]) @@ -146,7 +155,7 @@ def test_token_type_property(self): assert token.token_type is applications.TokenType.BEARER @pytest.mark.asyncio - async def test_acquire_on_new_instance(self, mock_token): + async def test_acquire_on_new_instance(self, mock_token: applications.PartialOAuth2Token): mock_rest = mock.Mock(authorize_client_credentials_token=mock.AsyncMock(return_value=mock_token)) result = await rest.ClientCredentialsStrategy(client=54123123, client_secret="123123123").acquire(mock_rest) @@ -158,7 +167,7 @@ async def test_acquire_on_new_instance(self, mock_token): ) @pytest.mark.asyncio - async def test_acquire_handles_out_of_date_token(self, mock_token): + async def test_acquire_handles_out_of_date_token(self, mock_token: applications.PartialOAuth2Token): mock_old_token = mock.Mock( applications.PartialOAuth2Token, expires_in=datetime.timedelta(weeks=1), @@ -181,7 +190,9 @@ async def test_acquire_handles_out_of_date_token(self, mock_token): assert new_token == "Bearer okokok.fofofo.ddd" @pytest.mark.asyncio - async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, mock_token): + async def test_acquire_handles_token_being_set_before_lock_is_acquired( + self, mock_token: applications.PartialOAuth2Token + ): lock = asyncio.Lock() mock_rest = mock.Mock(authorize_client_credentials_token=mock.AsyncMock(side_effect=[mock_token])) @@ -201,7 +212,7 @@ async def test_acquire_handles_token_being_set_before_lock_is_acquired(self, moc assert results == ["Bearer okokok.fofofo.ddd", "Bearer okokok.fofofo.ddd", "Bearer okokok.fofofo.ddd"] @pytest.mark.asyncio - async def test_acquire_after_invalidation(self, mock_token): + async def test_acquire_after_invalidation(self, mock_token: applications.PartialOAuth2Token): mock_old_token = mock.Mock( applications.PartialOAuth2Token, expires_in=datetime.timedelta(weeks=1), @@ -226,7 +237,7 @@ async def test_acquire_after_invalidation(self, mock_token): @pytest.mark.asyncio async def test_acquire_uses_newly_cached_token_after_acquiring_lock(self): class MockLock: - def __init__(self, strategy): + def __init__(self, strategy: rest.ClientCredentialsStrategy): self._strategy = strategy async def __aenter__(self): @@ -305,7 +316,7 @@ def test_invalidate_when_token_is_stored_token(self): class TestRESTApp: @pytest.fixture - def rest_app(self): + def rest_app(self) -> rest.RESTApp: return hikari_test_helpers.mock_class_namespace(rest.RESTApp, slots_=False)( executor=None, http_settings=mock.Mock(spec_set=config.HTTPSettings), @@ -315,29 +326,29 @@ def rest_app(self): url="https://some.url", ) - def test_executor_property(self, rest_app): - mock_executor = object() + def test_executor_property(self, rest_app: rest.RESTApp): + mock_executor = mock.Mock() rest_app._executor = mock_executor assert rest_app.executor is mock_executor - def test_http_settings_property(self, rest_app): - mock_http_settings = object() + def test_http_settings_property(self, rest_app: rest.RESTApp): + mock_http_settings = mock.Mock() rest_app._http_settings = mock_http_settings assert rest_app.http_settings is mock_http_settings - def test_proxy_settings(self, rest_app): - mock_proxy_settings = object() + def test_proxy_settings(self, rest_app: rest.RESTApp): + mock_proxy_settings = mock.Mock() rest_app._proxy_settings = mock_proxy_settings assert rest_app.proxy_settings is mock_proxy_settings - def test_acquire(self, rest_app): - rest_app._client_session = object() - rest_app._bucket_manager = object() - stack = contextlib.ExitStack() - mock_entity_factory = stack.enter_context(mock.patch.object(entity_factory, "EntityFactoryImpl")) - mock_client = stack.enter_context(mock.patch.object(rest, "RESTClientImpl")) + def test_acquire(self, rest_app: rest.RESTApp): + rest_app._client_session = mock.Mock() + rest_app._bucket_manager = mock.Mock() - with stack: + with ( + mock.patch.object(entity_factory, "EntityFactoryImpl") as mock_entity_factory, + mock.patch.object(rest, "RESTClientImpl") as mock_client, + ): rest_app.acquire(token="token", token_type="Type") mock_client.assert_called_once_with( @@ -363,14 +374,14 @@ def test_acquire(self, rest_app): assert rest_provider.rest is mock_client.return_value assert rest_provider.executor is rest_app._executor - def test_acquire_defaults_to_bearer_for_a_string_token(self, rest_app): - rest_app._client_session = object() - rest_app._bucket_manager = object() - stack = contextlib.ExitStack() - mock_entity_factory = stack.enter_context(mock.patch.object(entity_factory, "EntityFactoryImpl")) - mock_client = stack.enter_context(mock.patch.object(rest, "RESTClientImpl")) + def test_acquire_defaults_to_bearer_for_a_string_token(self, rest_app: rest.RESTApp): + rest_app._client_session = mock.Mock() + rest_app._bucket_manager = mock.Mock() - with stack: + with ( + mock.patch.object(entity_factory, "EntityFactoryImpl") as mock_entity_factory, + mock.patch.object(rest, "RESTClientImpl") as mock_client, + ): rest_app.acquire(token="token") mock_client.assert_called_once_with( @@ -403,17 +414,19 @@ def test_acquire_defaults_to_bearer_for_a_string_token(self, rest_app): @pytest.fixture(scope="module") -def rest_client_class(): +def rest_client_class() -> typing.Type[rest.RESTClientImpl]: return hikari_test_helpers.mock_class_namespace(rest.RESTClientImpl, slots_=False) @pytest.fixture -def mock_cache(): +def mock_cache() -> cache.MutableCache: return mock.Mock() @pytest.fixture -def rest_client(rest_client_class, mock_cache): +def rest_client( + rest_client_class: typing.Type[rest.RESTClientImpl], mock_cache: cache.MutableCache +) -> rest.RESTClientImpl: obj = rest_client_class( cache=mock_cache, http_settings=mock.Mock(spec=config.HTTPSettings), @@ -423,7 +436,7 @@ def rest_client(rest_client_class, mock_cache): token_type="tYpe", max_retries=0, rest_url="https://some.where/api/v3", - executor=object(), + executor=mock.Mock(), entity_factory=mock.Mock(), bucket_manager=mock.Mock( acquire_bucket=mock.Mock(return_value=hikari_test_helpers.AsyncContextManagerMock()), @@ -431,66 +444,483 @@ def rest_client(rest_client_class, mock_cache): ), client_session=mock.Mock(request=mock.AsyncMock()), ) - obj._close_event = object() + obj._close_event = mock.Mock() return obj -@pytest.fixture -def file_resource(): - class Stream: - def __init__(self, data): - self.open = False - self.data = data +class MockStream: + def __init__(self, data: str): + self.open = False + self.data = data - async def data_uri(self): - if not self.open: - raise RuntimeError("Tried to read off a closed stream") + async def data_uri(self): + if not self.open: + raise RuntimeError("Tried to read off a closed stream") - return self.data + return self.data - async def __aenter__(self): - self.open = True - return self + async def __aenter__(self): + self.open = True + return self - async def __aexit__(self, exc_type, exc, exc_tb) -> None: - self.open = False + async def __aexit__(self, exc_type: type[Exception], exc: Exception, exc_tb: typing.Any) -> None: + self.open = False - class FileResource(files.Resource): - filename = None - url = None - def __init__(self, stream_data): - self._stream = Stream(data=stream_data) +class MockFileResource(files.Resource[typing.Any]): + @property + def filename(self) -> str: + return "" - def stream(self, executor): - return self._stream + @property + def url(self) -> str: + return "" - return FileResource + def __init__(self, stream_data: str): + self._stream = MockStream(data=stream_data) + + def stream(self, executor: concurrent.futures.Executor): + return self._stream @pytest.fixture -def file_resource_patch(file_resource): - resource = file_resource("some data") +def file_resource_patch() -> typing.Generator[files.Resource[typing.Any], typing.Any, None]: + resource = MockFileResource("some data") with mock.patch.object(files, "ensure_resource", return_value=resource): yield resource -class StubModel(snowflakes.Unique): - id = None +# There is a naming scheme to everything. +# Partial objects have a unique identifier. guild=123, channel=456, user=789 and message=101 +# Sub objects, use their Type to identify them further. For example, a guild stage channel would be 45613 +# Objects that do not have partial, go in increments of a plural of 3 numbers. +# for example, applications start at 111, the next item would be 222 and so on. + + +@pytest.fixture +def mock_partial_guild() -> guilds.PartialGuild: + return guilds.PartialGuild( + app=mock.Mock(), id=snowflakes.Snowflake(123), icon_hash="partial_guild_icon_hash", name="partial_guild" + ) + + +def make_guild_text_channel(id: int) -> channels.GuildTextChannel: + return channels.GuildTextChannel( + app=mock.Mock(), + id=snowflakes.Snowflake(id), + name="guild_text_channel_name", + type=channels.ChannelType.GUILD_TEXT, + guild_id=mock.Mock(), # FIXME: Can this be pulled from the actual fixture? + parent_id=mock.Mock(), # FIXME: Can this be pulled from the actual fixture? + position=0, + is_nsfw=False, + permission_overwrites={}, + topic=None, + last_message_id=None, + rate_limit_per_user=datetime.timedelta(seconds=10), + last_pin_timestamp=None, + default_auto_archive_duration=datetime.timedelta(seconds=10), + ) + + +@pytest.fixture +def mock_guild_text_channel() -> channels.GuildTextChannel: + return make_guild_text_channel(4560) + + +@pytest.fixture +def mock_dm_channel() -> channels.DMChannel: + return channels.DMChannel( + app=mock.Mock(), + id=snowflakes.Snowflake(4561), + name="dm_channel_name", + type=channels.ChannelType.DM, + last_message_id=None, + recipient=mock.Mock(), + ) + + +@pytest.fixture +def mock_guild_voice_channel( + mock_guild_category: channels.GuildCategory, mock_partial_guild: guilds.PartialGuild +) -> channels.GuildVoiceChannel: + return channels.GuildVoiceChannel( + app=mock.Mock(), + id=snowflakes.Snowflake(4562), + name="guild_voice_channel_name", + type=channels.ChannelType.GUILD_VOICE, + guild_id=mock_partial_guild.id, + parent_id=mock_guild_category.id, + position=0, + is_nsfw=False, + permission_overwrites={}, + bitrate=1, + region=None, + user_limit=0, + video_quality_mode=0, + last_message_id=None, + ) + + +@pytest.fixture +def mock_guild_category(mock_partial_guild: guilds.PartialGuild) -> channels.GuildCategory: + return channels.GuildCategory( + app=mock.Mock(), + id=snowflakes.Snowflake(4564), + name="guild_category_name", + type=channels.ChannelType.GUILD_CATEGORY, + guild_id=mock_partial_guild.id, + parent_id=None, + position=0, + is_nsfw=False, + permission_overwrites={}, + ) + + +@pytest.fixture +def mock_guild_news_channel( + mock_partial_guild: guilds.PartialGuild, mock_guild_category: channels.GuildCategory +) -> channels.GuildNewsChannel: + return channels.GuildNewsChannel( + app=mock.Mock(), + id=snowflakes.Snowflake(4565), + name="guild_news_channel_name", + type=channels.ChannelType.GUILD_NEWS, + guild_id=mock_partial_guild.id, + parent_id=mock_guild_category.id, + position=1, + is_nsfw=False, + permission_overwrites={}, + topic="guild_news_channel_topic", + last_message_id=None, + last_pin_timestamp=datetime.datetime.fromtimestamp(1), + default_auto_archive_duration=datetime.timedelta(1), + ) + + +@pytest.fixture +def mock_guild_public_thread_channel( + mock_partial_guild: guilds.PartialGuild, mock_guild_text_channel: channels.GuildTextChannel, mock_user: users.User +) -> channels.GuildThreadChannel: + return channels.GuildThreadChannel( + app=mock.Mock(), + id=snowflakes.Snowflake(45611), + name="guild_public_thread_channel_name", + type=channels.ChannelType.GUILD_PUBLIC_THREAD, + guild_id=mock_partial_guild.id, + parent_id=mock_guild_text_channel.id, + last_message_id=None, + last_pin_timestamp=datetime.datetime.fromtimestamp(1), + rate_limit_per_user=datetime.timedelta(1), + approximate_message_count=1, + approximate_member_count=1, + member=None, + owner_id=mock_user.id, + metadata=mock.Mock(), + ) + + +@pytest.fixture +def mock_guild_stage_channel( + mock_partial_guild: guilds.PartialGuild, mock_guild_category: channels.GuildCategory, mock_user: users.User +) -> channels.GuildStageChannel: + return channels.GuildStageChannel( + app=mock.Mock(), + id=snowflakes.Snowflake(45613), + name="guild_news_channel_name", + type=channels.ChannelType.GUILD_STAGE, + guild_id=mock_partial_guild.id, + parent_id=mock_guild_category.id, + position=1, + is_nsfw=False, + permission_overwrites={}, + last_message_id=None, + bitrate=1, + region=None, + user_limit=1, + video_quality_mode=channels.VideoQualityMode.FULL, + ) + + +def make_user(id: int) -> users.User: + return users.UserImpl( + id=snowflakes.Snowflake(id), + app=mock.Mock(), + discriminator="0", + username="user_username", + global_name="user_global_name", + avatar_hash="user_avatar_hash", + banner_hash="user_banner_hash", + avatar_decoration=None, + accent_color=None, + is_bot=False, + is_system=False, + flags=users.UserFlag.NONE, + primary_guild=None, + ) + + +@pytest.fixture +def mock_user() -> users.User: + return make_user(789) + + +def make_mock_message(id: int) -> messages.Message: + return messages.Message( + id=snowflakes.Snowflake(id), + app=mock.Mock(), + channel_id=snowflakes.Snowflake(456), + guild_id=None, + author=mock.Mock(), + member=mock.Mock(), + content=None, + timestamp=datetime.datetime.fromtimestamp(6000), + edited_timestamp=None, + is_tts=False, + user_mentions={}, + role_mention_ids=[], + channel_mentions={}, + mentions_everyone=False, + attachments=[], + embeds=[], + poll=None, + reactions=[], + is_pinned=False, + webhook_id=snowflakes.Snowflake(432), + type=messages.MessageType.DEFAULT, + activity=None, + application=None, + message_reference=None, + flags=messages.MessageFlag.NONE, + stickers=[], + nonce=None, + referenced_message=None, + message_snapshots=[], + application_id=None, + components=[], + thread=None, + interaction_metadata=None, + ) + + +@pytest.fixture +def mock_message() -> messages.Message: + return make_mock_message(101) + + +def make_partial_webhook(id: int) -> webhooks.PartialWebhook: + return webhooks.PartialWebhook( + app=mock.Mock(), + id=snowflakes.Snowflake(id), + type=webhooks.WebhookType.APPLICATION, + name="partial_webhook_name", + avatar_hash="partial_webhook_avatar_hash", + application_id=None, + ) + + +@pytest.fixture +def mock_partial_webhook() -> webhooks.PartialWebhook: + return make_partial_webhook(112) + + +@pytest.fixture +def mock_application() -> applications.Application: + return applications.Application( + id=snowflakes.Snowflake(111), + name="application_name", + description="application_description", + icon_hash="application_icon_hash", + app=mock.Mock(), + is_bot_public=False, + is_bot_code_grant_required=False, + owner=mock.Mock(), + rpc_origins=None, + flags=applications.ApplicationFlags.EMBEDDED, + public_key=b"application_key", + team=None, + cover_image_hash="application_cover_image_hash", + terms_of_service_url=None, + privacy_policy_url=None, + role_connections_verification_url=None, + custom_install_url=None, + tags=[], + install_parameters=None, + approximate_guild_count=0, + approximate_user_install_count=0, + integration_types_config={}, + ) + + +@pytest.fixture +def mock_partial_sticker() -> stickers.PartialSticker: + return stickers.PartialSticker( + id=snowflakes.Snowflake(222), name="sticker_name", format_type=stickers.StickerFormatType.PNG + ) + + +def make_invite_with_metadata(code: str) -> invites.InviteWithMetadata: + return invites.InviteWithMetadata( + app=mock.Mock(), + code=code, + guild=None, + guild_id=None, + channel_id=snowflakes.Snowflake(456), + inviter=None, + channel=None, + target_type=invites.TargetType.STREAM, + target_user=None, + target_application=None, + approximate_active_member_count=None, + approximate_member_count=None, + expires_at=None, + uses=1, + max_uses=None, + max_age=None, + is_temporary=False, + created_at=datetime.datetime.fromtimestamp(0), + ) + + +@pytest.fixture +def mock_invite_with_metadata() -> invites.InviteWithMetadata: + return make_invite_with_metadata("invite_with_metadata_name") + + +def make_partial_role(id: int) -> guilds.PartialRole: + return guilds.PartialRole(app=mock.Mock(), id=snowflakes.Snowflake(id), name="partial_role_name") + + +@pytest.fixture +def mock_partial_role() -> guilds.PartialRole: + return make_partial_role(333) + + +def make_custom_emoji(id: int) -> emojis.CustomEmoji: + return emojis.CustomEmoji(id=snowflakes.Snowflake(id), name="custom_emoji_name", is_animated=False) + + +@pytest.fixture +def mock_custom_emoji() -> emojis.CustomEmoji: + return make_custom_emoji(4440) + + +def make_unicode_emoji(emoji: str) -> emojis.UnicodeEmoji: + return emojis.UnicodeEmoji(emoji) + + +@pytest.fixture +def mock_unicode_emoji() -> emojis.UnicodeEmoji: + return make_unicode_emoji("🙂") + + +def make_permission_overwrite(id: int) -> channels.PermissionOverwrite: + return channels.PermissionOverwrite(id=snowflakes.Snowflake(id), type=channels.PermissionOverwriteType.MEMBER) - def __init__(self, id=0): - self.id = snowflakes.Snowflake(id) + +@pytest.fixture +def mock_permission_overwrite() -> channels.PermissionOverwrite: + return make_permission_overwrite(555) + + +@pytest.fixture +def mock_partial_command(mock_application: applications.Application) -> commands.PartialCommand: + return commands.PartialCommand( + app=mock.Mock(), + id=snowflakes.Snowflake(666), + type=commands.CommandType.SLASH, + application_id=mock_application.id, + name="partial_command_name", + default_member_permissions=permissions.Permissions.NONE, + is_nsfw=False, + guild_id=None, + version=snowflakes.Snowflake(1), + name_localizations={}, + integration_types=[], + context_types=[], + ) + + +@pytest.fixture +def mock_partial_interaction(mock_application: applications.Application) -> interactions.PartialInteraction: + return interactions.PartialInteraction( + app=mock.Mock(), + id=snowflakes.Snowflake(777), + application_id=mock_application.id, + type=interactions.InteractionType.APPLICATION_COMMAND, + token="partial_interaction_token", + version=1, + app_permissions=None, + user=mock.Mock(), + member=None, + channel=mock.Mock(), + guild_id=None, + guild_locale=None, + locale="platy", + entitlements=[], + context=applications.ApplicationContextType.GUILD, + authorizing_integration_owners={}, + ) + + +@pytest.fixture +def mock_scheduled_event(mock_partial_guild: guilds.PartialGuild) -> scheduled_events.ScheduledEvent: + return scheduled_events.ScheduledEvent( + app=mock.Mock(), + id=snowflakes.Snowflake(888), + guild_id=mock_partial_guild.id, + name="scheduled_event_name", + description="scheduled_event_description", + start_time=datetime.datetime.fromtimestamp(1), + end_time=None, + privacy_level=scheduled_events.EventPrivacyLevel.GUILD_ONLY, + status=scheduled_events.ScheduledEventStatus.ACTIVE, + entity_type=scheduled_events.ScheduledEventType.VOICE, + creator=None, + user_count=None, + image_hash="scheduled_event_image_hash", + ) + + +@pytest.fixture +def mock_auto_mod_rule(mock_partial_guild: guilds.PartialGuild, mock_user: users.User) -> auto_mod.AutoModRule: + return auto_mod.AutoModRule( + app=mock.Mock(), + id=snowflakes.Snowflake(999), + guild_id=mock_partial_guild.id, + name="auto_mod_rule", + creator_id=mock_user.id, + event_type=auto_mod.AutoModEventType.MESSAGE_SEND, + trigger=auto_mod.KeywordTrigger( + type=auto_mod.AutoModTriggerType.KEYWORD, + keyword_filter=["hello", "world"], + regex_patterns=["[abc]+", "[123]+"], + allow_list=["goodbye", "planet"], + ), + actions=[], + is_enabled=False, + exempt_role_ids=[snowflakes.Snowflake(968583), snowflakes.Snowflake(534908), snowflakes.Snowflake(345873)], + exempt_channel_ids=[snowflakes.Snowflake(834754), snowflakes.Snowflake(192438)], + ) + + +def make_forum_tag(id: int) -> channels.ForumTag: + return channels.ForumTag(id=id, name="forum_tag", moderated=True, emoji=None) + + +def make_vanity_url(code: str) -> invites.VanityURL: + return invites.VanityURL(app=mock.Mock(), code=code, uses=0) class TestStringifyHttpMessage: - def test_when_body_is_str(self, rest_client): + def test_when_body_is_str(self): headers = {"HEADER1": "value1", "HEADER2": "value2", "Authorization": "this will never see the light of day"} returned = rest._stringify_http_message(headers, None) assert returned == " HEADER1: value1\n HEADER2: value2\n Authorization: **REDACTED TOKEN**" - def test_when_body_is_not_None(self, rest_client): + def test_when_body_is_not_None(self): headers = {"HEADER1": "value1", "HEADER2": "value2", "Authorization": "this will never see the light of day"} returned = rest._stringify_http_message(headers, bytes("hello :)", "ascii")) @@ -504,21 +934,25 @@ class TestTransformEmojiToUrlFormat: @pytest.mark.parametrize( ("emoji", "expected_return"), [ - (emojis.CustomEmoji(id=123, name="rooYay", is_animated=False), "rooYay:123"), + (emojis.CustomEmoji(id=snowflakes.Snowflake(123), name="rooYay", is_animated=False), "rooYay:123"), ("\N{OK HAND SIGN}", "\N{OK HAND SIGN}"), (emojis.UnicodeEmoji("\N{OK HAND SIGN}"), "\N{OK HAND SIGN}"), ], ) - def test_expected(self, rest_client, emoji, expected_return): + def test_expected(self, emoji: emojis.Emoji, expected_return: str): assert rest._transform_emoji_to_url_format(emoji, undefined.UNDEFINED) == expected_return - def test_with_id(self, rest_client): + def test_with_id(self): assert rest._transform_emoji_to_url_format("rooYay", 123) == "rooYay:123" @pytest.mark.parametrize( - "emoji", [emojis.CustomEmoji(id=123, name="rooYay", is_animated=False), emojis.UnicodeEmoji("\N{OK HAND SIGN}")] + "emoji", + [ + emojis.CustomEmoji(id=snowflakes.Snowflake(123), name="rooYay", is_animated=False), + emojis.UnicodeEmoji("\N{OK HAND SIGN}"), + ], ) - def test_when_id_passed_with_emoji_object(self, rest_client, emoji): + def test_when_id_passed_with_emoji_object(self, emoji: emojis.Emoji): with pytest.raises(ValueError, match="emoji_id shouldn't be passed when an Emoji object is passed for emoji"): rest._transform_emoji_to_url_format(emoji, 123) @@ -536,7 +970,7 @@ def test__init__when_max_retries_over_5(self): token_type="ooga booga", rest_url=None, executor=None, - entity_factory=None, + entity_factory=mock.Mock(), ) def test__init__when_token_strategy_passed_with_token_type(self): @@ -550,7 +984,7 @@ def test__init__when_token_strategy_passed_with_token_type(self): token_type="ooga booga", rest_url=None, executor=None, - entity_factory=None, + entity_factory=mock.Mock(), ) def test__init__when_token_strategy_passed(self): @@ -564,7 +998,7 @@ def test__init__when_token_strategy_passed(self): token_type=None, rest_url=None, executor=None, - entity_factory=None, + entity_factory=mock.Mock(), ) assert obj._token is mock_strategy @@ -580,7 +1014,7 @@ def test__init__when_token_is_None_sets_token_to_None(self): token_type=None, rest_url=None, executor=None, - entity_factory=None, + entity_factory=mock.Mock(), ) assert obj._token is None assert obj._token_type is None @@ -595,7 +1029,7 @@ def test__init__when_token_and_token_type_is_not_None_generates_token_with_type( token_type="tYpe", rest_url=None, executor=None, - entity_factory=None, + entity_factory=mock.Mock(), ) assert obj._token == "Type some_token" assert obj._token_type == "Type" @@ -611,7 +1045,7 @@ def test__init__when_token_provided_as_string_without_type(self): token_type=None, rest_url=None, executor=None, - entity_factory=None, + entity_factory=mock.Mock(), ) def test__init__when_rest_url_is_None_generates_url_using_default_url(self): @@ -624,7 +1058,7 @@ def test__init__when_rest_url_is_None_generates_url_using_default_url(self): token_type=None, rest_url=None, executor=None, - entity_factory=None, + entity_factory=mock.Mock(), ) assert obj._rest_url == urls.REST_API_URL @@ -638,86 +1072,99 @@ def test__init__when_rest_url_is_not_None_generates_url_using_given_url(self): token_type=None, rest_url="https://some.where/api/v2", executor=None, - entity_factory=None, + entity_factory=mock.Mock(), ) assert obj._rest_url == "https://some.where/api/v2" - def test___enter__(self, rest_client): - # ruff gets annoyed if we use "with" here so here's a hacky alternative + def test___enter__(self, rest_client: rest.RESTClientImpl): + # flake8 gets annoyed if we use "with" here so here's a hacky alternative with pytest.raises(TypeError, match=" is async-only, did you mean 'async with'?"): rest_client.__enter__() - def test___exit__(self, rest_client): + def test___exit__(self, rest_client: rest.RESTClientImpl): try: rest_client.__exit__(None, None, None) except AttributeError as exc: pytest.fail(exc) - @pytest.mark.parametrize(("attributes", "expected_result"), [(None, False), (object(), True)]) - def test_is_alive_property(self, rest_client, attributes, expected_result): - rest_client._close_event = attributes - - assert rest_client.is_alive is expected_result + @pytest.mark.parametrize(("attributes", "expected_result"), [(None, False), (mock.Mock(), True)]) + def test_is_alive_property( + self, rest_client: rest.RESTClientImpl, attributes: object | None, expected_result: bool + ): + with mock.patch.object(rest_client, "_close_event", attributes): + assert rest_client.is_alive is expected_result - def test_entity_factory_property(self, rest_client): + def test_entity_factory_property(self, rest_client: rest.RESTClientImpl): assert rest_client.entity_factory is rest_client._entity_factory - def test_http_settings_property(self, rest_client): - mock_http_settings = object() - rest_client._http_settings = mock_http_settings - assert rest_client.http_settings is mock_http_settings + def test_http_settings_property(self, rest_client: rest.RESTClientImpl): + mock_http_settings = mock.Mock() + + with mock.patch.object(rest_client, "_http_settings", mock_http_settings): + assert rest_client.http_settings is mock_http_settings - def test_proxy_settings_property(self, rest_client): - mock_proxy_settings = object() - rest_client._proxy_settings = mock_proxy_settings - assert rest_client.proxy_settings is mock_proxy_settings + def test_proxy_settings_property(self, rest_client: rest.RESTClientImpl): + mock_proxy_settings = mock.Mock() - def test_token_type_property(self, rest_client): - mock_type = object() - rest_client._token_type = mock_type - assert rest_client.token_type is mock_type + with mock.patch.object(rest_client, "_proxy_settings", mock_proxy_settings): + assert rest_client.proxy_settings is mock_proxy_settings + + def test_token_type_property(self, rest_client: rest.RESTClientImpl): + mock_type = mock.Mock() + + with mock.patch.object(rest_client, "_token_type", mock_type): + assert rest_client.token_type is mock_type @pytest.mark.parametrize("client_session_owner", [True, False]) @pytest.mark.parametrize("bucket_manager_owner", [True, False]) @pytest.mark.asyncio - async def test_close(self, rest_client, client_session_owner, bucket_manager_owner): + async def test_close( + self, rest_client: rest.RESTClientImpl, client_session_owner: bool, bucket_manager_owner: bool + ): rest_client._close_event = mock_close_event = mock.Mock() - rest_client._client_session.close = client_close = mock.AsyncMock() - rest_client._bucket_manager.close = bucket_close = mock.AsyncMock() rest_client._client_session_owner = client_session_owner rest_client._bucket_manager_owner = bucket_manager_owner - await rest_client.close() + with ( + mock.patch.object(rest_client._client_session, "close", mock.AsyncMock()) as patched__client_session_close, + mock.patch.object(rest_client, "_bucket_manager") as patched__bucket_manager, + mock.patch.object(patched__bucket_manager, "close", mock.AsyncMock()) as patched__bucket_manager_close, + ): + await rest_client.close() mock_close_event.set.assert_called_once_with() assert rest_client._close_event is None if client_session_owner: - client_close.assert_awaited_once_with() + patched__client_session_close.assert_awaited_once_with() assert rest_client._client_session is None else: - client_close.assert_not_called() + patched__client_session_close.assert_not_called() assert rest_client._client_session is not None if bucket_manager_owner: - bucket_close.assert_awaited_once_with() + patched__bucket_manager_close.assert_awaited_once_with() else: - rest_client._bucket_manager.assert_not_called() + patched__bucket_manager.assert_not_called() @pytest.mark.parametrize("client_session_owner", [True, False]) @pytest.mark.parametrize("bucket_manager_owner", [True, False]) @pytest.mark.asyncio # Function needs to be executed in a running loop - async def test_start(self, rest_client, client_session_owner, bucket_manager_owner): + async def test_start( + self, rest_client: rest.RESTClientImpl, client_session_owner: bool, bucket_manager_owner: bool + ): rest_client._client_session = None rest_client._close_event = None rest_client._bucket_manager = mock.Mock() rest_client._client_session_owner = client_session_owner rest_client._bucket_manager_owner = bucket_manager_owner - with mock.patch.object(net, "create_client_session") as create_client_session: - with mock.patch.object(net, "create_tcp_connector") as create_tcp_connector: - with mock.patch.object(asyncio, "Event") as event: - rest_client.start() + with ( + mock.patch.object(net, "create_client_session") as create_client_session, + mock.patch.object(net, "create_tcp_connector") as create_tcp_connector, + mock.patch.object(asyncio, "Event") as event, + ): + rest_client.start() assert rest_client._close_event is event.return_value @@ -739,45 +1186,48 @@ async def test_start(self, rest_client, client_session_owner, bucket_manager_own else: rest_client._bucket_manager.start.assert_not_called() - def test_start_when_active(self, rest_client): - rest_client._close_event = object() - - with pytest.raises(errors.ComponentStateConflictError): + def test_start_when_active(self, rest_client: rest.RESTClientImpl): + with mock.patch.object(rest_client, "_close_event"), pytest.raises(errors.ComponentStateConflictError): rest_client.start() ####################### # Non-async endpoints # ####################### - def test_trigger_typing(self, rest_client): - channel = StubModel(123) + def test_trigger_typing(self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel): stub_iterator = mock.Mock() with mock.patch.object(special_endpoints, "TypingIndicator", return_value=stub_iterator) as typing_indicator: - assert rest_client.trigger_typing(channel) == stub_iterator + assert rest_client.trigger_typing(mock_guild_text_channel) == stub_iterator typing_indicator.assert_called_once_with( - request_call=rest_client._request, channel=channel, rest_close_event=rest_client._close_event + request_call=rest_client._request, + channel=mock_guild_text_channel, + rest_close_event=rest_client._close_event, ) @pytest.mark.parametrize( "before", [ datetime.datetime(2020, 7, 23, 7, 18, 11, 554023, tzinfo=datetime.timezone.utc), - StubModel(735757641938108416), + make_user(735757641938108416), ], ) - def test_fetch_messages_with_before(self, rest_client, before): - channel = StubModel(123) + def test_fetch_messages_with_before( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + before: datetime.datetime | users.User, + ): stub_iterator = mock.Mock() with mock.patch.object(special_endpoints, "MessageIterator", return_value=stub_iterator) as iterator: - assert rest_client.fetch_messages(channel, before=before) == stub_iterator + assert rest_client.fetch_messages(mock_guild_text_channel, before=before) == stub_iterator iterator.assert_called_once_with( entity_factory=rest_client._entity_factory, request_call=rest_client._request, - channel=channel, + channel=mock_guild_text_channel, direction="before", first_id="735757641938108416", ) @@ -786,20 +1236,24 @@ def test_fetch_messages_with_before(self, rest_client, before): "after", [ datetime.datetime(2020, 7, 23, 7, 18, 11, 554023, tzinfo=datetime.timezone.utc), - StubModel(735757641938108416), + make_user(735757641938108416), ], ) - def test_fetch_messages_with_after(self, rest_client, after): - channel = StubModel(123) + def test_fetch_messages_with_after( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + after: datetime.datetime | users.User, + ): stub_iterator = mock.Mock() with mock.patch.object(special_endpoints, "MessageIterator", return_value=stub_iterator) as iterator: - assert rest_client.fetch_messages(channel, after=after) == stub_iterator + assert rest_client.fetch_messages(mock_guild_text_channel, after=after) == stub_iterator iterator.assert_called_once_with( entity_factory=rest_client._entity_factory, request_call=rest_client._request, - channel=channel, + channel=mock_guild_text_channel, direction="after", first_id="735757641938108416", ) @@ -808,35 +1262,40 @@ def test_fetch_messages_with_after(self, rest_client, after): "around", [ datetime.datetime(2020, 7, 23, 7, 18, 11, 554023, tzinfo=datetime.timezone.utc), - StubModel(735757641938108416), + make_user(735757641938108416), ], ) - def test_fetch_messages_with_around(self, rest_client, around): - channel = StubModel(123) + def test_fetch_messages_with_around( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + around: datetime.datetime | users.User, + ): stub_iterator = mock.Mock() with mock.patch.object(special_endpoints, "MessageIterator", return_value=stub_iterator) as iterator: - assert rest_client.fetch_messages(channel, around=around) == stub_iterator + assert rest_client.fetch_messages(mock_guild_text_channel, around=around) == stub_iterator iterator.assert_called_once_with( entity_factory=rest_client._entity_factory, request_call=rest_client._request, - channel=channel, + channel=mock_guild_text_channel, direction="around", first_id="735757641938108416", ) - def test_fetch_messages_with_default(self, rest_client): - channel = StubModel(123) + def test_fetch_messages_with_default( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): stub_iterator = mock.Mock() with mock.patch.object(special_endpoints, "MessageIterator", return_value=stub_iterator) as iterator: - assert rest_client.fetch_messages(channel) == stub_iterator + assert rest_client.fetch_messages(mock_guild_text_channel) == stub_iterator iterator.assert_called_once_with( entity_factory=rest_client._entity_factory, request_call=rest_client._request, - channel=channel, + channel=mock_guild_text_channel, direction="before", first_id=undefined.UNDEFINED, ) @@ -850,43 +1309,55 @@ def test_fetch_messages_with_default(self, rest_client): {"before": 1234, "after": 1234, "around": 1234}, ], ) - def test_fetch_messages_when_more_than_one_kwarg_passed(self, rest_client, kwargs): + def test_fetch_messages_when_more_than_one_kwarg_passed( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + kwargs: dict[str, int], + ): with pytest.raises(TypeError): - rest_client.fetch_messages(StubModel(123), **kwargs) + rest_client.fetch_messages(mock_guild_text_channel, **kwargs) - def test_fetch_reactions_for_emoji(self, rest_client): - channel = StubModel(123) - message = StubModel(456) + def test_fetch_reactions_for_emoji( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): stub_iterator = mock.Mock() with mock.patch.object(special_endpoints, "ReactorIterator", return_value=stub_iterator) as iterator: with mock.patch.object(rest, "_transform_emoji_to_url_format", return_value="rooYay:123"): - assert rest_client.fetch_reactions_for_emoji(channel, message, "<:rooYay:123>") == stub_iterator + assert ( + rest_client.fetch_reactions_for_emoji(mock_guild_text_channel, mock_message, "<:rooYay:123>") + == stub_iterator + ) iterator.assert_called_once_with( entity_factory=rest_client._entity_factory, request_call=rest_client._request, - channel=channel, - message=message, + channel=mock_guild_text_channel, + message=mock_message, emoji="rooYay:123", ) - def test_fetch_pins_with_before(self, rest_client): - channel = StubModel(123) + def test_fetch_pins_with_before( + self, rest_client: rest.RESTClientImpl, hikari_guild_text_channel: channels.GuildTextChannel + ): time = datetime.datetime(2020, 7, 23, 7, 18, 11, 554023, tzinfo=datetime.timezone.utc) stub_iterator = mock.Mock() with mock.patch.object(special_endpoints, "PinnedMessageIterator", return_value=stub_iterator) as iterator: - assert rest_client.fetch_pins(channel, before=time) == stub_iterator + assert rest_client.fetch_pins(hikari_guild_text_channel, before=time) == stub_iterator iterator.assert_called_once_with( entity_factory=rest_client._entity_factory, request_call=rest_client._request, - channel=channel, + channel=hikari_guild_text_channel, first_id=str(int(time.timestamp())), ) - def test_fetch_my_guilds_when_start_at_is_undefined(self, rest_client): + def test_fetch_my_guilds_when_start_at_is_undefined(self, rest_client: rest.RESTClientImpl): stub_iterator = mock.Mock() with mock.patch.object(special_endpoints, "OwnGuildIterator", return_value=stub_iterator) as iterator: @@ -899,7 +1370,7 @@ def test_fetch_my_guilds_when_start_at_is_undefined(self, rest_client): first_id="0", ) - def test_fetch_my_guilds_when_start_at_is_datetime(self, rest_client): + def test_fetch_my_guilds_when_start_at_is_datetime(self, rest_client: rest.RESTClientImpl): stub_iterator = mock.Mock() datetime_obj = datetime.datetime(2020, 7, 23, 7, 18, 11, 554023, tzinfo=datetime.timezone.utc) @@ -913,11 +1384,13 @@ def test_fetch_my_guilds_when_start_at_is_datetime(self, rest_client): first_id="735757641938108416", ) - def test_fetch_my_guilds_when_start_at_is_else(self, rest_client): + def test_fetch_my_guilds_when_start_at_is_else( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): stub_iterator = mock.Mock() with mock.patch.object(special_endpoints, "OwnGuildIterator", return_value=stub_iterator) as iterator: - assert rest_client.fetch_my_guilds(newest_first=True, start_at=StubModel(123)) == stub_iterator + assert rest_client.fetch_my_guilds(newest_first=True, start_at=mock_partial_guild) == stub_iterator iterator.assert_called_once_with( entity_factory=rest_client._entity_factory, @@ -926,113 +1399,126 @@ def test_fetch_my_guilds_when_start_at_is_else(self, rest_client): first_id="123", ) - def test_fetch_audit_log_when_before_is_undefined(self, rest_client): - guild = StubModel(123) + def test_fetch_audit_log_when_before_is_undefined( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): stub_iterator = mock.Mock() with mock.patch.object(special_endpoints, "AuditLogIterator", return_value=stub_iterator) as iterator: - assert rest_client.fetch_audit_log(guild) == stub_iterator + assert rest_client.fetch_audit_log(mock_partial_guild) == stub_iterator iterator.assert_called_once_with( entity_factory=rest_client._entity_factory, request_call=rest_client._request, - guild=guild, + guild=mock_partial_guild, before=undefined.UNDEFINED, user=undefined.UNDEFINED, action_type=undefined.UNDEFINED, ) - def test_fetch_audit_log_when_before_datetime(self, rest_client): - guild = StubModel(123) - user = StubModel(456) + def test_fetch_audit_log_when_before_datetime( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild, mock_user: users.User + ): stub_iterator = mock.Mock() datetime_obj = datetime.datetime(2020, 7, 23, 7, 18, 11, 554023, tzinfo=datetime.timezone.utc) with mock.patch.object(special_endpoints, "AuditLogIterator", return_value=stub_iterator) as iterator: returned = rest_client.fetch_audit_log( - guild, user=user, before=datetime_obj, event_type=audit_logs.AuditLogEventType.GUILD_UPDATE + mock_partial_guild, + user=mock_user, + before=datetime_obj, + event_type=audit_logs.AuditLogEventType.GUILD_UPDATE, ) assert returned is stub_iterator iterator.assert_called_once_with( entity_factory=rest_client._entity_factory, request_call=rest_client._request, - guild=guild, + guild=mock_partial_guild, before="735757641938108416", - user=user, + user=mock_user, action_type=audit_logs.AuditLogEventType.GUILD_UPDATE, ) - def test_fetch_audit_log_when_before_is_else(self, rest_client): - guild = StubModel(123) + def test_fetch_audit_log_when_before_is_else( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild, mock_user: users.User + ): stub_iterator = mock.Mock() with mock.patch.object(special_endpoints, "AuditLogIterator", return_value=stub_iterator) as iterator: - assert rest_client.fetch_audit_log(guild, before=StubModel(456)) == stub_iterator + assert rest_client.fetch_audit_log(mock_partial_guild, before=mock_user) == stub_iterator iterator.assert_called_once_with( entity_factory=rest_client._entity_factory, request_call=rest_client._request, - guild=guild, - before="456", + guild=mock_partial_guild, + before="789", user=undefined.UNDEFINED, action_type=undefined.UNDEFINED, ) - def test_fetch_public_archived_threads(self, rest_client: rest.RESTClientImpl): + def test_fetch_public_archived_threads( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): mock_datetime = time.utc_datetime() with mock.patch.object(special_endpoints, "GuildThreadIterator") as iterator: - result = rest_client.fetch_public_archived_threads(StubModel(54123123), before=mock_datetime) + result = rest_client.fetch_public_archived_threads(mock_guild_text_channel, before=mock_datetime) assert result is iterator.return_value iterator.assert_called_once_with( deserialize=rest_client._deserialize_public_thread, entity_factory=rest_client.entity_factory, request_call=rest_client._request, - route=routes.GET_PUBLIC_ARCHIVED_THREADS.compile(channel=54123123), + route=routes.GET_PUBLIC_ARCHIVED_THREADS.compile(channel=4560), before=mock_datetime.isoformat(), before_is_timestamp=True, ) - def test_fetch_public_archived_threads_when_before_not_specified(self, rest_client: rest.RESTClientImpl): + def test_fetch_public_archived_threads_when_before_not_specified( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): with mock.patch.object(special_endpoints, "GuildThreadIterator") as iterator: - result = rest_client.fetch_public_archived_threads(StubModel(432234)) + result = rest_client.fetch_public_archived_threads(mock_guild_text_channel) assert result is iterator.return_value iterator.assert_called_once_with( deserialize=rest_client._deserialize_public_thread, entity_factory=rest_client.entity_factory, request_call=rest_client._request, - route=routes.GET_PUBLIC_ARCHIVED_THREADS.compile(channel=432234), + route=routes.GET_PUBLIC_ARCHIVED_THREADS.compile(channel=4560), before=undefined.UNDEFINED, before_is_timestamp=True, ) - def test_fetch_private_archived_threads(self, rest_client: rest.RESTClientImpl): + def test_fetch_private_archived_threads( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): mock_datetime = time.utc_datetime() with mock.patch.object(special_endpoints, "GuildThreadIterator") as iterator: - result = rest_client.fetch_private_archived_threads(StubModel(432234432), before=mock_datetime) + result = rest_client.fetch_private_archived_threads(mock_guild_text_channel, before=mock_datetime) assert result is iterator.return_value iterator.assert_called_once_with( deserialize=rest_client.entity_factory.deserialize_guild_private_thread, entity_factory=rest_client.entity_factory, request_call=rest_client._request, - route=routes.GET_PRIVATE_ARCHIVED_THREADS.compile(channel=432234432), + route=routes.GET_PRIVATE_ARCHIVED_THREADS.compile(channel=4560), before=mock_datetime.isoformat(), before_is_timestamp=True, ) - def test_fetch_private_archived_threads_when_before_not_specified(self, rest_client: rest.RESTClientImpl): + def test_fetch_private_archived_threads_when_before_not_specified( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): with mock.patch.object(special_endpoints, "GuildThreadIterator") as iterator: - result = rest_client.fetch_private_archived_threads(StubModel(543345543)) + result = rest_client.fetch_private_archived_threads(mock_guild_text_channel) assert result is iterator.return_value iterator.assert_called_once_with( deserialize=rest_client.entity_factory.deserialize_guild_private_thread, entity_factory=rest_client.entity_factory, request_call=rest_client._request, - route=routes.GET_PRIVATE_ARCHIVED_THREADS.compile(channel=543345543), + route=routes.GET_PRIVATE_ARCHIVED_THREADS.compile(channel=4560), before=undefined.UNDEFINED, before_is_timestamp=True, ) @@ -1041,47 +1527,51 @@ def test_fetch_private_archived_threads_when_before_not_specified(self, rest_cli "before", [datetime.datetime(2022, 2, 28, 10, 58, 30, 987193, tzinfo=datetime.timezone.utc), 947809989634818048] ) def test_fetch_joined_private_archived_threads( - self, rest_client: rest.RESTClientImpl, before: typing.Union[datetime.datetime, snowflakes.Snowflake] + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + before: typing.Union[datetime.datetime, snowflakes.Snowflake], ): with mock.patch.object(special_endpoints, "GuildThreadIterator") as iterator: - result = rest_client.fetch_joined_private_archived_threads(StubModel(543123), before=before) + result = rest_client.fetch_joined_private_archived_threads(mock_guild_text_channel, before=before) assert result is iterator.return_value iterator.assert_called_once_with( deserialize=rest_client.entity_factory.deserialize_guild_private_thread, entity_factory=rest_client.entity_factory, request_call=rest_client._request, - route=routes.GET_JOINED_PRIVATE_ARCHIVED_THREADS.compile(channel=543123), + route=routes.GET_JOINED_PRIVATE_ARCHIVED_THREADS.compile(channel=4560), before="947809989634818048", before_is_timestamp=False, ) - def test_fetch_joined_private_archived_threads_when_before_not_specified(self, rest_client: rest.RESTClientImpl): + def test_fetch_joined_private_archived_threads_when_before_not_specified( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): with mock.patch.object(special_endpoints, "GuildThreadIterator") as iterator: - result = rest_client.fetch_joined_private_archived_threads(StubModel(323232)) + result = rest_client.fetch_joined_private_archived_threads(mock_guild_text_channel) assert result is iterator.return_value iterator.assert_called_once_with( deserialize=rest_client.entity_factory.deserialize_guild_private_thread, entity_factory=rest_client.entity_factory, request_call=rest_client._request, - route=routes.GET_JOINED_PRIVATE_ARCHIVED_THREADS.compile(channel=323232), + route=routes.GET_JOINED_PRIVATE_ARCHIVED_THREADS.compile(channel=4560), before=undefined.UNDEFINED, before_is_timestamp=False, ) - def test_fetch_members(self, rest_client): - guild = StubModel(123) + def test_fetch_members(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): stub_iterator = mock.Mock() with mock.patch.object(special_endpoints, "MemberIterator", return_value=stub_iterator) as iterator: - assert rest_client.fetch_members(guild) == stub_iterator + assert rest_client.fetch_members(mock_partial_guild) == stub_iterator iterator.assert_called_once_with( - entity_factory=rest_client._entity_factory, request_call=rest_client._request, guild=guild + entity_factory=rest_client._entity_factory, request_call=rest_client._request, guild=mock_partial_guild ) - def test_kick_member(self, rest_client): + def test_kick_member(self, rest_client: rest.RESTClientImpl): mock_kick_user = mock.Mock() rest_client.kick_user = mock_kick_user @@ -1090,7 +1580,7 @@ def test_kick_member(self, rest_client): assert result is mock_kick_user.return_value mock_kick_user.assert_called_once_with(123, 5423, reason="oewkwkwk") - def test_ban_member(self, rest_client): + def test_ban_member(self, rest_client: rest.RESTClientImpl): mock_ban_user = mock.Mock() rest_client.ban_user = mock_ban_user @@ -1099,7 +1589,7 @@ def test_ban_member(self, rest_client): assert result is mock_ban_user.return_value mock_ban_user.assert_called_once_with(43123, 54123, delete_message_seconds=518400, reason="wowowowo") - def test_unban_member(self, rest_client): + def test_unban_member(self, rest_client: rest.RESTClientImpl): mock_unban_user = mock.Mock() rest_client.unban_user = mock_unban_user @@ -1108,12 +1598,12 @@ def test_unban_member(self, rest_client): assert reason is mock_unban_user.return_value mock_unban_user.assert_called_once_with(123, 321, reason="ayaya") - def test_fetch_bans(self, rest_client: rest.RESTClientImpl): + def test_fetch_bans(self, rest_client: rest.RESTClientImpl, mock_user: users.User): with mock.patch.object(special_endpoints, "GuildBanIterator") as iterator_cls: - iterator = rest_client.fetch_bans(187, newest_first=True, start_at=StubModel(65652342134)) + iterator = rest_client.fetch_bans(123, newest_first=True, start_at=mock_user) iterator_cls.assert_called_once_with( - rest_client._entity_factory, rest_client._request, 187, newest_first=True, first_id="65652342134" + rest_client._entity_factory, rest_client._request, 123, newest_first=True, first_id="789" ) assert iterator is iterator_cls.return_value @@ -1153,28 +1643,28 @@ def test_fetch_bans_when_start_at_undefined_and_newest_first(self, rest_client: ) assert iterator is iterator_cls.return_value - def test_slash_command_builder(self, rest_client): + def test_slash_command_builder(self, rest_client: rest.RESTClientImpl): result = rest_client.slash_command_builder("a name", "a description") assert isinstance(result, special_endpoints.SlashCommandBuilder) - def test_context_menu_command_command_builder(self, rest_client): + def test_context_menu_command_command_builder(self, rest_client: rest.RESTClientImpl): result = rest_client.context_menu_command_builder(3, "a name") assert isinstance(result, special_endpoints.ContextMenuCommandBuilder) assert result.type == commands.CommandType.MESSAGE - def test_build_message_action_row(self, rest_client): + def test_build_message_action_row(self, rest_client: rest.RESTClientImpl): with mock.patch.object(special_endpoints, "MessageActionRowBuilder") as action_row_builder: assert rest_client.build_message_action_row() is action_row_builder.return_value action_row_builder.assert_called_once_with() - def test_build_modal_action_row(self, rest_client): + def test_build_modal_action_row(self, rest_client: rest.RESTClientImpl): with mock.patch.object(special_endpoints, "ModalActionRowBuilder") as action_row_builder: assert rest_client.build_modal_action_row() is action_row_builder.return_value action_row_builder.assert_called_once_with() - def test__build_message_payload_with_undefined_args(self, rest_client): + def test__build_message_payload_with_undefined_args(self, rest_client: rest.RESTClientImpl): with mock.patch.object( mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1} ) as generate_allowed_mentions: @@ -1188,8 +1678,8 @@ def test__build_message_payload_with_undefined_args(self, rest_client): ) @pytest.mark.parametrize("args", [("embeds", "components", "attachments"), ("embed", "component", "attachment")]) - def test__build_message_payload_with_None_args(self, rest_client, args): - kwargs = {} + def test__build_message_payload_with_None_args(self, rest_client: rest.RESTClientImpl, args: tuple[str, str, str]): + kwargs: dict[str, typing.Any] = {} for arg in args: kwargs[arg] = None @@ -1205,7 +1695,7 @@ def test__build_message_payload_with_None_args(self, rest_client, args): undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED ) - def test__build_message_payload_with_edit_and_all_mentions_undefined(self, rest_client): + def test__build_message_payload_with_edit_and_all_mentions_undefined(self, rest_client: rest.RESTClientImpl): with mock.patch.object(mentions, "generate_allowed_mentions") as generate_allowed_mentions: body, form = rest_client._build_message_payload(edit=True) @@ -1216,16 +1706,17 @@ def test__build_message_payload_with_edit_and_all_mentions_undefined(self, rest_ undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED ) - def test__build_message_payload_embed_content_syntactic_sugar(self, rest_client): + def test__build_message_payload_embed_content_syntactic_sugar(self, rest_client: rest.RESTClientImpl): embed = mock.Mock(embeds.Embed) - stack = contextlib.ExitStack() - generate_allowed_mentions = stack.enter_context( - mock.patch.object(mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1}) - ) - rest_client._entity_factory.serialize_embed.return_value = ({"embed": 1}, []) - - with stack: + with ( + mock.patch.object( + rest_client.entity_factory, "serialize_embed", return_value=({"embed": 1}, []) + ) as patched_serialize_embed, + mock.patch.object( + mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1} + ) as generate_allowed_mentions, + ): body, form = rest_client._build_message_payload(content=embed) # Returned @@ -1233,27 +1724,24 @@ def test__build_message_payload_embed_content_syntactic_sugar(self, rest_client) assert form is None # Embeds - rest_client._entity_factory.serialize_embed.assert_called_once_with(embed) + patched_serialize_embed.assert_called_once_with(embed) # Generate allowed mentions generate_allowed_mentions.assert_called_once_with( undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED ) - def test__build_message_payload_attachment_content_syntactic_sugar(self, rest_client): + def test__build_message_payload_attachment_content_syntactic_sugar(self, rest_client: rest.RESTClientImpl): attachment = mock.Mock(files.Resource) resource_attachment = mock.Mock(filename="attachment.png") - stack = contextlib.ExitStack() - ensure_resource = stack.enter_context( - mock.patch.object(files, "ensure_resource", return_value=resource_attachment) - ) - generate_allowed_mentions = stack.enter_context( - mock.patch.object(mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1}) - ) - url_encoded_form = stack.enter_context(mock.patch.object(data_binding, "URLEncodedFormBuilder")) - - with stack: + with ( + mock.patch.object(files, "ensure_resource", return_value=resource_attachment) as ensure_resource, + mock.patch.object( + mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1} + ) as generate_allowed_mentions, + mock.patch.object(data_binding, "URLEncodedFormBuilder") as url_encoded_form, + ): body, form = rest_client._build_message_payload(content=attachment) # Returned @@ -1275,35 +1763,38 @@ def test__build_message_payload_attachment_content_syntactic_sugar(self, rest_cl url_encoded_form.assert_called_once_with() url_encoded_form.return_value.add_resource.assert_called_once_with("files[0]", resource_attachment) - def test__build_message_payload_with_singular_args(self, rest_client): - attachment = object() + def test__build_message_payload_with_singular_args( + self, rest_client: rest.RESTClientImpl, mock_partial_sticker: stickers.PartialSticker + ): + attachment = mock.Mock() resource_attachment1 = mock.Mock(filename="attachment.png") resource_attachment2 = mock.Mock(filename="attachment2.png") component = mock.Mock(build=mock.Mock(return_value=({"component": 1}, ()))) - embed = object() - embed_attachment = object() - mentions_everyone = object() - mentions_reply = object() - user_mentions = object() - role_mentions = object() - - stack = contextlib.ExitStack() - ensure_resource = stack.enter_context( - mock.patch.object(files, "ensure_resource", side_effect=[resource_attachment1, resource_attachment2]) - ) - generate_allowed_mentions = stack.enter_context( - mock.patch.object(mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1}) - ) - url_encoded_form = stack.enter_context(mock.patch.object(data_binding, "URLEncodedFormBuilder")) - rest_client._entity_factory.serialize_embed.return_value = ({"embed": 1}, [embed_attachment]) + embed = mock.Mock() + embed_attachment = mock.Mock() + mentions_everyone = mock.Mock() + mentions_reply = mock.Mock() + user_mentions = mock.Mock() + role_mentions = mock.Mock() - with stack: + with ( + mock.patch.object( + rest_client.entity_factory, "serialize_embed", return_value=({"embed": 1}, [embed_attachment]) + ) as patched_serialize_embed, + mock.patch.object( + files, "ensure_resource", side_effect=[resource_attachment1, resource_attachment2] + ) as ensure_resource, + mock.patch.object( + mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1} + ) as generate_allowed_mentions, + mock.patch.object(data_binding, "URLEncodedFormBuilder") as url_encoded_form, + ): body, form = rest_client._build_message_payload( content=987654321, attachment=attachment, component=component, embed=embed, - sticker=StubModel(5412123), + sticker=mock_partial_sticker, flags=120, tts=True, mentions_everyone=mentions_everyone, @@ -1317,7 +1808,7 @@ def test__build_message_payload_with_singular_args(self, rest_client): "content": "987654321", "tts": True, "flags": 120, - "sticker_ids": ["5412123"], + "sticker_ids": ["222"], "embeds": [{"embed": 1}], "components": [{"component": 1}], "attachments": [{"id": 0, "filename": "attachment.png"}, {"id": 1, "filename": "attachment2.png"}], @@ -1330,7 +1821,7 @@ def test__build_message_payload_with_singular_args(self, rest_client): ensure_resource.assert_has_calls([mock.call(attachment), mock.call(embed_attachment)]) # Embeds - rest_client._entity_factory.serialize_embed.assert_called_once_with(embed) + patched_serialize_embed.assert_called_once_with(embed) # Components component.build.assert_called_once_with() @@ -1347,9 +1838,11 @@ def test__build_message_payload_with_singular_args(self, rest_client): [mock.call("files[0]", resource_attachment1), mock.call("files[1]", resource_attachment2)] ) - def test__build_message_payload_with_plural_args(self, rest_client): - attachment1 = object() - attachment2 = mock.Mock(message_models.Attachment, id=123, filename="attachment123.png") + def test__build_message_payload_with_plural_args( + self, rest_client: rest.RESTClientImpl, mock_partial_sticker: stickers.PartialSticker + ): + attachment1 = mock.Mock() + attachment2 = mock.Mock(messages.Attachment, id=123, filename="attachment123.png") resource_attachment1 = mock.Mock(filename="attachment.png") resource_attachment2 = mock.Mock(filename="attachment2.png") resource_attachment3 = mock.Mock(filename="attachment3.png") @@ -1358,19 +1851,26 @@ def test__build_message_payload_with_plural_args(self, rest_client): resource_attachment6 = mock.Mock(filename="attachment6.png") component1 = mock.Mock(build=mock.Mock(return_value=({"component": 1}, ()))) component2 = mock.Mock(build=mock.Mock(return_value=({"component": 2}, ()))) - embed1 = object() - embed2 = object() - embed_attachment1 = object() - embed_attachment2 = object() - embed_attachment3 = object() - embed_attachment4 = object() - mentions_everyone = object() - mentions_reply = object() - user_mentions = object() - role_mentions = object() + embed1 = mock.Mock() + embed2 = mock.Mock() + embed_attachment1 = mock.Mock() + embed_attachment2 = mock.Mock() + embed_attachment3 = mock.Mock() + embed_attachment4 = mock.Mock() + mentions_everyone = mock.Mock() + mentions_reply = mock.Mock() + user_mentions = mock.Mock() + role_mentions = mock.Mock() + + serialize_embed_side_effect = [ + ({"embed": 1}, [embed_attachment1, embed_attachment2]), + ({"embed": 2}, [embed_attachment3, embed_attachment4]), + ] - stack = contextlib.ExitStack() - ensure_resource = stack.enter_context( + with ( + mock.patch.object( + rest_client.entity_factory, "serialize_embed", side_effect=serialize_embed_side_effect + ) as patched_serialize_embed, mock.patch.object( files, "ensure_resource", @@ -1382,24 +1882,18 @@ def test__build_message_payload_with_plural_args(self, rest_client): resource_attachment5, resource_attachment6, ], - ) - ) - generate_allowed_mentions = stack.enter_context( - mock.patch.object(mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1}) - ) - url_encoded_form = stack.enter_context(mock.patch.object(data_binding, "URLEncodedFormBuilder")) - rest_client._entity_factory.serialize_embed.side_effect = [ - ({"embed": 1}, [embed_attachment1, embed_attachment2]), - ({"embed": 2}, [embed_attachment3, embed_attachment4]), - ] - - with stack: + ) as ensure_resource, + mock.patch.object( + mentions, "generate_allowed_mentions", return_value={"allowed_mentions": 1} + ) as generate_allowed_mentions, + mock.patch.object(data_binding, "URLEncodedFormBuilder") as url_encoded_form, + ): body, form = rest_client._build_message_payload( content=987654321, attachments=[attachment1, attachment2], components=[component1, component2], embeds=[embed1, embed2], - stickers=[54612123, StubModel(123321)], + stickers=[54612123, mock_partial_sticker], flags=120, tts=True, mentions_everyone=mentions_everyone, @@ -1415,7 +1909,7 @@ def test__build_message_payload_with_plural_args(self, rest_client): "flags": 120, "embeds": [{"embed": 1}, {"embed": 2}], "components": [{"component": 1}, {"component": 2}], - "sticker_ids": ["54612123", "123321"], + "sticker_ids": ["54612123", "222"], "attachments": [ {"id": 0, "filename": "attachment.png"}, {"id": 1, "filename": "attachment2.png"}, @@ -1442,8 +1936,8 @@ def test__build_message_payload_with_plural_args(self, rest_client): ) # Embeds - assert rest_client._entity_factory.serialize_embed.call_count == 2 - rest_client._entity_factory.serialize_embed.assert_has_calls([mock.call(embed1), mock.call(embed2)]) + assert patched_serialize_embed.call_count == 2 + patched_serialize_embed.assert_has_calls([mock.call(embed1), mock.call(embed2)]) # Components component1.build.assert_called_once_with() @@ -1468,9 +1962,9 @@ def test__build_message_payload_with_plural_args(self, rest_client): ] ) - def test__build_message_payload_with_edit_and_attachment_object_passed(self, rest_client): - attachment1 = object() - attachment2 = mock.Mock(message_models.Attachment, id=123, filename="attachment123.png") + def test__build_message_payload_with_edit_and_attachment_object_passed(self, rest_client: rest.RESTClientImpl): + attachment1 = mock.Mock() + attachment2 = mock.Mock(messages.Attachment, id=123, filename="attachment123.png") resource_attachment1 = mock.Mock(filename="attachment.png") resource_attachment2 = mock.Mock(filename="attachment2.png") resource_attachment3 = mock.Mock(filename="attachment3.png") @@ -1478,15 +1972,20 @@ def test__build_message_payload_with_edit_and_attachment_object_passed(self, res resource_attachment5 = mock.Mock(filename="attachment5.png") component1 = mock.Mock(build=mock.Mock(return_value=({"component": 1}, ()))) component2 = mock.Mock(build=mock.Mock(return_value=({"component": 2}, ()))) - embed1 = object() - embed2 = object() - embed_attachment1 = object() - embed_attachment2 = object() - embed_attachment3 = object() - embed_attachment4 = object() + embed1 = mock.Mock() + embed2 = mock.Mock() + embed_attachment1 = mock.Mock() + embed_attachment2 = mock.Mock() + embed_attachment3 = mock.Mock() + embed_attachment4 = mock.Mock() - stack = contextlib.ExitStack() - ensure_resource = stack.enter_context( + serialize_embed_side_effect = [ + ({"embed": 1}, [embed_attachment1, embed_attachment2]), + ({"embed": 2}, [embed_attachment3, embed_attachment4]), + ] + + with ( + mock.patch.object(rest_client.entity_factory, "serialize_embed", side_effect=serialize_embed_side_effect), mock.patch.object( files, "ensure_resource", @@ -1497,15 +1996,9 @@ def test__build_message_payload_with_edit_and_attachment_object_passed(self, res resource_attachment4, resource_attachment5, ], - ) - ) - url_encoded_form = stack.enter_context(mock.patch.object(data_binding, "URLEncodedFormBuilder")) - rest_client._entity_factory.serialize_embed.side_effect = [ - ({"embed": 1}, [embed_attachment1, embed_attachment2]), - ({"embed": 2}, [embed_attachment3, embed_attachment4]), - ] - - with stack: + ) as ensure_resource, + mock.patch.object(data_binding, "URLEncodedFormBuilder") as url_encoded_form, + ): body, form = rest_client._build_message_payload( content=987654321, attachments=[attachment1, attachment2], @@ -1513,10 +2006,10 @@ def test__build_message_payload_with_edit_and_attachment_object_passed(self, res embeds=[embed1, embed2], flags=120, tts=True, - mentions_everyone=None, - mentions_reply=None, - user_mentions=None, - role_mentions=None, + mentions_everyone=undefined.UNDEFINED, + mentions_reply=undefined.UNDEFINED, + user_mentions=undefined.UNDEFINED, + role_mentions=undefined.UNDEFINED, edit=True, ) @@ -1564,9 +2057,9 @@ def test__build_message_payload_with_edit_and_attachment_object_passed(self, res ] ) - def test__build_message_payload_with_duplicated_resources(self, rest_client): + def test__build_message_payload_with_duplicated_resources(self, rest_client: rest.RESTClientImpl): attachment1 = mock.Mock() - attachment2 = mock.Mock(message_models.Attachment, id=123, filename="attachment123.png") + attachment2 = mock.Mock(messages.Attachment, id=123, filename="attachment123.png") component_attachment = mock.Mock(filename="component.png") component_attachment2 = mock.Mock(filename="component2.png") component1 = mock.Mock( @@ -1586,12 +2079,15 @@ def test__build_message_payload_with_duplicated_resources(self, rest_client): resource_attachment5 = mock.Mock(filename="attachment5.png") resource_attachment6 = mock.Mock(filename="attachment6.png") - rest_client._entity_factory.serialize_embed.side_effect = [ - ({"embed": 1}, [embed_attachment1, embed_attachment2]), - ({"embed": 2}, [embed_attachment2, embed_attachment1]), - ] - with ( + mock.patch.object( + rest_client._entity_factory, + "serialize_embed", + side_effect=[ + ({"embed": 1}, [embed_attachment1, embed_attachment2]), + ({"embed": 2}, [embed_attachment2, embed_attachment1]), + ], + ), mock.patch.object( files, "ensure_resource", @@ -1613,10 +2109,10 @@ def test__build_message_payload_with_duplicated_resources(self, rest_client): embeds=[embed1, embed2], flags=120, tts=True, - mentions_everyone=None, - mentions_reply=None, - user_mentions=None, - role_mentions=None, + mentions_everyone=undefined.UNDEFINED, + mentions_reply=undefined.UNDEFINED, + user_mentions=undefined.UNDEFINED, + role_mentions=undefined.UNDEFINED, edit=True, ) @@ -1672,16 +2168,16 @@ def test__build_message_payload_with_duplicated_resources(self, rest_client): [("attachment", "attachments"), ("component", "components"), ("embed", "embeds"), ("sticker", "stickers")], ) def test__build_message_payload_when_both_single_and_plural_args_passed( - self, rest_client, singular_arg, plural_arg + self, rest_client: rest.RESTClientImpl, singular_arg: str, plural_arg: str ): with pytest.raises( ValueError, match=rf"You may only specify one of '{singular_arg}' or '{plural_arg}', not both" ): - rest_client._build_message_payload(**{singular_arg: object(), plural_arg: object()}) + rest_client._build_message_payload(**{singular_arg: mock.Mock(), plural_arg: mock.Mock()}) - def test_build_voice_message_payload(self, rest_client): - body, form_builder = rest_client._build_voice_message_payload( - attachment=mock.Mock(message_models.Attachment, id=123, filename="attachment123.png"), + def test_build_voice_message_payload(self, rest_client: rest.RESTClientImpl): + body, _form_builder = rest_client._build_voice_message_payload( + attachment=mock.Mock(messages.Attachment, id=123, filename="attachment123.png"), waveform="AAAA", duration=3, flags=120, @@ -1695,9 +2191,9 @@ def test_build_voice_message_payload(self, rest_client): "attachments": [{"duration_secs": 3, "waveform": "AAAA", "id": 0, "filename": "attachment123.png"}], } - def test_build_voice_message_payload_with_mentions_reply(self, rest_client): - body, form_builder = rest_client._build_voice_message_payload( - attachment=mock.Mock(message_models.Attachment, id=123, filename="attachment123.png"), + def test_build_voice_message_payload_with_mentions_reply(self, rest_client: rest.RESTClientImpl): + body, _form_builder = rest_client._build_voice_message_payload( + attachment=mock.Mock(messages.Attachment, id=123, filename="attachment123.png"), waveform="AAAA", duration=3, flags=120, @@ -1721,7 +2217,9 @@ def test_build_voice_message_payload_with_mentions_reply(self, rest_client): components.ComponentType.CHANNEL_SELECT_MENU, ), ) - def test__build_message_payload_with_singular_components_v1(self, rest_client, type): + def test__build_message_payload_with_singular_components_v1( + self, rest_client: rest.RESTClientImpl, type: components.ComponentType + ): component = mock.Mock(type=type, build=mock.Mock(return_value=({}, ()))) payload, _ = rest_client._build_message_payload(component=component) @@ -1740,12 +2238,14 @@ def test__build_message_payload_with_singular_components_v1(self, rest_client, t components.ComponentType.CONTAINER, ), ) - def test__build_message_payload_with_singular_components_v2(self, rest_client, type): + def test__build_message_payload_with_singular_components_v2( + self, rest_client: rest.RESTClientImpl, type: components.ComponentType + ): component = mock.Mock(type=type, build=mock.Mock(return_value=({}, ()))) payload, _ = rest_client._build_message_payload(component=component) - assert payload.get("flags") is message_models.MessageFlag.IS_COMPONENTS_V2 + assert payload.get("flags") is messages.MessageFlag.IS_COMPONENTS_V2 @pytest.mark.parametrize( "type", @@ -1759,7 +2259,9 @@ def test__build_message_payload_with_singular_components_v2(self, rest_client, t components.ComponentType.CHANNEL_SELECT_MENU, ), ) - def test__build_message_payload_with_multiple_components_v1(self, rest_client, type): + def test__build_message_payload_with_multiple_components_v1( + self, rest_client: rest.RESTClientImpl, type: components.ComponentType + ): component = mock.Mock(type=type, build=mock.Mock(return_value=({}, ()))) payload, _ = rest_client._build_message_payload(component=component) @@ -1778,72 +2280,67 @@ def test__build_message_payload_with_multiple_components_v1(self, rest_client, t components.ComponentType.CONTAINER, ), ) - def test__build_message_payload_with_multiple_components_v2(self, rest_client, type): + def test__build_message_payload_with_multiple_components_v2( + self, rest_client: rest.RESTClientImpl, type: components.ComponentType + ): component = mock.Mock(type=type, build=mock.Mock(return_value=({}, ()))) payload, _ = rest_client._build_message_payload(components=[component]) - assert payload.get("flags") is message_models.MessageFlag.IS_COMPONENTS_V2 + assert payload.get("flags") is messages.MessageFlag.IS_COMPONENTS_V2 - def test__build_message_payload_with_mixed_components(self, rest_client): + def test__build_message_payload_with_mixed_components(self, rest_client: rest.RESTClientImpl): component_1 = mock.Mock(type=components.ComponentType.ACTION_ROW, build=mock.Mock(return_value=({}, ()))) component_2 = mock.Mock(type=components.ComponentType.CONTAINER, build=mock.Mock(return_value=({}, ()))) payload, _ = rest_client._build_message_payload(components=[component_1, component_2]) - assert payload.get("flags") is message_models.MessageFlag.IS_COMPONENTS_V2 + assert payload.get("flags") is messages.MessageFlag.IS_COMPONENTS_V2 - def test__build_message_payload_with_components_v2_and_flags(self, rest_client): + def test__build_message_payload_with_components_v2_and_flags(self, rest_client: rest.RESTClientImpl): component_1 = mock.Mock(type=components.ComponentType.ACTION_ROW, build=mock.Mock(return_value=({}, ()))) component_2 = mock.Mock(type=components.ComponentType.CONTAINER, build=mock.Mock(return_value=({}, ()))) payload, _ = rest_client._build_message_payload( - components=[component_1, component_2], flags=message_models.MessageFlag.EPHEMERAL + components=[component_1, component_2], flags=messages.MessageFlag.EPHEMERAL ) - assert ( - payload.get("flags") is message_models.MessageFlag.IS_COMPONENTS_V2 | message_models.MessageFlag.EPHEMERAL - ) + assert payload.get("flags") is messages.MessageFlag.IS_COMPONENTS_V2 | messages.MessageFlag.EPHEMERAL - def test_interaction_deferred_builder(self, rest_client): + def test_interaction_deferred_builder(self, rest_client: rest.RESTClientImpl): result = rest_client.interaction_deferred_builder(5) assert result.type == 5 assert isinstance(result, special_endpoints.InteractionDeferredBuilder) - def test_interaction_autocomplete_builder(self, rest_client): + def test_interaction_autocomplete_builder(self, rest_client: rest.RESTClientImpl): result = rest_client.interaction_autocomplete_builder( [special_endpoints.AutocompleteChoiceBuilder(name="name", value="value")] ) assert result.choices == [special_endpoints.AutocompleteChoiceBuilder(name="name", value="value")] - def test_interaction_message_builder(self, rest_client): + def test_interaction_message_builder(self, rest_client: rest.RESTClientImpl): result = rest_client.interaction_message_builder(4) assert result.type == 4 assert isinstance(result, special_endpoints.InteractionMessageBuilder) - def test_interaction_modal_builder(self, rest_client): + def test_interaction_modal_builder(self, rest_client: rest.RESTClientImpl): result = rest_client.interaction_modal_builder("aaaaa", "custom") assert result.type == 9 assert result.title == "aaaaa" assert result.custom_id == "custom" - def test_fetch_scheduled_event_users(self, rest_client: rest.RESTClientImpl): + def test_fetch_scheduled_event_users( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild, mock_user: users.User + ): with mock.patch.object(special_endpoints, "ScheduledEventUserIterator") as iterator_cls: - iterator = rest_client.fetch_scheduled_event_users( - 33432234, 6666655555, newest_first=True, start_at=StubModel(65652342134) - ) + iterator = rest_client.fetch_scheduled_event_users(123, 6666655555, newest_first=True, start_at=mock_user) iterator_cls.assert_called_once_with( - rest_client._entity_factory, - rest_client._request, - 33432234, - 6666655555, - first_id="65652342134", - newest_first=True, + rest_client._entity_factory, rest_client._request, 123, 6666655555, first_id="789", newest_first=True ) assert iterator is iterator_cls.return_value @@ -1893,91 +2390,112 @@ def test_fetch_scheduled_event_users_when_start_at_undefined_and_newest_first( assert iterator is iterator_cls.return_value +class ExitException(Exception): ... + + @pytest.mark.asyncio class TestRESTClientImplAsync: @pytest.fixture - def exit_exception(self): - class ExitException(Exception): ... - + def exit_exception(self) -> typing.Type[ExitException]: return ExitException - async def test___aenter__and__aexit__(self, rest_client): - rest_client.close = mock.AsyncMock() - rest_client.start = mock.Mock() - - async with rest_client as client: - assert client is rest_client - rest_client.start.assert_called_once() - rest_client.close.assert_not_called() + async def test___aenter__and__aexit__(self, rest_client: rest.RESTClientImpl): + with ( + mock.patch.object(rest_client, "close", new_callable=mock.AsyncMock) as patched_close, + mock.patch.object(rest_client, "start") as patched_start, + ): + async with rest_client as client: + assert client is rest_client + patched_start.assert_called_once() + patched_close.assert_not_called() - rest_client.close.assert_awaited_once_with() + patched_close.assert_awaited_once_with() @hikari_test_helpers.timeout() - async def test_request_errors_if_both_json_and_form_builder_passed(self, rest_client): + async def test_request_errors_if_both_json_and_form_builder_passed(self, rest_client: rest.RESTClientImpl): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) with pytest.raises(ValueError, match="Can only provide one of 'json' or 'form_builder', not both"): - await rest_client._request(route, json=object(), form_builder=object()) + await rest_client._request(route, json=mock.Mock(), form_builder=mock.Mock()) @hikari_test_helpers.timeout() - async def test_request_builds_json_when_passed(self, rest_client, exit_exception): + async def test_request_builds_json_when_passed( + self, rest_client: rest.RESTClientImpl, exit_exception: typing.Type[Exception] + ): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - rest_client._client_session.request.side_effect = exit_exception rest_client._token = None - with mock.patch.object(data_binding, "JSONPayload") as json_payload: - with pytest.raises(exit_exception): - await rest_client._request(route, json={"some": "data"}) + with ( + mock.patch.object(rest_client._client_session, "request", side_effect=exit_exception) as patched_request, + mock.patch.object(data_binding, "JSONPayload") as patched_json_payload, + pytest.raises(exit_exception), + ): + await rest_client._request(route, json={"some": "data"}) - json_payload.assert_called_once_with({"some": "data"}, dumps=rest_client._dumps) - _, kwargs = rest_client._client_session.request.call_args_list[0] - assert kwargs["data"] is json_payload.return_value + patched_json_payload.assert_called_once_with({"some": "data"}, dumps=rest_client._dumps) + _, kwargs = patched_request.call_args_list[0] + assert kwargs["data"] is patched_json_payload.return_value @hikari_test_helpers.timeout() - async def test_request_builds_form_when_passed(self, rest_client, exit_exception): + async def test_request_builds_form_when_passed( + self, rest_client: rest.RESTClientImpl, exit_exception: typing.Type[Exception] + ): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - rest_client._client_session.request.side_effect = exit_exception - rest_client._token = None mock_form = mock.AsyncMock() mock_stack = mock.AsyncMock() mock_stack.__aenter__ = mock_stack - with mock.patch.object(contextlib, "AsyncExitStack", return_value=mock_stack) as exit_stack: - with pytest.raises(exit_exception): - await rest_client._request(route, form_builder=mock_form) + with ( + mock.patch.object(contextlib, "AsyncExitStack", return_value=mock_stack) as exit_stack, + mock.patch.object(rest_client, "_token", None), + mock.patch.object(rest_client, "_client_session") as patched__client_session, + mock.patch.object(patched__client_session, "request", side_effect=exit_exception) as patched_request, + pytest.raises(exit_exception), + ): + await rest_client._request(route, form_builder=mock_form) - _, kwargs = rest_client._client_session.request.call_args_list[0] + _, kwargs = patched_request.call_args_list[0] mock_form.build.assert_awaited_once_with(exit_stack.return_value, executor=rest_client._executor) assert kwargs["data"] is mock_form.build.return_value @hikari_test_helpers.timeout() - async def test_request_url_encodes_reason_header(self, rest_client, exit_exception): + async def test_request_url_encodes_reason_header( + self, rest_client: rest.RESTClientImpl, exit_exception: typing.Type[Exception] + ): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - rest_client._client_session.request.side_effect = exit_exception - with pytest.raises(exit_exception): + with ( + mock.patch.object(rest_client._client_session, "request", side_effect=exit_exception) as patched_request, + pytest.raises(exit_exception), + ): await rest_client._request(route, reason="光のenergyが 大地に降りそそぐ") - _, kwargs = rest_client._client_session.request.call_args_list[0] + _, kwargs = patched_request.call_args_list[0] assert kwargs["headers"][rest._X_AUDIT_LOG_REASON_HEADER] == ( "%E5%85%89%E3%81%AEenergy%E3%81%8C%E3%80%80%E5%A4%" "A7%E5%9C%B0%E3%81%AB%E9%99%8D%E3%82%8A%E3%81%9D%E3%81%9D%E3%81%90" ) @hikari_test_helpers.timeout() - async def test_request_with_strategy_token(self, rest_client, exit_exception): + async def test_request_with_strategy_token( + self, rest_client: rest.RESTClientImpl, exit_exception: typing.Type[Exception] + ): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - rest_client._client_session.request.side_effect = exit_exception rest_client._token = mock.Mock(rest_api.TokenStrategy, acquire=mock.AsyncMock(return_value="Bearer ok.ok.ok")) - with pytest.raises(exit_exception): + with ( + mock.patch.object(rest_client._client_session, "request", side_effect=exit_exception) as patched_request, + pytest.raises(exit_exception), + ): await rest_client._request(route) - _, kwargs = rest_client._client_session.request.call_args_list[0] + _, kwargs = patched_request.call_args_list[0] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok.ok.ok" @hikari_test_helpers.timeout() - async def test_request_retries_strategy_once(self, rest_client, exit_exception): + async def test_request_retries_strategy_once( + self, rest_client: rest.RESTClientImpl, exit_exception: typing.Type[Exception] + ): class StubResponse: status = http.HTTPStatus.UNAUTHORIZED content_type = rest._APPLICATION_JSON @@ -1988,23 +2506,30 @@ async def read(self): return '{"something": null}' route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - rest_client._client_session.request = hikari_test_helpers.CopyingAsyncMock( - side_effect=[StubResponse(), exit_exception] - ) rest_client._token = mock.Mock( rest_api.TokenStrategy, acquire=mock.AsyncMock(side_effect=["Bearer ok.ok.ok", "Bearer ok2.ok2.ok2"]) ) - with pytest.raises(exit_exception): + with ( + mock.patch.object( + rest_client._client_session, + "request", + new_callable=hikari_test_helpers.CopyingAsyncMock, + side_effect=[StubResponse(), exit_exception], + ) as patched_request, + pytest.raises(exit_exception), + ): await rest_client._request(route) - _, kwargs = rest_client._client_session.request.call_args_list[0] + _, kwargs = patched_request.call_args_list[0] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok.ok.ok" - _, kwargs = rest_client._client_session.request.call_args_list[1] + _, kwargs = patched_request.call_args_list[1] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok2.ok2.ok2" @hikari_test_helpers.timeout() - async def test_request_raises_after_re_auth_attempt(self, rest_client, exit_exception): + async def test_request_raises_after_re_auth_attempt( + self, rest_client: rest.RESTClientImpl, exit_exception: Exception + ): class StubResponse: status = http.HTTPStatus.UNAUTHORIZED content_type = rest._APPLICATION_JSON @@ -2019,75 +2544,98 @@ async def json(self): return {"something": None} route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - rest_client._client_session.request = hikari_test_helpers.CopyingAsyncMock( - side_effect=[StubResponse(), StubResponse(), StubResponse()] - ) rest_client._token = mock.Mock( rest_api.TokenStrategy, acquire=mock.AsyncMock(side_effect=["Bearer ok.ok.ok", "Bearer ok2.ok2.ok2"]) ) - with pytest.raises(errors.UnauthorizedError): + with ( + mock.patch.object( + rest_client._client_session, + "request", + new_callable=hikari_test_helpers.CopyingAsyncMock, + side_effect=[StubResponse(), StubResponse(), StubResponse()], + ) as patched_request, + pytest.raises(errors.UnauthorizedError), + ): await rest_client._request(route) - _, kwargs = rest_client._client_session.request.call_args_list[0] + _, kwargs = patched_request.call_args_list[0] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok.ok.ok" - _, kwargs = rest_client._client_session.request.call_args_list[1] + _, kwargs = patched_request.call_args_list[1] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "Bearer ok2.ok2.ok2" @hikari_test_helpers.timeout() - async def test_request_when__token_is_None(self, rest_client, exit_exception): + async def test_request_when__token_is_None( + self, rest_client: rest.RESTClientImpl, exit_exception: typing.Type[Exception] + ): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - rest_client._client_session.request.side_effect = exit_exception rest_client._token = None - with pytest.raises(exit_exception): + with ( + mock.patch.object(rest_client._client_session, "request", side_effect=exit_exception) as patched_request, + pytest.raises(exit_exception), + ): await rest_client._request(route) - _, kwargs = rest_client._client_session.request.call_args_list[0] + _, kwargs = patched_request.call_args_list[0] assert rest._AUTHORIZATION_HEADER not in kwargs["headers"] @hikari_test_helpers.timeout() - async def test_request_when__token_is_not_None(self, rest_client, exit_exception): + async def test_request_when__token_is_not_None( + self, rest_client: rest.RESTClientImpl, exit_exception: typing.Type[Exception] + ): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - rest_client._client_session.request.side_effect = exit_exception rest_client._token = "token" - with pytest.raises(exit_exception): + with ( + mock.patch.object(rest_client._client_session, "request", side_effect=exit_exception) as patched_request, + pytest.raises(exit_exception), + ): await rest_client._request(route) - _, kwargs = rest_client._client_session.request.call_args_list[0] + _, kwargs = patched_request.call_args_list[0] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "token" @hikari_test_helpers.timeout() - async def test_request_when_no_auth_passed(self, rest_client, exit_exception): + async def test_request_when_no_auth_passed( + self, rest_client: rest.RESTClientImpl, exit_exception: typing.Type[Exception] + ): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - rest_client._client_session.request.side_effect = exit_exception rest_client._token = "token" - with pytest.raises(exit_exception): + with ( + mock.patch.object(rest_client._client_session, "request", side_effect=exit_exception) as patched_request, + mock.patch.object(rest_client._bucket_manager, "acquire_bucket") as patched_acquire_bucket, + pytest.raises(exit_exception), + ): await rest_client._request(route, auth=None) - _, kwargs = rest_client._client_session.request.call_args_list[0] + _, kwargs = patched_request.call_args_list[0] assert rest._AUTHORIZATION_HEADER not in kwargs["headers"] - rest_client._bucket_manager.acquire_bucket.assert_called_once_with(route, None) - rest_client._bucket_manager.acquire_bucket.return_value.assert_used_once() + patched_acquire_bucket.assert_called_once_with(route, None) + # patched_acquire_bucket.return_value.assert_used_once() # FIXME: This is a weird thing because it fails no matter how its fixed. assert_used_once() is also not a function lmao. @hikari_test_helpers.timeout() - async def test_request_when_auth_passed(self, rest_client, exit_exception): + async def test_request_when_auth_passed( + self, rest_client: rest.RESTClientImpl, exit_exception: typing.Type[Exception] + ): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - rest_client._client_session.request.side_effect = exit_exception rest_client._token = "token" - with pytest.raises(exit_exception): + with ( + mock.patch.object(rest_client._client_session, "request", side_effect=exit_exception) as patched_request, + mock.patch.object(rest_client._bucket_manager, "acquire_bucket") as patched_acquire_bucket, + pytest.raises(exit_exception), + ): await rest_client._request(route, auth="ooga booga") - _, kwargs = rest_client._client_session.request.call_args_list[0] + _, kwargs = patched_request.call_args_list[0] assert kwargs["headers"][rest._AUTHORIZATION_HEADER] == "ooga booga" - rest_client._bucket_manager.acquire_bucket.assert_called_once_with(route, "ooga booga") - rest_client._bucket_manager.acquire_bucket.return_value.assert_used_once() + patched_acquire_bucket.assert_called_once_with(route, "ooga booga") + # patched_acquire_bucket.return_value.assert_used_once() # FIXME: This is a weird thing because it fails no matter how its fixed. assert_used_once() is also not a function lmao. @hikari_test_helpers.timeout() - async def test_request_when_response_is_NO_CONTENT(self, rest_client): + async def test_request_when_response_is_NO_CONTENT(self, rest_client: rest.RESTClientImpl): class StubResponse: status = http.HTTPStatus.NO_CONTENT reason = "cause why not" @@ -2099,7 +2647,7 @@ class StubResponse: assert (await rest_client._request(route)) is None @hikari_test_helpers.timeout() - async def test_request_when_response_is_APPLICATION_JSON(self, rest_client): + async def test_request_when_response_is_APPLICATION_JSON(self, rest_client: rest.RESTClientImpl): class StubResponse: status = http.HTTPStatus.OK content_type = rest._APPLICATION_JSON @@ -2116,7 +2664,7 @@ async def read(self): assert (await rest_client._request(route)) == {"something": None} @hikari_test_helpers.timeout() - async def test_request_when_response_is_not_JSON(self, rest_client): + async def test_request_when_response_is_not_JSON(self, rest_client: rest.RESTClientImpl): class StubResponse: status = http.HTTPStatus.IM_USED content_type = "text/html" @@ -2131,7 +2679,9 @@ class StubResponse: await rest_client._request(route) @hikari_test_helpers.timeout() - async def test_request_when_response_unhandled_status(self, rest_client, exit_exception): + async def test_request_when_response_unhandled_status( + self, rest_client: rest.RESTClientImpl, exit_exception: Exception + ): class StubResponse: status = http.HTTPStatus.NOT_IMPLEMENTED content_type = "text/html" @@ -2147,7 +2697,9 @@ class StubResponse: await rest_client._request(route) @hikari_test_helpers.timeout() - async def test_request_when_status_in_retry_codes_will_retry_until_exhausted(self, rest_client, exit_exception): + async def test_request_when_status_in_retry_codes_will_retry_until_exhausted( + self, rest_client: rest.RESTClientImpl, exit_exception: Exception + ): class StubResponse: status = http.HTTPStatus.INTERNAL_SERVER_ERROR @@ -2180,7 +2732,9 @@ class StubResponse: @hikari_test_helpers.timeout() @pytest.mark.parametrize("exception", [asyncio.TimeoutError, aiohttp.ClientConnectionError]) - async def test_request_when_connection_error_will_retry_until_exhausted(self, rest_client, exception): + async def test_request_when_connection_error_will_retry_until_exhausted( + self, rest_client: rest.RESTClientImpl, exception: typing.Type[Exception] + ): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) mock_session = mock.AsyncMock(request=mock.AsyncMock(side_effect=exception)) rest_client._max_retries = 3 @@ -2207,7 +2761,7 @@ async def test_request_when_connection_error_will_retry_until_exhausted(self, re @pytest.mark.parametrize("enabled", [True, False]) @hikari_test_helpers.timeout() - async def test_request_logger(self, rest_client, enabled): + async def test_request_logger(self, rest_client: rest.RESTClientImpl, enabled: bool): class StubResponse: status = http.HTTPStatus.NO_CONTENT headers = {} @@ -2228,7 +2782,7 @@ async def read(self): else: assert logger.log.call_count == 0 - async def test__parse_ratelimits_when_bucket_provided_updates_rate_limits(self, rest_client): + async def test__parse_ratelimits_when_bucket_provided_updates_rate_limits(self, rest_client: rest.RESTClientImpl): class StubResponse: status = http.HTTPStatus.OK headers = { @@ -2242,9 +2796,13 @@ class StubResponse: response = StubResponse() route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - assert await rest_client._parse_ratelimits(route, "auth", response) is None + with ( + mock.patch.object(rest_client, "_bucket_manager") as patched__bucket_manager, + mock.patch.object(patched__bucket_manager, "update_rate_limits") as patched_update_rate_limits, + ): + assert await rest_client._parse_ratelimits(route, "auth", response) is None - rest_client._bucket_manager.update_rate_limits.assert_called_once_with( + patched_update_rate_limits.assert_called_once_with( compiled_route=route, bucket_header="bucket_header", authentication="auth", @@ -2254,7 +2812,7 @@ class StubResponse: reset_at=12123123.2, ) - async def test__parse_ratelimits_when_not_ratelimited(self, rest_client): + async def test__parse_ratelimits_when_not_ratelimited(self, rest_client: rest.RESTClientImpl): class StubResponse: status = http.HTTPStatus.OK headers = {} @@ -2268,7 +2826,9 @@ class StubResponse: response.json.assert_not_called() - async def test__parse_ratelimits_when_ratelimited(self, rest_client, exit_exception): + async def test__parse_ratelimits_when_ratelimited( + self, rest_client: rest.RESTClientImpl, exit_exception: typing.Type[ExitException] + ): class StubResponse: status = http.HTTPStatus.TOO_MANY_REQUESTS content_type = rest._APPLICATION_JSON @@ -2281,7 +2841,7 @@ async def read(self): with pytest.raises(exit_exception): await rest_client._parse_ratelimits(route, "auth", StubResponse()) - async def test__parse_ratelimits_when_unexpected_content_type(self, rest_client): + async def test__parse_ratelimits_when_unexpected_content_type(self, rest_client: rest.RESTClientImpl): class StubResponse: status = http.HTTPStatus.TOO_MANY_REQUESTS content_type = "text/html" @@ -2295,7 +2855,7 @@ async def read(self): with pytest.raises(errors.HTTPResponseError): await rest_client._parse_ratelimits(route, "auth", StubResponse()) - async def test__parse_ratelimits_when_global_ratelimit(self, rest_client): + async def test__parse_ratelimits_when_global_ratelimit(self, rest_client: rest.RESTClientImpl): class StubResponse: status = http.HTTPStatus.TOO_MANY_REQUESTS content_type = rest._APPLICATION_JSON @@ -2306,11 +2866,16 @@ async def read(self): return '{"global": true, "retry_after": "2"}' route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - assert (await rest_client._parse_ratelimits(route, "auth", StubResponse())) == 0 - rest_client._bucket_manager.throttle.assert_called_once_with(2.0) + with ( + mock.patch.object(rest_client, "_bucket_manager") as patched__bucket_manager, + mock.patch.object(patched__bucket_manager, "throttle") as patched_throttle, + ): + assert (await rest_client._parse_ratelimits(route, "auth", StubResponse())) == 0 - async def test__parse_ratelimits_when_remaining_header_under_or_equal_to_0(self, rest_client): + patched_throttle.assert_called_once_with(2.0) + + async def test__parse_ratelimits_when_remaining_header_under_or_equal_to_0(self, rest_client: rest.RESTClientImpl): class StubResponse: status = http.HTTPStatus.TOO_MANY_REQUESTS content_type = rest._APPLICATION_JSON @@ -2323,7 +2888,7 @@ async def json(self): route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) assert await rest_client._parse_ratelimits(route, "some auth", StubResponse()) == 0 - async def test__parse_ratelimits_when_retry_after_is_not_too_long(self, rest_client): + async def test__parse_ratelimits_when_retry_after_is_not_too_long(self, rest_client: rest.RESTClientImpl): class StubResponse: status = http.HTTPStatus.TOO_MANY_REQUESTS content_type = rest._APPLICATION_JSON @@ -2333,12 +2898,14 @@ class StubResponse: async def read(self): return '{"retry_after": "0.002"}' - rest_client._bucket_manager.max_rate_limit = 10 - - route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - assert await rest_client._parse_ratelimits(route, "some auth", StubResponse()) == 0.002 + with ( + mock.patch.object(rest_client, "_bucket_manager") as patched__bucket_manager, + mock.patch.object(patched__bucket_manager, "max_rate_limit", 10), + ): + route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) + assert await rest_client._parse_ratelimits(route, "some auth", StubResponse()) == 0.002 - async def test__parse_ratelimits_when_retry_after_is_too_long(self, rest_client): + async def test__parse_ratelimits_when_retry_after_is_too_long(self, rest_client: rest.RESTClientImpl): class StubResponse: status = http.HTTPStatus.TOO_MANY_REQUESTS content_type = rest._APPLICATION_JSON @@ -2348,52 +2915,92 @@ class StubResponse: async def read(self): return '{"retry_after": "4"}' - rest_client._bucket_manager.max_rate_limit = 3 - route = routes.Route("GET", "/something/{channel}/somewhere").compile(channel=123) - with pytest.raises(errors.RateLimitTooLongError): + + with ( + mock.patch.object(rest_client, "_bucket_manager") as patched__bucket_manager, + mock.patch.object(patched__bucket_manager, "max_rate_limit", 3), + pytest.raises(errors.RateLimitTooLongError), + ): await rest_client._parse_ratelimits(route, "auth", StubResponse()) ############# # Endpoints # ############# - async def test_fetch_channel(self, rest_client): + async def test_fetch_channel(self, rest_client: rest.RESTClientImpl): expected_route = routes.GET_CHANNEL.compile(channel=123) mock_object = mock.Mock() - rest_client._entity_factory.deserialize_channel = mock.Mock(return_value=mock_object) - rest_client._request = mock.AsyncMock(return_value={"payload": "NO"}) - assert await rest_client.fetch_channel(StubModel(123)) == mock_object - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_channel.assert_called_once_with(rest_client._request.return_value) + mock_channel = mock.Mock(channels.GuildTextChannel, id=snowflakes.Snowflake(123)) + + with ( + mock.patch.object( + rest_client.entity_factory, "deserialize_channel", return_value=mock_object + ) as patched_deserialize_channel, + mock.patch.object( + rest_client, "_request", mock.AsyncMock(return_value={"payload": "NO"}) + ) as patched__request, + ): + assert await rest_client.fetch_channel(mock_channel) == mock_object + + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_channel.assert_called_once_with(patched__request.return_value) - async def test_fetch_channel_with_dm_channel_when_cacheful(self, rest_client, mock_cache): + async def test_fetch_channel_with_dm_channel_when_cacheful( + self, rest_client: rest.RESTClientImpl, mock_cache: cache.MutableCache + ): expected_route = routes.GET_CHANNEL.compile(channel=123) mock_object = mock.Mock(spec=channels.DMChannel, type=channels.ChannelType.DM) - rest_client._entity_factory.deserialize_channel = mock.Mock(return_value=mock_object) - rest_client._request = mock.AsyncMock(return_value={"payload": "NO"}) - assert await rest_client.fetch_channel(StubModel(123)) == mock_object - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_channel.assert_called_once_with(rest_client._request.return_value) - mock_cache.set_dm_channel_id.assert_called_once_with(mock_object.recipient.id, mock_object.id) + mock_channel = mock.Mock(channels.DMChannel, id=snowflakes.Snowflake(123)) + + with ( + mock.patch.object( + rest_client.entity_factory, "deserialize_channel", return_value=mock_object + ) as patched_deserialize_channel, + mock.patch.object( + rest_client, "_request", mock.AsyncMock(return_value={"payload": "NO"}) + ) as patched__request, + mock.patch.object(mock_cache, "set_dm_channel_id") as patched_set_dm_channel_id, + ): + assert await rest_client.fetch_channel(mock_channel) == mock_object + + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_channel.assert_called_once_with(patched__request.return_value) + patched_set_dm_channel_id.assert_called_once_with(mock_object.recipient.id, mock_object.id) - async def test_fetch_channel_with_dm_channel_when_cacheless(self, rest_client, mock_cache): + async def test_fetch_channel_with_dm_channel_when_cacheless( + self, rest_client: rest.RESTClientImpl, mock_cache: cache.MutableCache + ): expected_route = routes.GET_CHANNEL.compile(channel=123) mock_object = mock.Mock(spec=channels.DMChannel, type=channels.ChannelType.DM) - rest_client._cache = None - rest_client._entity_factory.deserialize_channel = mock.Mock(return_value=mock_object) - rest_client._request = mock.AsyncMock(return_value={"payload": "NO"}) - assert await rest_client.fetch_channel(StubModel(123)) == mock_object - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_channel.assert_called_once_with(rest_client._request.return_value) - mock_cache.set_dm_channel_id.assert_not_called() + mock_channel = mock.Mock(channels.DMChannel, id=snowflakes.Snowflake(123)) + + with ( + mock.patch.object(rest_client, "_cache", None), + mock.patch.object( + rest_client.entity_factory, "deserialize_channel", return_value=mock_object + ) as patched_deserialize_channel, + mock.patch.object( + rest_client, "_request", mock.AsyncMock(return_value={"payload": "NO"}) + ) as patched__request, + mock.patch.object(mock_cache, "set_dm_channel_id") as patched_set_dm_channel_id, + ): + assert await rest_client.fetch_channel(mock_channel) == mock_object + + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_channel.assert_called_once_with(patched__request.return_value) + patched_set_dm_channel_id.assert_not_called() @pytest.mark.parametrize( ("emoji", "expected_emoji_id", "expected_emoji_name"), - [(123, 123, None), ("emoji", None, "emoji"), (None, None, None)], + [ + (emojis.CustomEmoji(id=snowflakes.Snowflake(989), name="emoji", is_animated=False), 989, None), + (emojis.UnicodeEmoji("❤️"), None, "❤️"), + (None, None, None), + ], ) @pytest.mark.parametrize( ("auto_archive_duration", "default_auto_archive_duration"), @@ -2401,424 +3008,480 @@ async def test_fetch_channel_with_dm_channel_when_cacheless(self, rest_client, m ) async def test_edit_channel( self, - rest_client, - auto_archive_duration, - default_auto_archive_duration, - emoji, - expected_emoji_id, - expected_emoji_name, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_guild_category: channels.GuildCategory, + auto_archive_duration: typing.Union[int, datetime.timedelta], + default_auto_archive_duration: typing.Union[int, float], + emoji: undefined.UndefinedNoneOr[emojis.Emoji], + expected_emoji_id: typing.Optional[snowflakes.Snowflake], + expected_emoji_name: typing.Optional[str], ): - expected_route = routes.PATCH_CHANNEL.compile(channel=123) + expected_route = routes.PATCH_CHANNEL.compile(channel=4560) mock_object = mock.Mock() - rest_client._entity_factory.deserialize_channel = mock.Mock(return_value=mock_object) - rest_client._request = mock.AsyncMock(return_value={"payload": "GO"}) - rest_client._entity_factory.serialize_permission_overwrite = mock.Mock( - return_value={"type": "member", "allow": 1024, "deny": 8192, "id": "1235431"} - ) - rest_client._entity_factory.serialize_forum_tag = mock.Mock( - return_value={"id": 0, "name": "testing", "moderated": True, "emoji_id": None, "emoji_name": None} - ) - expected_json = { - "name": "new name", - "position": 1, - "rtc_region": "ostrich-city", - "topic": "new topic", - "nsfw": True, - "bitrate": 10, - "video_quality_mode": 2, - "user_limit": 100, - "rate_limit_per_user": 30, - "parent_id": "1234", - "permission_overwrites": [{"type": "member", "allow": 1024, "deny": 8192, "id": "1235431"}], - "default_auto_archive_duration": 445123, - "default_thread_rate_limit_per_user": 40, - "default_forum_layout": 1, - "default_sort_order": 0, - "default_reaction_emoji": {"emoji_id": expected_emoji_id, "emoji_name": expected_emoji_name}, - "available_tags": [{"id": 0, "name": "testing", "moderated": True, "emoji_id": None, "emoji_name": None}], - "archived": True, - "locked": False, - "invitable": True, - "auto_archive_duration": 12322, - "flags": 12, - "applied_tags": ["0"], - } - result = await rest_client.edit_channel( - StubModel(123), - name="new name", - position=1, - topic="new topic", - nsfw=True, - bitrate=10, - video_quality_mode=channels.VideoQualityMode.FULL, - user_limit=100, - rate_limit_per_user=30, - permission_overwrites=[ - channels.PermissionOverwrite( - type=channels.PermissionOverwriteType.MEMBER, - allow=permissions.Permissions.VIEW_CHANNEL, - deny=permissions.Permissions.MANAGE_MESSAGES, - id=1235431, - ) - ], - parent_category=StubModel(1234), - region="ostrich-city", - reason="some reason :)", - default_auto_archive_duration=default_auto_archive_duration, - default_thread_rate_limit_per_user=40, - default_forum_layout=channels.ForumLayoutType.LIST_VIEW, - default_sort_order=channels.ForumSortOrderType.LATEST_ACTIVITY, - available_tags=[channels.ForumTag(name="testing", moderated=True)], - default_reaction_emoji=emoji, - archived=True, - locked=False, - invitable=True, - auto_archive_duration=auto_archive_duration, - flags=12, - applied_tags=[StubModel(0)], - ) + mock_tag = channels.ForumTag(id=snowflakes.Snowflake(0), name="tag", moderated=False, emoji=None) + + with ( + mock.patch.object(rest_client, "_cache", None), + mock.patch.object( + rest_client.entity_factory, "deserialize_channel", return_value=mock_object + ) as patched_deserialize_channel, + mock.patch.object( + rest_client.entity_factory, + "serialize_permission_overwrite", + return_value={"type": "member", "allow": 1024, "deny": 8192, "id": "1235431"}, + ), + mock.patch.object( + rest_client.entity_factory, + "serialize_forum_tag", + return_value={"id": 0, "name": "testing", "moderated": True, "emoji_id": None, "emoji_name": None}, + ), + mock.patch.object( + rest_client, "_request", mock.AsyncMock(return_value={"payload": "GO"}) + ) as patched__request, + ): + expected_json = { + "name": "new name", + "position": 1, + "rtc_region": "ostrich-city", + "topic": "new topic", + "nsfw": True, + "bitrate": 10, + "video_quality_mode": channels.VideoQualityMode.FULL, + "user_limit": 100, + "rate_limit_per_user": 30, + "parent_id": "4564", + "permission_overwrites": [{"type": "member", "allow": 1024, "deny": 8192, "id": "1235431"}], + "default_auto_archive_duration": 445123, + "default_thread_rate_limit_per_user": 40, + "default_forum_layout": channels.ForumLayoutType.LIST_VIEW, + "default_sort_order": channels.ForumSortOrderType.LATEST_ACTIVITY, + "default_reaction_emoji": {"emoji_id": expected_emoji_id, "emoji_name": expected_emoji_name}, + "available_tags": [ + {"id": 0, "name": "testing", "moderated": True, "emoji_id": None, "emoji_name": None} + ], + "archived": True, + "locked": False, + "invitable": True, + "auto_archive_duration": 12322, + "flags": channels.ChannelFlag.REQUIRE_TAG, + "applied_tags": ["0"], + } + + result = await rest_client.edit_channel( + mock_guild_text_channel, + name="new name", + position=1, + topic="new topic", + nsfw=True, + bitrate=10, + video_quality_mode=channels.VideoQualityMode.FULL, + user_limit=100, + rate_limit_per_user=30, + permission_overwrites=[ + channels.PermissionOverwrite( + type=channels.PermissionOverwriteType.MEMBER, + allow=permissions.Permissions.VIEW_CHANNEL, + deny=permissions.Permissions.MANAGE_MESSAGES, + id=1235431, + ) + ], + parent_category=mock_guild_category, + region="ostrich-city", + reason="some reason :)", + default_auto_archive_duration=default_auto_archive_duration, + default_thread_rate_limit_per_user=40, + default_forum_layout=channels.ForumLayoutType.LIST_VIEW, + default_sort_order=channels.ForumSortOrderType.LATEST_ACTIVITY, + available_tags=[channels.ForumTag(name="testing", moderated=True)], + default_reaction_emoji=emoji, + archived=True, + locked=False, + invitable=True, + auto_archive_duration=auto_archive_duration, + flags=channels.ChannelFlag.REQUIRE_TAG, + applied_tags=[mock_tag], + ) - assert result == mock_object - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="some reason :)") - rest_client._entity_factory.deserialize_channel.assert_called_once_with(rest_client._request.return_value) + assert result == mock_object - async def test_edit_channel_without_optionals(self, rest_client): + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="some reason :)") + patched_deserialize_channel.assert_called_once_with(patched__request.return_value) + + async def test_edit_channel_without_optionals(self, rest_client: rest.RESTClientImpl): expected_route = routes.PATCH_CHANNEL.compile(channel=123) mock_object = mock.Mock() - rest_client._entity_factory.deserialize_channel = mock.Mock(return_value=mock_object) - rest_client._request = mock.AsyncMock(return_value={"payload": "no"}) - - assert await rest_client.edit_channel(StubModel(123)) == mock_object - rest_client._request.assert_awaited_once_with(expected_route, json={}, reason=undefined.UNDEFINED) - rest_client._entity_factory.deserialize_channel.assert_called_once_with(rest_client._request.return_value) - async def test_delete_channel(self, rest_client): - expected_route = routes.DELETE_CHANNEL.compile(channel=123) - rest_client._request = mock.AsyncMock(return_value={"id": "NNNNN"}) + mock_channel = mock.Mock(channels.GuildTextChannel, id=snowflakes.Snowflake(123)) - result = await rest_client.delete_channel(StubModel(123), reason="some reason :)") + with ( + mock.patch.object( + rest_client.entity_factory, "deserialize_channel", return_value=mock_object + ) as patched_deserialize_channel, + mock.patch.object( + rest_client, "_request", mock.AsyncMock(return_value={"payload": "no"}) + ) as patched__request, + ): + assert await rest_client.edit_channel(mock_channel) == mock_object - assert result is rest_client._entity_factory.deserialize_channel.return_value - rest_client._entity_factory.deserialize_channel.assert_called_once_with(rest_client._request.return_value) - rest_client._request.assert_awaited_once_with(expected_route, reason="some reason :)") + patched__request.assert_awaited_once_with(expected_route, json={}, reason=undefined.UNDEFINED) + patched_deserialize_channel.assert_called_once_with(patched__request.return_value) - async def test_delete_channel_without_optionals(self, rest_client): + async def test_delete_channel(self, rest_client: rest.RESTClientImpl): expected_route = routes.DELETE_CHANNEL.compile(channel=123) - rest_client._request = mock.AsyncMock(return_value={"id": "NNNNN"}) - result = await rest_client.delete_channel(StubModel(123)) + mock_channel = mock.Mock(channels.GuildTextChannel, id=snowflakes.Snowflake(123)) + + with ( + mock.patch.object(rest_client.entity_factory, "deserialize_channel") as patched_deserialize_channel, + mock.patch.object( + rest_client, "_request", mock.AsyncMock(return_value={"id": "NNNNN"}) + ) as patched__request, + ): + result = await rest_client.delete_channel(mock_channel, reason="Why not :D") - assert result is rest_client._entity_factory.deserialize_channel.return_value - rest_client._entity_factory.deserialize_channel.assert_called_once_with(rest_client._request.return_value) - rest_client._request.assert_awaited_once_with(expected_route, reason=undefined.UNDEFINED) + assert result is patched_deserialize_channel.return_value + patched_deserialize_channel.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with(expected_route, reason="Why not :D") - async def test_edit_my_voice_state_when_requesting_to_speak(self, rest_client): - rest_client._request = mock.AsyncMock() - expected_route = routes.PATCH_MY_GUILD_VOICE_STATE.compile(guild=5421) + async def test_edit_my_voice_state_when_requesting_to_speak( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_stage_channel: channels.GuildStageChannel, + ): + expected_route = routes.PATCH_MY_GUILD_VOICE_STATE.compile(guild=123) mock_datetime = mock.Mock(isoformat=mock.Mock(return_value="blamblamblam")) - with mock.patch.object(time, "utc_datetime", return_value=mock_datetime): + with ( + mock.patch.object(time, "utc_datetime", return_value=mock_datetime) as patched_utc_datetime, + mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request, + ): result = await rest_client.edit_my_voice_state( - StubModel(5421), StubModel(999), suppress=True, request_to_speak=True + mock_partial_guild, mock_guild_stage_channel, suppress=True, request_to_speak=True ) - time.utc_datetime.assert_called_once() + patched_utc_datetime.assert_called_once() mock_datetime.isoformat.assert_called_once() assert result is None - rest_client._request.assert_awaited_once_with( - expected_route, json={"channel_id": "999", "suppress": True, "request_to_speak_timestamp": "blamblamblam"} - ) - - async def test_edit_my_voice_state_when_revoking_speak_request(self, rest_client): - rest_client._request = mock.AsyncMock() - expected_route = routes.PATCH_MY_GUILD_VOICE_STATE.compile(guild=5421) - - result = await rest_client.edit_my_voice_state( - StubModel(5421), StubModel(999), suppress=True, request_to_speak=False - ) - - assert result is None - rest_client._request.assert_awaited_once_with( - expected_route, json={"channel_id": "999", "suppress": True, "request_to_speak_timestamp": None} + patched__request.assert_awaited_once_with( + expected_route, json={"channel_id": "45613", "suppress": True, "request_to_speak_timestamp": "blamblamblam"} ) - async def test_fetch_my_voice_state(self, rest_client): - expected_route = routes.GET_MY_GUILD_VOICE_STATE.compile(guild=5454) - - expected_json = { - "guild_id": "5454", - "channel_id": "3940568093485", - "user_id": "237890809345627", - "member": { - "nick": "foobarbaz", - "roles": ["11111", "22222", "33333", "44444"], - "joined_at": "2015-04-26T06:26:56.936000+00:00", - "premium_since": "2019-05-17T06:26:56.936000+00:00", - "avatar": "estrogen", - "deaf": False, - "mute": True, - "pending": False, - "communication_disabled_until": "2021-10-18T06:26:56.936000+00:00", - }, - "session_id": "39405894b9058guhfguh43t9g", - "deaf": False, - "mute": True, - "self_deaf": False, - "self_mute": True, - "self_stream": False, - "self_video": True, - "suppress": False, - "request_to_speak_timestamp": "2021-04-17T10:11:19.970105+00:00", - } - - rest_client._request = mock.AsyncMock(return_value=expected_json) - - with mock.patch.object( - rest_client._entity_factory, "deserialize_voice_state", return_value=mock.Mock() - ) as patched_deserialize_voice_state: - await rest_client.fetch_my_voice_state(StubModel(5454)) - - patched_deserialize_voice_state.assert_called_once_with(expected_json) - - rest_client._request.assert_awaited_once_with(expected_route) - - async def test_fetch_voice_state(self, rest_client): - expected_route = routes.GET_GUILD_VOICE_STATE.compile(guild=5454, user=1234567890) - - expected_json = { - "guild_id": "5454", - "channel_id": "3940568093485", - "user_id": "1234567890", - "member": { - "nick": "foobarbaz", - "roles": ["11111", "22222", "33333", "44444"], - "joined_at": "2015-04-26T06:26:56.936000+00:00", - "premium_since": "2019-05-17T06:26:56.936000+00:00", - "avatar": "estrogen", - "deaf": False, - "mute": True, - "pending": False, - "communication_disabled_until": "2021-10-18T06:26:56.936000+00:00", - }, - "session_id": "39405894b9058guhfguh43t9g", - "deaf": False, - "mute": True, - "self_deaf": False, - "self_mute": True, - "self_stream": False, - "self_video": True, - "suppress": False, - "request_to_speak_timestamp": "2021-04-17T10:11:19.970105+00:00", - } - - rest_client._request = mock.AsyncMock(return_value=expected_json) - - with mock.patch.object( - rest_client._entity_factory, "deserialize_voice_state", return_value=mock.Mock() - ) as patched_deserialize_voice_state: - await rest_client.fetch_voice_state(StubModel(5454), StubModel(1234567890)) + async def test_edit_my_voice_state_when_revoking_speak_request( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_stage_channel: channels.GuildStageChannel, + ): + expected_route = routes.PATCH_MY_GUILD_VOICE_STATE.compile(guild=123) - patched_deserialize_voice_state.assert_called_once_with(expected_json) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + result = await rest_client.edit_my_voice_state( + mock_partial_guild, mock_guild_stage_channel, suppress=True, request_to_speak=False + ) - rest_client._request.assert_awaited_once_with(expected_route) + assert result is None + patched__request.assert_awaited_once_with( + expected_route, json={"channel_id": "45613", "suppress": True, "request_to_speak_timestamp": None} + ) - async def test_edit_my_voice_state_when_providing_datetime_for_request_to_speak(self, rest_client): - rest_client._request = mock.AsyncMock() - expected_route = routes.PATCH_MY_GUILD_VOICE_STATE.compile(guild=5421) + async def test_edit_my_voice_state_when_providing_datetime_for_request_to_speak( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_stage_channel: channels.GuildStageChannel, + ): + expected_route = routes.PATCH_MY_GUILD_VOICE_STATE.compile(guild=123) mock_datetime = mock.Mock(spec=datetime.datetime, isoformat=mock.Mock(return_value="blamblamblam2")) - result = await rest_client.edit_my_voice_state( - StubModel(5421), StubModel(999), suppress=True, request_to_speak=mock_datetime - ) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + result = await rest_client.edit_my_voice_state( + mock_partial_guild, mock_guild_stage_channel, suppress=True, request_to_speak=mock_datetime + ) - assert result is None - mock_datetime.isoformat.assert_called_once() - rest_client._request.assert_awaited_once_with( - expected_route, json={"channel_id": "999", "suppress": True, "request_to_speak_timestamp": "blamblamblam2"} - ) + assert result is None + mock_datetime.isoformat.assert_called_once() + patched__request.assert_awaited_once_with( + expected_route, + json={"channel_id": "45613", "suppress": True, "request_to_speak_timestamp": "blamblamblam2"}, + ) - async def test_edit_my_voice_state_without_optional_fields(self, rest_client): - rest_client._request = mock.AsyncMock() - expected_route = routes.PATCH_MY_GUILD_VOICE_STATE.compile(guild=5421) + async def test_edit_my_voice_state_without_optional_fields( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_stage_channel: channels.GuildStageChannel, + ): + expected_route = routes.PATCH_MY_GUILD_VOICE_STATE.compile(guild=123) - result = await rest_client.edit_my_voice_state(StubModel(5421), StubModel(999)) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + result = await rest_client.edit_my_voice_state(mock_partial_guild, mock_guild_stage_channel) - assert result is None - rest_client._request.assert_awaited_once_with(expected_route, json={"channel_id": "999"}) + assert result is None + patched__request.assert_awaited_once_with(expected_route, json={"channel_id": "45613"}) - async def test_edit_voice_state(self, rest_client): - rest_client._request = mock.AsyncMock() - expected_route = routes.PATCH_GUILD_VOICE_STATE.compile(guild=543123, user=32123) + async def test_edit_voice_state( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_stage_channel: channels.GuildStageChannel, + mock_user: users.User, + ): + expected_route = routes.PATCH_GUILD_VOICE_STATE.compile(guild=123, user=789) - result = await rest_client.edit_voice_state(StubModel(543123), StubModel(321), StubModel(32123), suppress=True) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + result = await rest_client.edit_voice_state( + mock_partial_guild, mock_guild_stage_channel, mock_user, suppress=True + ) - assert result is None - rest_client._request.assert_awaited_once_with(expected_route, json={"channel_id": "321", "suppress": True}) + assert result is None + patched__request.assert_awaited_once_with(expected_route, json={"channel_id": "45613", "suppress": True}) - async def test_edit_voice_state_without_optional_arguments(self, rest_client): - rest_client._request = mock.AsyncMock() - expected_route = routes.PATCH_GUILD_VOICE_STATE.compile(guild=543123, user=32123) + async def test_edit_voice_state_without_optional_arguments( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_stage_channel: channels.GuildStageChannel, + mock_user: users.User, + ): + expected_route = routes.PATCH_GUILD_VOICE_STATE.compile(guild=123, user=789) - result = await rest_client.edit_voice_state(StubModel(543123), StubModel(321), StubModel(32123)) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + result = await rest_client.edit_voice_state(mock_partial_guild, mock_guild_stage_channel, mock_user) - assert result is None - rest_client._request.assert_awaited_once_with(expected_route, json={"channel_id": "321"}) + assert result is None + patched__request.assert_awaited_once_with(expected_route, json={"channel_id": "45613"}) + + async def test_edit_permission_overwrite( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): + expected_route = routes.PUT_CHANNEL_PERMISSIONS.compile(channel=4560, overwrite=2983) - async def test_edit_permission_overwrite(self, rest_client): - target = StubModel(456) - expected_route = routes.PUT_CHANNEL_PERMISSIONS.compile(channel=123, overwrite=456) - rest_client._request = mock.AsyncMock() expected_json = {"type": 1, "allow": 4, "deny": 1} - await rest_client.edit_permission_overwrite( - StubModel(123), - target, - target_type=channels.PermissionOverwriteType.MEMBER, - allow=permissions.Permissions.BAN_MEMBERS, - deny=permissions.Permissions.CREATE_INSTANT_INVITE, - reason="cause why not :)", - ) - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="cause why not :)") + target = mock.Mock(users.PartialUser, id=snowflakes.Snowflake(2983)) + + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.edit_permission_overwrite( + mock_guild_text_channel, + target, + target_type=channels.PermissionOverwriteType.MEMBER, + allow=permissions.Permissions.BAN_MEMBERS, + deny=permissions.Permissions.CREATE_INSTANT_INVITE, + reason="cause why not :)", + ) + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="cause why not :)") @pytest.mark.parametrize( ("target", "expected_type"), [ - (mock.Mock(users.UserImpl, id=456), channels.PermissionOverwriteType.MEMBER), - (mock.Mock(guilds.Role, id=456), channels.PermissionOverwriteType.ROLE), + (mock.Mock(users.UserImpl, id=34895734), channels.PermissionOverwriteType.MEMBER), + (mock.Mock(guilds.Role, id=34895734), channels.PermissionOverwriteType.ROLE), ( - mock.Mock(channels.PermissionOverwrite, id=456, type=channels.PermissionOverwriteType.MEMBER), + mock.Mock(channels.PermissionOverwrite, id=34895734, type=channels.PermissionOverwriteType.MEMBER), channels.PermissionOverwriteType.MEMBER, ), ], ) - async def test_edit_permission_overwrite_when_target_undefined(self, rest_client, target, expected_type): - expected_route = routes.PUT_CHANNEL_PERMISSIONS.compile(channel=123, overwrite=456) - rest_client._request = mock.AsyncMock() + async def test_edit_permission_overwrite_when_target_undefined( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + target: mock.Mock, + expected_type: channels.PermissionOverwriteType, + ): + expected_route = routes.PUT_CHANNEL_PERMISSIONS.compile(channel=4560, overwrite=34895734) + expected_json = {"type": expected_type} - await rest_client.edit_permission_overwrite(StubModel(123), target) - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason=undefined.UNDEFINED) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.edit_permission_overwrite(mock_guild_text_channel, target) + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason=undefined.UNDEFINED) + + async def test_edit_permission_overwrite_when_cant_determine_target_type(self, rest_client: rest.RESTClientImpl): + mock_channel = mock.Mock(channels.GuildStageChannel, id=snowflakes.Snowflake(123)) + mock_target = mock.Mock(id=snowflakes.Snowflake(456)) - async def test_edit_permission_overwrite_when_cant_determine_target_type(self, rest_client): with pytest.raises(TypeError): - await rest_client.edit_permission_overwrite(StubModel(123), StubModel(123)) + await rest_client.edit_permission_overwrite(mock_channel, mock_target) - async def test_delete_permission_overwrite(self, rest_client): - expected_route = routes.DELETE_CHANNEL_PERMISSIONS.compile(channel=123, overwrite=456) - rest_client._request = mock.AsyncMock() + async def test_delete_permission_overwrite( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): + expected_route = routes.DELETE_CHANNEL_PERMISSIONS.compile(channel=4560, overwrite=23409582) - await rest_client.delete_permission_overwrite(StubModel(123), StubModel(456), reason="testing") - rest_client._request.assert_awaited_once_with(expected_route, reason="testing") + mock_target = mock.Mock(users.PartialUser, id=snowflakes.Snowflake(23409582)) - async def test_fetch_channel_invites(self, rest_client): - invite1 = StubModel(456) - invite2 = StubModel(789) - expected_route = routes.GET_CHANNEL_INVITES.compile(channel=123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) - rest_client._entity_factory.deserialize_invite_with_metadata = mock.Mock(side_effect=[invite1, invite2]) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_permission_overwrite(mock_guild_text_channel, mock_target) + patched__request.assert_awaited_once_with(expected_route, reason=undefined.UNDEFINED) - assert await rest_client.fetch_channel_invites(StubModel(123)) == [invite1, invite2] - rest_client._request.assert_awaited_once_with(expected_route) - assert rest_client._entity_factory.deserialize_invite_with_metadata.call_count == 2 - rest_client._entity_factory.deserialize_invite_with_metadata.assert_has_calls( - [mock.call({"id": "456"}), mock.call({"id": "789"})] - ) + async def test_fetch_channel_invites( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): + invite1 = make_invite_with_metadata("1111") + invite2 = make_invite_with_metadata("2222") + + expected_route = routes.GET_CHANNEL_INVITES.compile(channel=4560) + + with ( + mock.patch.object( + rest_client, "_request", new=mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_invite_with_metadata", side_effect=[invite1, invite2] + ) as patched_deserialize_invite_with_metadata, + ): + assert await rest_client.fetch_channel_invites(mock_guild_text_channel) == [invite1, invite2] + patched__request.assert_awaited_once_with(expected_route) + assert patched_deserialize_invite_with_metadata.call_count == 2 + patched_deserialize_invite_with_metadata.assert_has_calls( + [mock.call({"id": "456"}), mock.call({"id": "789"})] + ) - async def test_create_invite(self, rest_client): - expected_route = routes.POST_CHANNEL_INVITES.compile(channel=123) - rest_client._request = mock.AsyncMock(return_value={"ID": "NOOOOOOOOPOOOOOOOI!"}) + async def test_create_invite( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_user: users.User, + mock_application: applications.Application, + ): + expected_route = routes.POST_CHANNEL_INVITES.compile(channel=4560) expected_json = { "max_age": 60, "max_uses": 4, "temporary": True, "unique": True, "target_type": invites.TargetType.STREAM, - "target_user_id": "456", - "target_application_id": "789", + "target_user_id": "789", + "target_application_id": "111", } - result = await rest_client.create_invite( - StubModel(123), - max_age=datetime.timedelta(minutes=1), - max_uses=4, - temporary=True, - unique=True, - target_type=invites.TargetType.STREAM, - target_user=StubModel(456), - target_application=StubModel(789), - reason="cause why not :)", - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"ID": "NOOOOOOOOPOOOOOOOI!"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_invite_with_metadata" + ) as patched_deserialize_invite_with_metadata, + ): + result = await rest_client.create_invite( + mock_guild_text_channel, + max_age=datetime.timedelta(minutes=1), + max_uses=4, + temporary=True, + unique=True, + target_type=invites.TargetType.STREAM, + target_user=mock_user, + target_application=mock_application, + reason="cause why not :)", + ) - assert result is rest_client._entity_factory.deserialize_invite_with_metadata.return_value - rest_client._entity_factory.deserialize_invite_with_metadata.assert_called_once_with( - rest_client._request.return_value - ) - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="cause why not :)") + assert result is patched_deserialize_invite_with_metadata.return_value + patched_deserialize_invite_with_metadata.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="cause why not :)") - async def test_pin_message(self, rest_client): - expected_route = routes.PUT_CHANNEL_PINS.compile(channel=123, message=456) - rest_client._request = mock.AsyncMock() + async def test_pin_message( + self, + rest_client: rest.RESTClientImpl, + hikari_guild_text_channel: channels.GuildTextChannel, + hikari_message: messages.Message, + ): + expected_route = routes.PUT_CHANNEL_PINS.compile(channel=4560, message=101) - await rest_client.pin_message(StubModel(123), StubModel(456)) - rest_client._request.assert_awaited_once_with(expected_route) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.pin_message(hikari_guild_text_channel, hikari_message) - async def test_unpin_message(self, rest_client): - expected_route = routes.DELETE_CHANNEL_PIN.compile(channel=123, message=456) - rest_client._request = mock.AsyncMock() + patched__request.assert_awaited_once_with(expected_route) - await rest_client.unpin_message(StubModel(123), StubModel(456)) - rest_client._request.assert_awaited_once_with(expected_route) + async def test_unpin_message( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): + expected_route = routes.DELETE_CHANNEL_PIN.compile(channel=4560, message=101) + + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.unpin_message(mock_guild_text_channel, mock_message) - async def test_fetch_message(self, rest_client): + patched__request.assert_awaited_once_with(expected_route) + + async def test_fetch_message( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): message_obj = mock.Mock() - expected_route = routes.GET_CHANNEL_MESSAGE.compile(channel=123, message=456) - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_message = mock.Mock(return_value=message_obj) + expected_route = routes.GET_CHANNEL_MESSAGE.compile(channel=4560, message=101) - assert await rest_client.fetch_message(StubModel(123), StubModel(456)) is message_obj - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"id": "456"}) - - async def test_create_message_when_form(self, rest_client: rest.RESTClientImpl): - attachment_obj = object() - attachment_obj2 = object() - component_obj = object() - component_obj2 = object() - embed_obj = object() - embed_obj2 = object() - poll_obj = object() + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_message", return_value=message_obj + ) as patched_deserialize_message, + ): + assert await rest_client.fetch_message(mock_guild_text_channel, mock_message) is message_obj + + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_message.assert_called_once_with({"id": "456"}) + + async def test_create_message_when_form( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): + attachment_obj = mock.Mock() + attachment_obj2 = mock.Mock() + component_obj = mock.Mock() + component_obj2 = mock.Mock() + embed_obj = mock.Mock() + embed_obj2 = mock.Mock() + poll_obj = mock.Mock() mock_form = mock.Mock() mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") - expected_route = routes.POST_CHANNEL_MESSAGES.compile(channel=123456789) - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, mock_form)) - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) + expected_route = routes.POST_CHANNEL_MESSAGES.compile(channel=4560) - returned = await rest_client.create_message( - StubModel(123456789), - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - poll=poll_obj, - sticker=54234, - stickers=[564123, 431123], - tts=True, - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - reply=StubModel(987654321), - reply_must_exist=False, - flags=54123, - ) - assert returned is rest_client._entity_factory.deserialize_message.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 987654321} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, mock_form) + ) as patched__build_message_payload, + ): + returned = await rest_client.create_message( + mock_guild_text_channel, + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + poll=poll_obj, + sticker=54234, + stickers=[564123, 431123], + tts=True, + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + reply=mock_message, + reply_must_exist=False, + flags=54123, + ) + assert returned is patched_deserialize_message.return_value - rest_client._build_message_payload.assert_called_once_with( + patched__build_message_payload.assert_called_once_with( content="new content", attachment=attachment_obj, attachments=[attachment_obj2], @@ -2838,47 +3501,61 @@ async def test_create_message_when_form(self, rest_client: rest.RESTClientImpl): ) mock_form.add_field.assert_called_once_with( "payload_json", - b'{"testing":"ensure_in_test","message_reference":{"message_id":"987654321","fail_if_not_exists":false}}', + b'{"testing":"ensure_in_test","message_reference":{"message_id":"101","fail_if_not_exists":false}}', content_type="application/json", ) - rest_client._request.assert_awaited_once_with(expected_route, form_builder=mock_form) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) + patched__request.assert_awaited_once_with(expected_route, form_builder=mock_form) + patched_deserialize_message.assert_called_once_with({"message_id": 987654321}) - async def test_create_message_when_no_form(self, rest_client): - attachment_obj = object() - attachment_obj2 = object() - component_obj = object() - component_obj2 = object() - embed_obj = object() - embed_obj2 = object() + async def test_create_message_when_no_form( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): + attachment_obj = mock.Mock() + attachment_obj2 = mock.Mock() + component_obj = mock.Mock() + component_obj2 = mock.Mock() + embed_obj = mock.Mock() + embed_obj2 = mock.Mock() + poll_obj = mock.Mock() mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") - expected_route = routes.POST_CHANNEL_MESSAGES.compile(channel=123456789) - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, None)) - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) + expected_route = routes.POST_CHANNEL_MESSAGES.compile(channel=4560) - returned = await rest_client.create_message( - StubModel(123456789), - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - sticker=543345, - stickers=[123321, 6572345], - tts=True, - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - reply=StubModel(987654321), - reply_must_exist=False, - flags=6643, - ) - assert returned is rest_client._entity_factory.deserialize_message.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 987654321} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, None) + ) as patched__build_message_payload, + ): + returned = await rest_client.create_message( + mock_guild_text_channel, + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + poll=poll_obj, + sticker=543345, + stickers=[123321, 6572345], + tts=True, + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + reply=mock_message, + reply_must_exist=False, + flags=6643, + ) + assert returned is patched_deserialize_message.return_value - rest_client._build_message_payload.assert_called_once_with( + patched__build_message_payload.assert_called_once_with( content="new content", attachment=attachment_obj, attachments=[attachment_obj2], @@ -2886,7 +3563,7 @@ async def test_create_message_when_no_form(self, rest_client): components=[component_obj2], embed=embed_obj, embeds=[embed_obj2], - poll=undefined.UNDEFINED, + poll=poll_obj, sticker=543345, stickers=[123321, 6572345], tts=True, @@ -2896,50 +3573,59 @@ async def test_create_message_when_no_form(self, rest_client): role_mentions=[1234], flags=6643, ) - rest_client._request.assert_awaited_once_with( + patched__request.assert_awaited_once_with( expected_route, - json={ - "testing": "ensure_in_test", - "message_reference": {"message_id": "987654321", "fail_if_not_exists": False}, - }, + json={"testing": "ensure_in_test", "message_reference": {"message_id": "101", "fail_if_not_exists": False}}, ) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) + patched_deserialize_message.assert_called_once_with({"message_id": 987654321}) - async def test_create_voice_message(self, rest_client): - expected_route = routes.POST_CHANNEL_MESSAGES.compile(channel=123456789) - attachment_obj = object() - reply_obj = StubModel(987654321) + async def test_create_voice_message( + self, + rest_client: rest.RESTClientImpl, + hikari_guild_text_channel: channels.GuildTextChannel, + hikari_message: messages.Message, + ): + expected_route = routes.POST_CHANNEL_MESSAGES.compile(channel=4560) + attachment_obj = mock.Mock() mock_form = mock.Mock() mock_body = data_binding.JSONObjectBuilder() - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) - rest_client._build_voice_message_payload = mock.Mock(return_value=(mock_body, mock_form)) - returned = await rest_client.create_voice_message( - StubModel(123456789), - attachment=attachment_obj, - waveform="AAA", - duration=3, - reply=reply_obj, - mentions_reply=True, - reply_must_exist=False, - flags=54123, - ) - assert returned is rest_client._entity_factory.deserialize_message.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 123} + ) as patched__request, + mock.patch.object( + rest_client, "_build_voice_message_payload", return_value=(mock_body, mock_form) + ) as patched__build_voice_message_payload, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + ): + returned = await rest_client.create_voice_message( + hikari_guild_text_channel, + attachment=attachment_obj, + waveform="AAA", + duration=3, + reply=hikari_message, + mentions_reply=True, + reply_must_exist=False, + flags=54123, + ) - rest_client._build_voice_message_payload.assert_called_once_with( + assert returned is patched_deserialize_message.return_value + + patched__build_voice_message_payload.assert_called_once_with( attachment=attachment_obj, mentions_reply=True, flags=54123, waveform="AAA", duration=3, - reply=reply_obj, + reply=hikari_message, reply_must_exist=False, ) - rest_client._request.assert_awaited_once_with(expected_route, form_builder=mock_form) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) + patched__request.assert_awaited_once_with(expected_route, form_builder=mock_form) + patched_deserialize_message.assert_called_once_with({"message_id": 123}) - async def test_forward_message_id(self, rest_client): + async def test_forward_message_id(self, rest_client: rest.RESTClientImpl): rest_client._request = mock.AsyncMock(return_value={"message_id": 1239}) expected_route = routes.POST_CHANNEL_MESSAGES.compile(channel=1234) m = await rest_client.forward_message(channel_to=1234, message=123, channel_from=12345) @@ -2950,17 +3636,17 @@ async def test_forward_message_id(self, rest_client): "message_reference": { "message_id": "123", "channel_id": "12345", - "type": message_models.MessageReferenceType.FORWARD, + "type": messages.MessageReferenceType.FORWARD, } }, ) rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 1239}) - async def test_forward_partial_message(self, rest_client): + async def test_forward_partial_message(self, rest_client: rest.RESTClientImpl): rest_client._request = mock.AsyncMock(return_value={"message_id": 1239}) expected_route = routes.POST_CHANNEL_MESSAGES.compile(channel=1234) m = mock.Mock( - spec=message_models.PartialMessage, + spec=messages.PartialMessage, id=snowflakes.Snowflake(123), channel_id=snowflakes.Snowflake(12345), __int__=lambda self: int(self.id), @@ -2973,13 +3659,13 @@ async def test_forward_partial_message(self, rest_client): "message_reference": { "message_id": "123", "channel_id": "12345", - "type": message_models.MessageReferenceType.FORWARD, + "type": messages.MessageReferenceType.FORWARD, } }, ) rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 1239}) - async def test_forward_fails_with_no_channel(self, rest_client): + async def test_forward_fails_with_no_channel(self, rest_client: rest.RESTClientImpl): rest_client._request = mock.AsyncMock(return_value={"message_id": 1239}) with pytest.raises(ValueError) as excinfo: await rest_client.forward_message(channel_to=1234, message=123) @@ -2987,500 +3673,709 @@ async def test_forward_fails_with_no_channel(self, rest_client): excinfo.value ) - async def test_create_voice_message_no_flags(self, rest_client): + async def test_create_voice_message_no_flags( + self, rest_client: rest.RESTClientImpl, hikari_guild_text_channel: channels.GuildTextChannel + ): rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) returned = await rest_client.create_voice_message( - StubModel(123456789), attachment=object(), waveform="AAA", duration=3 + hikari_guild_text_channel, attachment=mock.Mock(), waveform="AAA", duration=3 ) assert returned is rest_client._entity_factory.deserialize_message.return_value rest_client._request.assert_awaited_once() - async def test_crosspost_message(self, rest_client): - expected_route = routes.POST_CHANNEL_CROSSPOST.compile(channel=444432, message=12353234) - mock_message = object() - rest_client._entity_factory.deserialize_message = mock.Mock(return_value=mock_message) - rest_client._request = mock.AsyncMock(return_value={"id": "93939383883", "content": "foobar"}) + async def test_crosspost_message( + self, + rest_client: rest.RESTClientImpl, + mock_guild_news_channel: channels.GuildNewsChannel, + hikari_message: messages.Message, + ): + expected_route = routes.POST_CHANNEL_CROSSPOST.compile(channel=4565, message=101) - result = await rest_client.crosspost_message(StubModel(444432), StubModel(12353234)) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"id": "93939383883", "content": "foobar"}, + ) as patched__request, + mock.patch.object(rest_client._entity_factory, "deserialize_message") as patched_deserialize_message, + ): + result = await rest_client.crosspost_message(mock_guild_news_channel, hikari_message) - assert result is mock_message - rest_client._entity_factory.deserialize_message.assert_called_once_with( - {"id": "93939383883", "content": "foobar"} - ) - rest_client._request.assert_awaited_once_with(expected_route) + assert result is patched_deserialize_message.return_value + patched_deserialize_message.assert_called_once_with({"id": "93939383883", "content": "foobar"}) + patched__request.assert_awaited_once_with(expected_route) - async def test_edit_message_when_form(self, rest_client): - attachment_obj = object() - attachment_obj2 = object() - component_obj = object() - component_obj2 = object() - embed_obj = object() - embed_obj2 = object() + async def test_edit_message_when_form( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): + attachment_obj = mock.Mock() + attachment_obj2 = mock.Mock() + component_obj = mock.Mock() + component_obj2 = mock.Mock() + embed_obj = mock.Mock() + embed_obj2 = mock.Mock() mock_form = mock.Mock() mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") - expected_route = routes.PATCH_CHANNEL_MESSAGE.compile(channel=123456789, message=987654321) - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, mock_form)) - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) - - returned = await rest_client.edit_message( - StubModel(123456789), - StubModel(987654321), - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - flags=120, - ) - assert returned is rest_client._entity_factory.deserialize_message.return_value + expected_route = routes.PATCH_CHANNEL_MESSAGE.compile(channel=4560, message=101) - rest_client._build_message_payload.assert_called_once_with( - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - flags=120, - mentions_everyone=False, - mentions_reply=undefined.UNDEFINED, - user_mentions=[9876], - role_mentions=[1234], - edit=True, - ) - mock_form.add_field.assert_called_once_with( - "payload_json", b'{"testing":"ensure_in_test"}', content_type="application/json" - ) - rest_client._request.assert_awaited_once_with(expected_route, form_builder=mock_form) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 123} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, mock_form) + ) as patched__build_message_payload, + ): + returned = await rest_client.edit_message( + mock_guild_text_channel, + mock_message, + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + flags=messages.MessageFlag.NONE, + ) + assert returned is patched_deserialize_message.return_value + + patched__build_message_payload.assert_called_once_with( + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + flags=messages.MessageFlag.NONE, + mentions_everyone=False, + mentions_reply=undefined.UNDEFINED, + user_mentions=[9876], + role_mentions=[1234], + edit=True, + ) + mock_form.add_field.assert_called_once_with( + "payload_json", b'{"testing":"ensure_in_test"}', content_type="application/json" + ) + patched__request.assert_awaited_once_with(expected_route, form_builder=mock_form) + patched_deserialize_message.assert_called_once_with({"message_id": 123}) - async def test_edit_message_when_no_form(self, rest_client): - attachment_obj = object() - attachment_obj2 = object() - component_obj = object() - component_obj2 = object() - embed_obj = object() - embed_obj2 = object() + async def test_edit_message_when_no_form( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): + attachment_obj = mock.Mock() + attachment_obj2 = mock.Mock() + component_obj = mock.Mock() + component_obj2 = mock.Mock() + embed_obj = mock.Mock() + embed_obj2 = mock.Mock() mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") - expected_route = routes.PATCH_CHANNEL_MESSAGE.compile(channel=123456789, message=987654321) - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, None)) - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) + expected_route = routes.PATCH_CHANNEL_MESSAGE.compile(channel=4560, message=101) - returned = await rest_client.edit_message( - StubModel(123456789), - StubModel(987654321), - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - flags=120, - ) - assert returned is rest_client._entity_factory.deserialize_message.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 123} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, None) + ) as patched__build_message_payload, + ): + returned = await rest_client.edit_message( + mock_guild_text_channel, + mock_message, + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + flags=messages.MessageFlag.NONE, + ) + assert returned is patched_deserialize_message.return_value + + patched__build_message_payload.assert_called_once_with( + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + flags=messages.MessageFlag.NONE, + mentions_everyone=False, + mentions_reply=undefined.UNDEFINED, + user_mentions=[9876], + role_mentions=[1234], + edit=True, + ) + patched__request.assert_awaited_once_with(expected_route, json={"testing": "ensure_in_test"}) + patched_deserialize_message.assert_called_once_with({"message_id": 123}) - rest_client._build_message_payload.assert_called_once_with( - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - flags=120, - mentions_everyone=False, - mentions_reply=undefined.UNDEFINED, - user_mentions=[9876], - role_mentions=[1234], - edit=True, - ) - rest_client._request.assert_awaited_once_with(expected_route, json={"testing": "ensure_in_test"}) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) + async def test_follow_channel( + self, + rest_client: rest.RESTClientImpl, + mock_guild_news_channel: channels.GuildNewsChannel, + mock_guild_text_channel: channels.GuildTextChannel, + ): + expected_route = routes.POST_CHANNEL_FOLLOWERS.compile(channel=4565) + + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"channel_id": "929292", "webhook_id": "929383838"}, + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_channel_follow" + ) as patched_deserialize_channel_follow, + ): + result = await rest_client.follow_channel( + mock_guild_news_channel, mock_guild_text_channel, reason="get followed" + ) - async def test_follow_channel(self, rest_client): - expected_route = routes.POST_CHANNEL_FOLLOWERS.compile(channel=3333) - rest_client._request = mock.AsyncMock(return_value={"channel_id": "929292", "webhook_id": "929383838"}) + assert result is patched_deserialize_channel_follow.return_value + patched_deserialize_channel_follow.assert_called_once_with( + {"channel_id": "929292", "webhook_id": "929383838"} + ) + patched__request.assert_awaited_once_with( + expected_route, json={"webhook_channel_id": "4560"}, reason="get followed" + ) - result = await rest_client.follow_channel(StubModel(3333), StubModel(606060), reason="get followed") + async def test_delete_message( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): + expected_route = routes.DELETE_CHANNEL_MESSAGE.compile(channel=4560, message=101) - assert result is rest_client._entity_factory.deserialize_channel_follow.return_value - rest_client._entity_factory.deserialize_channel_follow.assert_called_once_with( - {"channel_id": "929292", "webhook_id": "929383838"} - ) - rest_client._request.assert_awaited_once_with( - expected_route, json={"webhook_channel_id": "606060"}, reason="get followed" - ) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_message(mock_guild_text_channel, mock_message, reason="broke laws") - async def test_delete_message(self, rest_client): - expected_route = routes.DELETE_CHANNEL_MESSAGE.compile(channel=123, message=456) - rest_client._request = mock.AsyncMock() + patched__request.assert_awaited_once_with(expected_route, reason="broke laws") - await rest_client.delete_message(StubModel(123), StubModel(456), reason="broke laws") - rest_client._request.assert_awaited_once_with(expected_route, reason="broke laws") + async def test_delete_messages( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): + expected_route = routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=4560) - async def test_delete_messages(self, rest_client): - messages = [StubModel(i) for i in range(200)] - expected_route = routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=123) + messages_list = [make_mock_message(i) for i in range(200)] expected_json1 = {"messages": [str(i) for i in range(100)]} expected_json2 = {"messages": [str(i) for i in range(100, 200)]} - rest_client._request = mock.AsyncMock() - - await rest_client.delete_messages(StubModel(123), *messages, reason="broke laws") + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_messages(mock_guild_text_channel, *messages_list, reason="broke laws") - rest_client._request.assert_has_awaits( - [ - mock.call(expected_route, json=expected_json1, reason="broke laws"), - mock.call(expected_route, json=expected_json2, reason="broke laws"), - ] - ) + patched__request.assert_has_awaits( + [ + mock.call(expected_route, json=expected_json1, reason="broke laws"), + mock.call(expected_route, json=expected_json2, reason="broke laws"), + ] + ) async def test_delete_messages_when_one_message_left_in_chunk_and_delete_message_raises_message_not_found( - self, rest_client + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel ): - channel = StubModel(123) - messages = [StubModel(i) for i in range(101)] + messages = [make_mock_message(i) for i in range(101)] message = messages[-1] expected_json = {"messages": [str(i) for i in range(100)]} - rest_client._request = mock.AsyncMock() - rest_client.delete_message = mock.AsyncMock( - side_effect=errors.NotFoundError(url="", headers={}, raw_body="", code=10008) - ) - - await rest_client.delete_messages(channel, *messages, reason="broke laws") + with ( + mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request, + mock.patch.object( + rest_client, + "delete_message", + side_effect=errors.NotFoundError(url="", headers={}, raw_body="", code=10008), + ) as patched_delete_message, + ): + await rest_client.delete_messages(mock_guild_text_channel, *messages, reason="broke laws") - rest_client._request.assert_awaited_once_with( - routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=channel), json=expected_json, reason="broke laws" + patched__request.assert_awaited_once_with( + routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=mock_guild_text_channel), + json=expected_json, + reason="broke laws", ) - rest_client.delete_message.assert_awaited_once_with(channel, message, reason="broke laws") + patched_delete_message.assert_awaited_once_with(mock_guild_text_channel, message, reason="broke laws") async def test_delete_messages_when_one_message_left_in_chunk_and_delete_message_raises_channel_not_found( - self, rest_client + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel ): - channel = StubModel(123) - messages = [StubModel(i) for i in range(101)] + messages = [make_mock_message(i) for i in range(101)] message = messages[-1] expected_json = {"messages": [str(i) for i in range(100)]} - rest_client._request = mock.AsyncMock() mock_not_found = errors.NotFoundError(url="", headers={}, raw_body="", code=10003) - rest_client.delete_message = mock.AsyncMock(side_effect=mock_not_found) - with pytest.raises(errors.BulkDeleteError) as exc_info: - await rest_client.delete_messages(channel, *messages, reason="broke laws") + with ( + mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request, + mock.patch.object(rest_client, "delete_message", side_effect=mock_not_found) as patched_delete_message, + pytest.raises(errors.BulkDeleteError) as exc_info, + ): + await rest_client.delete_messages(mock_guild_text_channel, *messages, reason="broke laws") assert exc_info.value.__cause__ is mock_not_found - rest_client._request.assert_awaited_once_with( - routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=channel), json=expected_json, reason="broke laws" + patched__request.assert_awaited_once_with( + routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=mock_guild_text_channel), + json=expected_json, + reason="broke laws", ) - rest_client.delete_message.assert_awaited_once_with(channel, message, reason="broke laws") + patched_delete_message.assert_awaited_once_with(mock_guild_text_channel, message, reason="broke laws") - async def test_delete_messages_when_one_message_left_in_chunk(self, rest_client): - channel = StubModel(123) - messages = [StubModel(i) for i in range(101)] + async def test_delete_messages_when_one_message_left_in_chunk( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): + messages = [make_mock_message(i) for i in range(101)] message = messages[-1] expected_json = {"messages": [str(i) for i in range(100)]} - rest_client._request = mock.AsyncMock() + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_messages(mock_guild_text_channel, *messages, reason="broke laws") + + patched__request.assert_has_awaits( + [ + mock.call( + routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=mock_guild_text_channel), + json=expected_json, + reason="broke laws", + ), + mock.call( + routes.DELETE_CHANNEL_MESSAGE.compile(channel=mock_guild_text_channel, message=message), + reason="broke laws", + ), + ] + ) - await rest_client.delete_messages(channel, *messages, reason="broke laws") + async def test_delete_messages_when_exception( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): + messages = [make_mock_message(i) for i in range(101)] - rest_client._request.assert_has_awaits( - [ - mock.call( - routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=channel), - json=expected_json, - reason="broke laws", - ), - mock.call(routes.DELETE_CHANNEL_MESSAGE.compile(channel=channel, message=message), reason="broke laws"), - ] - ) + with ( + mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock, side_effect=Exception), + pytest.raises(errors.BulkDeleteError), + ): + await rest_client.delete_messages(mock_guild_text_channel, *messages, reason="broke laws") - async def test_delete_messages_when_exception(self, rest_client): - channel = StubModel(123) - messages = [StubModel(i) for i in range(101)] + async def test_delete_messages_with_iterable( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): + message_list = (make_mock_message(i) for i in range(101)) - rest_client._request = mock.AsyncMock(side_effect=Exception) + message_1 = make_mock_message(444) + message_2 = make_mock_message(6523) - with pytest.raises(errors.BulkDeleteError): - await rest_client.delete_messages(channel, *messages) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_messages( + mock_guild_text_channel, message_list, message_1, message_2, reason="broke laws" + ) - async def test_delete_messages_with_iterable(self, rest_client): - channel = StubModel(54123) - messages = (StubModel(i) for i in range(101)) + patched__request.assert_has_awaits( + [ + mock.call( + routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=mock_guild_text_channel), + json={"messages": [str(i) for i in range(100)]}, + reason="broke laws", + ), + mock.call( + routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=mock_guild_text_channel), + json={"messages": ["100", "444", "6523"]}, + reason="broke laws", + ), + ] + ) - rest_client._request = mock.AsyncMock() + async def test_delete_messages_with_async_iterable( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): + iterator = iterators.FlatLazyIterator(make_mock_message(i) for i in range(103)) + + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_messages(mock_guild_text_channel, iterator) + + patched__request.assert_has_awaits( + [ + mock.call( + routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=mock_guild_text_channel), + json={"messages": [str(i) for i in range(100)]}, + reason=undefined.UNDEFINED, + ), + mock.call( + routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=mock_guild_text_channel), + json={"messages": ["100", "101", "102"]}, + reason=undefined.UNDEFINED, + ), + ] + ) - await rest_client.delete_messages(channel, messages, StubModel(444), StubModel(6523), reason="broke laws") + async def test_delete_messages_with_async_iterable_and_args(self, rest_client: rest.RESTClientImpl): + with pytest.raises(TypeError, match=re.escape("Cannot use *args with an async iterable.")): + await rest_client.delete_messages(54123, iterators.FlatLazyIterator(()), 1, 2) - rest_client._request.assert_has_awaits( - [ - mock.call( - routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=channel), - json={"messages": [str(i) for i in range(100)]}, - reason="broke laws", - ), - mock.call( - routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=channel), - json={"messages": ["100", "444", "6523"]}, - reason="broke laws", - ), - ] - ) + async def test_add_reaction( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): + expected_route = routes.PUT_MY_REACTION.compile(emoji="rooYay:123", channel=4560, message=101) - async def test_delete_messages_with_async_iterable(self, rest_client): - channel = StubModel(54123) - iterator = iterators.FlatLazyIterator(StubModel(i) for i in range(103)) + with ( + mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request, + mock.patch.object(rest, "_transform_emoji_to_url_format", return_value="rooYay:123"), + ): + await rest_client.add_reaction(mock_guild_text_channel, mock_message, "<:rooYay:123>") - rest_client._request = mock.AsyncMock() + patched__request.assert_awaited_once_with(expected_route) - await rest_client.delete_messages(channel, iterator, reason="broke laws") + async def test_delete_my_reaction( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): + expected_route = routes.DELETE_MY_REACTION.compile(emoji="rooYay:123", channel=4560, message=101) - rest_client._request.assert_has_awaits( - [ - mock.call( - routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=channel), - json={"messages": [str(i) for i in range(100)]}, - reason="broke laws", - ), - mock.call( - routes.POST_DELETE_CHANNEL_MESSAGES_BULK.compile(channel=channel), - json={"messages": ["100", "101", "102"]}, - reason="broke laws", - ), - ] - ) + with ( + mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request, + mock.patch.object(rest, "_transform_emoji_to_url_format", return_value="rooYay:123"), + ): + await rest_client.delete_my_reaction(mock_guild_text_channel, mock_message, "<:rooYay:123>") - async def test_delete_messages_with_async_iterable_and_args(self, rest_client): - with pytest.raises(TypeError, match=re.escape("Cannot use *args with an async iterable.")): - await rest_client.delete_messages(54123, iterators.FlatLazyIterator(()), 1, 2) + patched__request.assert_awaited_once_with(expected_route) - async def test_add_reaction(self, rest_client): - expected_route = routes.PUT_MY_REACTION.compile(emoji="rooYay:123", channel=123, message=456) - rest_client._request = mock.AsyncMock() + async def test_delete_all_reactions_for_emoji( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): + expected_route = routes.DELETE_REACTION_EMOJI.compile(emoji="rooYay:123", channel=4560, message=101) - with mock.patch.object(rest, "_transform_emoji_to_url_format", return_value="rooYay:123"): - await rest_client.add_reaction(StubModel(123), StubModel(456), "<:rooYay:123>") + with ( + mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request, + mock.patch.object(rest, "_transform_emoji_to_url_format", return_value="rooYay:123"), + ): + await rest_client.delete_all_reactions_for_emoji(mock_guild_text_channel, mock_message, "<:rooYay:123>") - rest_client._request.assert_awaited_once_with(expected_route) + patched__request.assert_awaited_once_with(expected_route) - async def test_delete_my_reaction(self, rest_client): - expected_route = routes.DELETE_MY_REACTION.compile(emoji="rooYay:123", channel=123, message=456) - rest_client._request = mock.AsyncMock() + async def test_delete_reaction( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_user: users.User, + mock_message: messages.Message, + ): + expected_route = routes.DELETE_REACTION_USER.compile(emoji="rooYay:123", channel=4560, message=101, user=789) - with mock.patch.object(rest, "_transform_emoji_to_url_format", return_value="rooYay:123"): - await rest_client.delete_my_reaction(StubModel(123), StubModel(456), "<:rooYay:123>") + with ( + mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request, + mock.patch.object(rest, "_transform_emoji_to_url_format", return_value="rooYay:123"), + ): + await rest_client.delete_reaction(mock_guild_text_channel, mock_message, mock_user, "<:rooYay:123>") - rest_client._request.assert_awaited_once_with(expected_route) + patched__request.assert_awaited_once_with(expected_route) - async def test_delete_all_reactions_for_emoji(self, rest_client): - expected_route = routes.DELETE_REACTION_EMOJI.compile(emoji="rooYay:123", channel=123, message=456) - rest_client._request = mock.AsyncMock() + async def test_delete_all_reactions( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): + expected_route = routes.DELETE_ALL_REACTIONS.compile(channel=4560, message=101) - with mock.patch.object(rest, "_transform_emoji_to_url_format", return_value="rooYay:123"): - await rest_client.delete_all_reactions_for_emoji(StubModel(123), StubModel(456), "<:rooYay:123>") + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_all_reactions(mock_guild_text_channel, mock_message) - rest_client._request.assert_awaited_once_with(expected_route) + patched__request.assert_awaited_once_with(expected_route) - async def test_delete_reaction(self, rest_client): - expected_route = routes.DELETE_REACTION_USER.compile(emoji="rooYay:123", channel=123, message=456, user=789) - rest_client._request = mock.AsyncMock() + async def test_create_webhook( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + file_resource_patch: files.Resource[typing.Any], + ): + webhook = mock.Mock(webhooks.PartialWebhook) + expected_route = routes.POST_CHANNEL_WEBHOOKS.compile(channel=4560) - with mock.patch.object(rest, "_transform_emoji_to_url_format", return_value="rooYay:123"): - await rest_client.delete_reaction(StubModel(123), StubModel(456), StubModel(789), "<:rooYay:123>") + expected_json = {"name": "test webhook", "avatar": "some data"} - rest_client._request.assert_awaited_once_with(expected_route) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_incoming_webhook", return_value=webhook + ) as patched_deserialize_incoming_webhook, + ): + returned = await rest_client.create_webhook( + mock_guild_text_channel, "test webhook", avatar="someavatar.png", reason="why not" + ) + assert returned is webhook - async def test_delete_all_reactions(self, rest_client): - expected_route = routes.DELETE_ALL_REACTIONS.compile(channel=123, message=456) - rest_client._request = mock.AsyncMock() + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="why not") + patched_deserialize_incoming_webhook.assert_called_once_with({"id": "456"}) - await rest_client.delete_all_reactions(StubModel(123), StubModel(456)) - rest_client._request.assert_awaited_once_with(expected_route) + async def test_create_webhook_without_optionals( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_partial_webhook: webhooks.PartialWebhook, + ): + expected_route = routes.POST_CHANNEL_WEBHOOKS.compile(channel=4560) + expected_json = {"name": "test webhook"} - async def test_create_webhook(self, rest_client, file_resource_patch): - webhook = StubModel(456) - expected_route = routes.POST_CHANNEL_WEBHOOKS.compile(channel=123) - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - expected_json = {"name": "test webhook", "avatar": "some data"} - rest_client._entity_factory.deserialize_incoming_webhook = mock.Mock(return_value=webhook) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_incoming_webhook", return_value=mock_partial_webhook + ) as patched_deserialize_incoming_webhook, + ): + assert await rest_client.create_webhook(mock_guild_text_channel, "test webhook") is mock_partial_webhook + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason=undefined.UNDEFINED) + patched_deserialize_incoming_webhook.assert_called_once_with({"id": "456"}) - returned = await rest_client.create_webhook( - StubModel(123), "test webhook", avatar="someavatar.png", reason="why not" - ) - assert returned is webhook + async def test_fetch_webhook(self, rest_client: rest.RESTClientImpl, mock_partial_webhook: webhooks.PartialWebhook): + expected_route = routes.GET_WEBHOOK_WITH_TOKEN.compile(webhook=112, token="token") - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="why not") - rest_client._entity_factory.deserialize_incoming_webhook.assert_called_once_with({"id": "456"}) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_webhook", return_value=mock_partial_webhook + ) as patched_deserialize_webhook, + ): + assert await rest_client.fetch_webhook(mock_partial_webhook, token="token") is mock_partial_webhook + patched__request.assert_awaited_once_with(expected_route, auth=None) + patched_deserialize_webhook.assert_called_once_with({"id": "456"}) - async def test_create_webhook_without_optionals(self, rest_client): - webhook = StubModel(456) - expected_route = routes.POST_CHANNEL_WEBHOOKS.compile(channel=123) - expected_json = {"name": "test webhook"} - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_incoming_webhook = mock.Mock(return_value=webhook) + async def test_fetch_webhook_without_token( + self, rest_client: rest.RESTClientImpl, mock_partial_webhook: webhooks.PartialWebhook + ): + expected_route = routes.GET_WEBHOOK.compile(webhook=112) - assert await rest_client.create_webhook(StubModel(123), "test webhook") is webhook - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason=undefined.UNDEFINED) - rest_client._entity_factory.deserialize_incoming_webhook.assert_called_once_with({"id": "456"}) - - async def test_fetch_webhook(self, rest_client): - webhook = StubModel(123) - expected_route = routes.GET_WEBHOOK_WITH_TOKEN.compile(webhook=123, token="token") - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_webhook = mock.Mock(return_value=webhook) - - assert await rest_client.fetch_webhook(StubModel(123), token="token") is webhook - rest_client._request.assert_awaited_once_with(expected_route, auth=None) - rest_client._entity_factory.deserialize_webhook.assert_called_once_with({"id": "456"}) - - async def test_fetch_webhook_without_token(self, rest_client): - webhook = StubModel(123) - expected_route = routes.GET_WEBHOOK.compile(webhook=123) - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_webhook = mock.Mock(return_value=webhook) - - assert await rest_client.fetch_webhook(StubModel(123)) is webhook - rest_client._request.assert_awaited_once_with(expected_route, auth=undefined.UNDEFINED) - rest_client._entity_factory.deserialize_webhook.assert_called_once_with({"id": "456"}) - - async def test_fetch_channel_webhooks(self, rest_client): - webhook1 = StubModel(456) - webhook2 = StubModel(789) - expected_route = routes.GET_CHANNEL_WEBHOOKS.compile(channel=123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) - rest_client._entity_factory.deserialize_webhook = mock.Mock(side_effect=[webhook1, webhook2]) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_webhook", return_value=mock_partial_webhook + ) as patched_deserialize_webhook, + ): + assert await rest_client.fetch_webhook(mock_partial_webhook) is mock_partial_webhook + patched__request.assert_awaited_once_with(expected_route, auth=undefined.UNDEFINED) + patched_deserialize_webhook.assert_called_once_with({"id": "456"}) - assert await rest_client.fetch_channel_webhooks(StubModel(123)) == [webhook1, webhook2] - rest_client._request.assert_awaited_once_with(expected_route) - assert rest_client._entity_factory.deserialize_webhook.call_count == 2 - rest_client._entity_factory.deserialize_webhook.assert_has_calls( - [mock.call({"id": "456"}), mock.call({"id": "789"})] - ) + async def test_fetch_channel_webhooks( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): + webhook1 = make_partial_webhook(238947239847) + webhook2 = make_partial_webhook(218937419827) + expected_route = routes.GET_CHANNEL_WEBHOOKS.compile(channel=4560) - async def test_fetch_channel_webhooks_ignores_unrecognised_webhook_type(self, rest_client): - webhook1 = StubModel(456) - expected_route = routes.GET_CHANNEL_WEBHOOKS.compile(channel=123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) - rest_client._entity_factory.deserialize_webhook = mock.Mock( - side_effect=[errors.UnrecognisedEntityError("yeet"), webhook1] - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "456"}, {"id": "789"}] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_webhook", side_effect=[webhook1, webhook2] + ) as patched_deserialize_webhook, + ): + assert await rest_client.fetch_channel_webhooks(mock_guild_text_channel) == [webhook1, webhook2] + patched__request.assert_awaited_once_with(expected_route) + assert patched_deserialize_webhook.call_count == 2 + patched_deserialize_webhook.assert_has_calls([mock.call({"id": "456"}), mock.call({"id": "789"})]) - assert await rest_client.fetch_channel_webhooks(StubModel(123)) == [webhook1] - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_webhook.assert_has_calls( - [mock.call({"id": "456"}), mock.call({"id": "789"})] - ) + async def test_fetch_channel_webhooks_ignores_unrecognised_webhook_type( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_partial_webhook: webhooks.PartialWebhook, + ): + expected_route = routes.GET_CHANNEL_WEBHOOKS.compile(channel=4560) + + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "456"}, {"id": "789"}] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, + "deserialize_webhook", + side_effect=[errors.UnrecognisedEntityError("yeet"), mock_partial_webhook], + ) as patched_deserialize_webhook, + ): + assert await rest_client.fetch_channel_webhooks(mock_guild_text_channel) == [mock_partial_webhook] + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_webhook.assert_has_calls([mock.call({"id": "456"}), mock.call({"id": "789"})]) + + async def test_fetch_guild_webhooks(self, rest_client: rest.RESTClientImpl): + webhook1 = make_partial_webhook(456) + webhook2 = make_partial_webhook(789) - async def test_fetch_guild_webhooks(self, rest_client): - webhook1 = StubModel(456) - webhook2 = StubModel(789) expected_route = routes.GET_GUILD_WEBHOOKS.compile(guild=123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) - rest_client._entity_factory.deserialize_webhook = mock.Mock(side_effect=[webhook1, webhook2]) - assert await rest_client.fetch_guild_webhooks(StubModel(123)) == [webhook1, webhook2] - rest_client._request.assert_awaited_once_with(expected_route) - assert rest_client._entity_factory.deserialize_webhook.call_count == 2 - rest_client._entity_factory.deserialize_webhook.assert_has_calls( - [mock.call({"id": "456"}), mock.call({"id": "789"})] - ) + mock_guild = mock.Mock(guilds.PartialGuild, id=snowflakes.Snowflake(123)) + + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "456"}, {"id": "789"}] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_webhook", side_effect=[webhook1, webhook2] + ) as patched_deserialize_webhook, + ): + assert await rest_client.fetch_guild_webhooks(mock_guild) == [webhook1, webhook2] + patched__request.assert_awaited_once_with(expected_route) + assert patched_deserialize_webhook.call_count == 2 + patched_deserialize_webhook.assert_has_calls([mock.call({"id": "456"}), mock.call({"id": "789"})]) - async def test_fetch_guild_webhooks_ignores_unrecognised_webhook_types(self, rest_client): - webhook1 = StubModel(456) + async def test_fetch_guild_webhooks_ignores_unrecognised_webhook_types( + self, rest_client: rest.RESTClientImpl, mock_partial_webhook: webhooks.PartialWebhook + ): expected_route = routes.GET_GUILD_WEBHOOKS.compile(guild=123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) - rest_client._entity_factory.deserialize_webhook = mock.Mock( - side_effect=[errors.UnrecognisedEntityError("meow meow"), webhook1] - ) - assert await rest_client.fetch_guild_webhooks(StubModel(123)) == [webhook1] - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_webhook.assert_has_calls( - [mock.call({"id": "456"}), mock.call({"id": "789"})] - ) + mock_guild = mock.Mock(guilds.PartialGuild, id=snowflakes.Snowflake(123)) + + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "456"}, {"id": "789"}] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, + "deserialize_webhook", + side_effect=[errors.UnrecognisedEntityError("meow meow"), mock_partial_webhook], + ) as patched_deserialize_webhook, + ): + assert await rest_client.fetch_guild_webhooks(mock_guild) == [mock_partial_webhook] + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_webhook.assert_has_calls([mock.call({"id": "456"}), mock.call({"id": "789"})]) - async def test_edit_webhook(self, rest_client): - webhook = StubModel(456) - expected_route = routes.PATCH_WEBHOOK_WITH_TOKEN.compile(webhook=123, token="token") - expected_json = {"name": "some other name", "channel": "789", "avatar": None} - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_webhook = mock.Mock(return_value=webhook) + async def test_edit_webhook( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_partial_webhook: webhooks.PartialWebhook, + ): + expected_route = routes.PATCH_WEBHOOK_WITH_TOKEN.compile(webhook=112, token="token") + expected_json = {"name": "some other name", "channel": "4560", "avatar": None} - returned = await rest_client.edit_webhook( - StubModel(123), - token="token", - name="some other name", - avatar=None, - channel=StubModel(789), - reason="some smart reason to do this", - ) - assert returned is webhook + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_webhook", return_value=mock_partial_webhook + ) as patched_deserialize_webhook, + ): + returned = await rest_client.edit_webhook( + mock_partial_webhook, + token="token", + name="some other name", + avatar=None, + channel=mock_guild_text_channel, + reason="some smart reason to do this", + ) + assert returned is mock_partial_webhook - rest_client._request.assert_awaited_once_with( - expected_route, json=expected_json, reason="some smart reason to do this", auth=None - ) - rest_client._entity_factory.deserialize_webhook.assert_called_once_with({"id": "456"}) + patched__request.assert_awaited_once_with( + expected_route, json=expected_json, reason="some smart reason to do this", auth=None + ) + patched_deserialize_webhook.assert_called_once_with({"id": "456"}) - async def test_edit_webhook_without_token(self, rest_client): - webhook = StubModel(456) - expected_route = routes.PATCH_WEBHOOK.compile(webhook=123) + async def test_edit_webhook_without_token( + self, rest_client: rest.RESTClientImpl, mock_partial_webhook: webhooks.PartialWebhook + ): + expected_route = routes.PATCH_WEBHOOK.compile(webhook=112) expected_json = {} - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_webhook = mock.Mock(return_value=webhook) - returned = await rest_client.edit_webhook(StubModel(123)) - assert returned is webhook + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_webhook", return_value=mock_partial_webhook + ) as patched_deserialize_webhook, + ): + returned = await rest_client.edit_webhook(mock_partial_webhook) + assert returned is mock_partial_webhook - rest_client._request.assert_awaited_once_with( - expected_route, json=expected_json, reason=undefined.UNDEFINED, auth=undefined.UNDEFINED - ) - rest_client._entity_factory.deserialize_webhook.assert_called_once_with({"id": "456"}) + patched__request.assert_awaited_once_with( + expected_route, json=expected_json, reason=undefined.UNDEFINED, auth=undefined.UNDEFINED + ) + patched_deserialize_webhook.assert_called_once_with({"id": "456"}) - async def test_edit_webhook_when_avatar_is_file(self, rest_client, file_resource_patch): - webhook = StubModel(456) - expected_route = routes.PATCH_WEBHOOK.compile(webhook=123) + async def test_edit_webhook_when_avatar_is_file( + self, + rest_client: rest.RESTClientImpl, + mock_partial_webhook: webhooks.PartialWebhook, + file_resource_patch: files.Resource[typing.Any], + ): + expected_route = routes.PATCH_WEBHOOK.compile(webhook=112) expected_json = {"avatar": "some data"} - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_webhook = mock.Mock(return_value=webhook) - assert await rest_client.edit_webhook(StubModel(123), avatar="someavatar.png") is webhook + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_webhook", return_value=mock_partial_webhook + ) as patched_deserialize_webhook, + ): + assert await rest_client.edit_webhook(mock_partial_webhook, avatar="someavatar.png") is mock_partial_webhook - rest_client._request.assert_awaited_once_with( - expected_route, json=expected_json, reason=undefined.UNDEFINED, auth=undefined.UNDEFINED - ) - rest_client._entity_factory.deserialize_webhook.assert_called_once_with({"id": "456"}) + patched__request.assert_awaited_once_with( + expected_route, json=expected_json, reason=undefined.UNDEFINED, auth=undefined.UNDEFINED + ) + patched_deserialize_webhook.assert_called_once_with({"id": "456"}) - async def test_delete_webhook(self, rest_client): - expected_route = routes.DELETE_WEBHOOK_WITH_TOKEN.compile(webhook=123, token="token") - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) + async def test_delete_webhook( + self, rest_client: rest.RESTClientImpl, mock_partial_webhook: webhooks.PartialWebhook + ): + expected_route = routes.DELETE_WEBHOOK_WITH_TOKEN.compile(webhook=112, token="token") - await rest_client.delete_webhook(StubModel(123), token="token", reason="testing") - rest_client._request.assert_awaited_once_with(expected_route, auth=None, reason="testing") + with mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request: + await rest_client.delete_webhook(mock_partial_webhook, token="token", reason="testing") + patched__request.assert_awaited_once_with(expected_route, auth=None, reason="testing") - async def test_delete_webhook_without_token(self, rest_client): - expected_route = routes.DELETE_WEBHOOK.compile(webhook=123) - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) + async def test_delete_webhook_without_token( + self, rest_client: rest.RESTClientImpl, mock_partial_webhook: webhooks.PartialWebhook + ): + expected_route = routes.DELETE_WEBHOOK.compile(webhook=112) - await rest_client.delete_webhook(StubModel(123)) - rest_client._request.assert_awaited_once_with( - expected_route, auth=undefined.UNDEFINED, reason=undefined.UNDEFINED - ) + with mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request: + await rest_client.delete_webhook(mock_partial_webhook) + patched__request.assert_awaited_once_with( + expected_route, auth=undefined.UNDEFINED, reason=undefined.UNDEFINED + ) @pytest.mark.parametrize( ("webhook", "avatar_url"), @@ -3489,43 +4384,52 @@ async def test_delete_webhook_without_token(self, rest_client): (432, "https://website.com/davfsa_logo"), ], ) - async def test_execute_webhook_when_form(self, rest_client, webhook, avatar_url): - attachment_obj = object() - attachment_obj2 = object() - component_obj = object() - component_obj2 = object() - embed_obj = object() - embed_obj2 = object() - poll_obj = object() + async def test_execute_webhook_when_form( + self, rest_client: rest.RESTClientImpl, webhook: webhooks.ExecutableWebhook, avatar_url: files.URL + ): + attachment_obj = mock.Mock() + attachment_obj2 = mock.Mock() + component_obj = mock.Mock() + component_obj2 = mock.Mock() + embed_obj = mock.Mock() + embed_obj2 = mock.Mock() + poll_obj = mock.Mock() mock_form = mock.Mock() mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") expected_route = routes.POST_WEBHOOK_WITH_TOKEN.compile(webhook=432, token="hi, im a token") - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, mock_form)) - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) - returned = await rest_client.execute_webhook( - webhook, - "hi, im a token", - username="davfsa", - avatar_url=avatar_url, - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - poll=poll_obj, - tts=True, - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - flags=120, - ) - assert returned is rest_client._entity_factory.deserialize_message.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 123} + ) as patched__request, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, mock_form) + ) as patched__build_message_payload, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + ): + returned = await rest_client.execute_webhook( + webhook, + "hi, im a token", + username="davfsa", + avatar_url=avatar_url, + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + poll=poll_obj, + tts=True, + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + flags=120, + ) + assert returned is patched_deserialize_message.return_value - rest_client._build_message_payload.assert_called_once_with( + patched__build_message_payload.assert_called_once_with( content="new content", attachment=attachment_obj, attachments=[attachment_obj2], @@ -3545,25 +4449,34 @@ async def test_execute_webhook_when_form(self, rest_client, webhook, avatar_url) b'{"testing":"ensure_in_test","username":"davfsa","avatar_url":"https://website.com/davfsa_logo"}', content_type="application/json", ) - rest_client._request.assert_awaited_once_with( + patched__request.assert_awaited_once_with( expected_route, form_builder=mock_form, query={"wait": "true", "with_components": "true"}, auth=None ) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) + patched_deserialize_message.assert_called_once_with({"message_id": 123}) - async def test_execute_webhook_when_form_and_thread(self, rest_client): + async def test_execute_webhook_when_form_and_thread( + self, rest_client: rest.RESTClientImpl, mock_guild_public_thread_channel: channels.GuildThreadChannel + ): mock_form = mock.Mock() mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") - expected_route = routes.POST_WEBHOOK_WITH_TOKEN.compile(webhook=432, token="hi, im a token") - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, mock_form)) - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) + expected_route = routes.POST_WEBHOOK_WITH_TOKEN.compile(webhook=112, token="hi, im a token") - returned = await rest_client.execute_webhook( - 432, "hi, im a token", content="new content", thread=StubModel(1234543123) - ) - assert returned is rest_client._entity_factory.deserialize_message.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 123} + ) as patched__request, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, mock_form) + ) as patched__build_message_payload, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + ): + returned = await rest_client.execute_webhook( + 112, "hi, im a token", content="new content", thread=mock_guild_public_thread_channel + ) + assert returned is patched_deserialize_message.return_value - rest_client._build_message_payload.assert_called_once_with( + patched__build_message_payload.assert_called_once_with( content="new content", attachment=undefined.UNDEFINED, attachments=undefined.UNDEFINED, @@ -3581,27 +4494,36 @@ async def test_execute_webhook_when_form_and_thread(self, rest_client): mock_form.add_field.assert_called_once_with( "payload_json", b'{"testing":"ensure_in_test"}', content_type="application/json" ) - rest_client._request.assert_awaited_once_with( + patched__request.assert_awaited_once_with( expected_route, form_builder=mock_form, - query={"wait": "true", "with_components": "true", "thread_id": "1234543123"}, + query={"wait": "true", "with_components": "true", "thread_id": "45611"}, auth=None, ) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) + patched_deserialize_message.assert_called_once_with({"message_id": 123}) - async def test_execute_webhook_when_no_form(self, rest_client): + async def test_execute_webhook_when_no_form( + self, rest_client: rest.RESTClientImpl, mock_guild_public_thread_channel: channels.GuildThreadChannel + ): mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") expected_route = routes.POST_WEBHOOK_WITH_TOKEN.compile(webhook=432, token="hi, im a token") - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, None)) - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) - returned = await rest_client.execute_webhook( - 432, "hi, im a token", content="new content", thread=StubModel(2134312123) - ) - assert returned is rest_client._entity_factory.deserialize_message.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 123} + ) as patched__request, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, None) + ) as patched__build_message_payload, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + ): + returned = await rest_client.execute_webhook( + 432, "hi, im a token", content="new content", thread=mock_guild_public_thread_channel + ) + assert returned is patched_deserialize_message.return_value - rest_client._build_message_payload.assert_called_once_with( + patched__build_message_payload.assert_called_once_with( content="new content", attachment=undefined.UNDEFINED, attachments=undefined.UNDEFINED, @@ -3616,50 +4538,57 @@ async def test_execute_webhook_when_no_form(self, rest_client): user_mentions=undefined.UNDEFINED, role_mentions=undefined.UNDEFINED, ) - rest_client._request.assert_awaited_once_with( + patched__request.assert_awaited_once_with( expected_route, json={"testing": "ensure_in_test"}, - query={"wait": "true", "with_components": "true", "thread_id": "2134312123"}, + query={"wait": "true", "with_components": "true", "thread_id": "45611"}, auth=None, ) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) + patched_deserialize_message.assert_called_once_with({"message_id": 123}) - async def test_execute_webhook_when_thread_and_no_form(self, rest_client): - attachment_obj = object() - attachment_obj2 = object() - component_obj = object() - component_obj2 = object() - embed_obj = object() - embed_obj2 = object() - poll_obj = object() + async def test_execute_webhook_when_thread_and_no_form(self, rest_client: rest.RESTClientImpl): + attachment_obj = mock.Mock() + attachment_obj2 = mock.Mock() + component_obj = mock.Mock() + component_obj2 = mock.Mock() + embed_obj = mock.Mock() + embed_obj2 = mock.Mock() + poll_obj = mock.Mock() mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") expected_route = routes.POST_WEBHOOK_WITH_TOKEN.compile(webhook=432, token="hi, im a token") - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, None)) - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) - returned = await rest_client.execute_webhook( - 432, - "hi, im a token", - username="davfsa", - avatar_url="https://website.com/davfsa_logo", - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - poll=poll_obj, - tts=True, - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - flags=120, - ) - assert returned is rest_client._entity_factory.deserialize_message.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 123} + ) as patched__request, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, None) + ) as patched__build_message_payload, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + ): + returned = await rest_client.execute_webhook( + 432, + "hi, im a token", + username="davfsa", + avatar_url="https://website.com/davfsa_logo", + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + poll=poll_obj, + tts=True, + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + flags=120, + ) + assert returned is patched_deserialize_message.return_value - rest_client._build_message_payload.assert_called_once_with( + patched__build_message_payload.assert_called_once_with( content="new content", attachment=attachment_obj, attachments=[attachment_obj2], @@ -3674,81 +4603,135 @@ async def test_execute_webhook_when_thread_and_no_form(self, rest_client): user_mentions=[9876], role_mentions=[1234], ) - rest_client._request.assert_awaited_once_with( + patched__request.assert_awaited_once_with( expected_route, json={"testing": "ensure_in_test", "username": "davfsa", "avatar_url": "https://website.com/davfsa_logo"}, query={"wait": "true", "with_components": "true"}, auth=None, ) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) + patched_deserialize_message.assert_called_once_with({"message_id": 123}) - async def test_execute_webhook_voice_message(self, rest_client): + async def test_execute_webhook_voice_message(self, rest_client: rest.RESTClientImpl): webhook = 432 token = "some token" expected_route = routes.POST_WEBHOOK_WITH_TOKEN.compile(webhook=webhook, token=token) - attachment_obj = object() + attachment_obj = mock.Mock() mock_form = mock.Mock() mock_body = data_binding.JSONObjectBuilder() - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) - rest_client._build_voice_message_payload = mock.Mock(return_value=(mock_body, mock_form)) - await rest_client.execute_webhook_voice_message( - webhook=webhook, token=token, attachment=attachment_obj, waveform="AAA", duration=3 - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 123} + ) as patched__request, + mock.patch.object( + rest_client, "_build_voice_message_payload", return_value=(mock_body, mock_form) + ) as patched__build_voice_message_payload, + ): + await rest_client.execute_webhook_voice_message( + webhook=webhook, token=token, attachment=attachment_obj, waveform="AAA", duration=3 + ) - rest_client._build_voice_message_payload.assert_called_once_with( + patched__build_voice_message_payload.assert_called_once_with( attachment=attachment_obj, waveform="AAA", duration=3, flags=undefined.UNDEFINED ) - rest_client._request.assert_awaited_once_with( + patched__request.assert_awaited_once_with( expected_route, form_builder=mock_form, query={"wait": "true"}, auth=None ) @pytest.mark.parametrize("webhook", [mock.Mock(webhooks.ExecutableWebhook, webhook_id=432), 432]) - async def test_fetch_webhook_message(self, rest_client, webhook): - message_obj = object() - expected_route = routes.GET_WEBHOOK_MESSAGE.compile(webhook=432, token="hi, im a token", message=456) - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_message = mock.Mock(return_value=message_obj) + async def test_fetch_webhook_message( + self, + rest_client: rest.RESTClientImpl, + mock_message: messages.Message, + webhook: webhooks.ExecutableWebhook | int, + ): + message_obj = mock.Mock() + expected_route = routes.GET_WEBHOOK_MESSAGE.compile(webhook=432, token="hi, im a token", message=101) - assert await rest_client.fetch_webhook_message(webhook, "hi, im a token", StubModel(456)) is message_obj + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_message", return_value=message_obj + ) as patched_deserialize_message, + ): + assert await rest_client.fetch_webhook_message(webhook, "hi, im a token", mock_message) is message_obj - rest_client._request.assert_awaited_once_with(expected_route, auth=None, query={}) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"id": "456"}) + patched__request.assert_awaited_once_with(expected_route, auth=None, query={}) + patched_deserialize_message.assert_called_once_with({"id": "456"}) - async def test_fetch_webhook_message_when_thread(self, rest_client): - message_obj = object() - expected_route = routes.GET_WEBHOOK_MESSAGE.compile(webhook=43234312, token="hi, im a token", message=456) - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_message = mock.Mock(return_value=message_obj) + async def test_fetch_webhook_message_when_thread( + self, + rest_client: rest.RESTClientImpl, + mock_guild_public_thread_channel: channels.GuildThreadChannel, + mock_message: messages.Message, + ): + message_obj = mock.Mock() + expected_route = routes.GET_WEBHOOK_MESSAGE.compile(webhook=112, token="hi, im a token", message=101) - result = await rest_client.fetch_webhook_message( - 43234312, "hi, im a token", StubModel(456), thread=StubModel(54123123) - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_message", return_value=message_obj + ) as patched_deserialize_message, + ): + result = await rest_client.fetch_webhook_message( + 112, "hi, im a token", mock_message, thread=mock_guild_public_thread_channel + ) - assert result is message_obj - rest_client._request.assert_awaited_once_with(expected_route, auth=None, query={"thread_id": "54123123"}) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"id": "456"}) + assert result is message_obj + patched__request.assert_awaited_once_with(expected_route, auth=None, query={"thread_id": "45611"}) + patched_deserialize_message.assert_called_once_with({"id": "456"}) @pytest.mark.parametrize("webhook", [mock.Mock(webhooks.ExecutableWebhook, webhook_id=432), 432]) - async def test_edit_webhook_message_when_form(self, rest_client, webhook): - attachment_obj = object() - attachment_obj2 = object() - component_obj = object() - component_obj2 = object() - embed_obj = object() - embed_obj2 = object() + async def test_edit_webhook_message_when_form( + self, + rest_client: rest.RESTClientImpl, + mock_message: messages.Message, + webhook: webhooks.ExecutableWebhook | int, + ): + attachment_obj = mock.Mock() + attachment_obj2 = mock.Mock() + component_obj = mock.Mock() + component_obj2 = mock.Mock() + embed_obj = mock.Mock() + embed_obj2 = mock.Mock() mock_form = mock.Mock() mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") - expected_route = routes.PATCH_WEBHOOK_MESSAGE.compile(webhook=432, token="hi, im a token", message=456) - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, mock_form)) - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) + expected_route = routes.PATCH_WEBHOOK_MESSAGE.compile(webhook=432, token="hi, im a token", message=101) - returned = await rest_client.edit_webhook_message( - webhook, - "hi, im a token", - StubModel(456), + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 123} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, mock_form) + ) as patched__build_message_payload, + ): + returned = await rest_client.edit_webhook_message( + webhook, + "hi, im a token", + mock_message, + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + ) + assert returned is patched_deserialize_message.return_value + + patched__build_message_payload.assert_called_once_with( content="new content", attachment=attachment_obj, attachments=[attachment_obj2], @@ -3759,377 +4742,513 @@ async def test_edit_webhook_message_when_form(self, rest_client, webhook): mentions_everyone=False, user_mentions=[9876], role_mentions=[1234], + edit=True, ) - assert returned is rest_client._entity_factory.deserialize_message.return_value + mock_form.add_field.assert_called_once_with( + "payload_json", b'{"testing":"ensure_in_test"}', content_type="application/json" + ) + patched__request.assert_awaited_once_with( + expected_route, form_builder=mock_form, query={"with_components": "true"}, auth=None + ) + patched_deserialize_message.assert_called_once_with({"message_id": 123}) - rest_client._build_message_payload.assert_called_once_with( - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - edit=True, - ) - mock_form.add_field.assert_called_once_with( - "payload_json", b'{"testing":"ensure_in_test"}', content_type="application/json" - ) - rest_client._request.assert_awaited_once_with( - expected_route, form_builder=mock_form, query={"with_components": "true"}, auth=None - ) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) - - async def test_edit_webhook_message_when_form_and_thread(self, rest_client): + async def test_edit_webhook_message_when_form_and_thread( + self, + rest_client: rest.RESTClientImpl, + mock_guild_public_thread_channel: channels.GuildThreadChannel, + mock_message: messages.Message, + ): mock_form = mock.Mock() mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") - expected_route = routes.PATCH_WEBHOOK_MESSAGE.compile(webhook=12354123, token="hi, im a token", message=456) - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, mock_form)) - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) + expected_route = routes.PATCH_WEBHOOK_MESSAGE.compile(webhook=12354123, token="hi, im a token", message=101) - returned = await rest_client.edit_webhook_message( - 12354123, "hi, im a token", StubModel(456), content="new content", thread=StubModel(123543123) - ) - assert returned is rest_client._entity_factory.deserialize_message.return_value - - rest_client._build_message_payload.assert_called_once_with( - content="new content", - attachment=undefined.UNDEFINED, - attachments=undefined.UNDEFINED, - component=undefined.UNDEFINED, - components=undefined.UNDEFINED, - embed=undefined.UNDEFINED, - embeds=undefined.UNDEFINED, - mentions_everyone=undefined.UNDEFINED, - user_mentions=undefined.UNDEFINED, - role_mentions=undefined.UNDEFINED, - edit=True, - ) - mock_form.add_field.assert_called_once_with( - "payload_json", b'{"testing":"ensure_in_test"}', content_type="application/json" - ) - rest_client._request.assert_awaited_once_with( - expected_route, - form_builder=mock_form, - query={"with_components": "true", "thread_id": "123543123"}, - auth=None, - ) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 123} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, mock_form) + ) as patched__build_message_payload, + ): + returned = await rest_client.edit_webhook_message( + 12354123, "hi, im a token", mock_message, content="new content", thread=mock_guild_public_thread_channel + ) + assert returned is patched_deserialize_message.return_value + + patched__build_message_payload.assert_called_once_with( + content="new content", + attachment=undefined.UNDEFINED, + attachments=undefined.UNDEFINED, + component=undefined.UNDEFINED, + components=undefined.UNDEFINED, + embed=undefined.UNDEFINED, + embeds=undefined.UNDEFINED, + mentions_everyone=undefined.UNDEFINED, + user_mentions=undefined.UNDEFINED, + role_mentions=undefined.UNDEFINED, + edit=True, + ) + mock_form.add_field.assert_called_once_with( + "payload_json", b'{"testing":"ensure_in_test"}', content_type="application/json" + ) + patched__request.assert_awaited_once_with( + expected_route, + form_builder=mock_form, + query={"with_components": "true", "thread_id": "45611"}, + auth=None, + ) + patched_deserialize_message.assert_called_once_with({"message_id": 123}) - async def test_edit_webhook_message_when_no_form(self, rest_client: rest_api.RESTClient): - attachment_obj = object() - attachment_obj2 = object() - component_obj = object() - component_obj2 = object() - embed_obj = object() - embed_obj2 = object() + async def test_edit_webhook_message_when_no_form( + self, rest_client: rest.RESTClientImpl, mock_message: messages.Message + ): + attachment_obj = mock.Mock() + attachment_obj2 = mock.Mock() + component_obj = mock.Mock() + component_obj2 = mock.Mock() + embed_obj = mock.Mock() + embed_obj2 = mock.Mock() mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") - expected_route = routes.PATCH_WEBHOOK_MESSAGE.compile(webhook=432, token="hi, im a token", message=456) - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, None)) - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) - - returned = await rest_client.edit_webhook_message( - 432, - "hi, im a token", - StubModel(456), - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - ) - assert returned is rest_client._entity_factory.deserialize_message.return_value + expected_route = routes.PATCH_WEBHOOK_MESSAGE.compile(webhook=432, token="hi, im a token", message=101) - rest_client._build_message_payload.assert_called_once_with( - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - edit=True, - ) - rest_client._request.assert_awaited_once_with( - expected_route, json={"testing": "ensure_in_test"}, query={"with_components": "true"}, auth=None - ) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 123} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, None) + ) as patched__build_message_payload, + ): + returned = await rest_client.edit_webhook_message( + 432, + "hi, im a token", + mock_message, + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + ) + assert returned is patched_deserialize_message.return_value + + patched__build_message_payload.assert_called_once_with( + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + edit=True, + ) + patched__request.assert_awaited_once_with( + expected_route, json={"testing": "ensure_in_test"}, query={"with_components": "true"}, auth=None + ) + patched_deserialize_message.assert_called_once_with({"message_id": 123}) - async def test_edit_webhook_message_when_thread_and_no_form(self, rest_client: rest_api.RESTClient): + async def test_edit_webhook_message_when_thread_and_no_form( + self, + rest_client: rest.RESTClientImpl, + mock_guild_public_thread_channel: channels.GuildThreadChannel, + mock_message: messages.Message, + ): mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") - expected_route = routes.PATCH_WEBHOOK_MESSAGE.compile(webhook=432, token="hi, im a token", message=456) - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, None)) - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) - - returned = await rest_client.edit_webhook_message( - 432, "hi, im a token", StubModel(456), content="new content", thread=StubModel(2346523432) - ) - assert returned is rest_client._entity_factory.deserialize_message.return_value + expected_route = routes.PATCH_WEBHOOK_MESSAGE.compile(webhook=432, token="hi, im a token", message=101) - rest_client._build_message_payload.assert_called_once_with( - content="new content", - attachment=undefined.UNDEFINED, - attachments=undefined.UNDEFINED, - component=undefined.UNDEFINED, - components=undefined.UNDEFINED, - embed=undefined.UNDEFINED, - embeds=undefined.UNDEFINED, - mentions_everyone=undefined.UNDEFINED, - user_mentions=undefined.UNDEFINED, - role_mentions=undefined.UNDEFINED, - edit=True, - ) - rest_client._request.assert_awaited_once_with( - expected_route, - json={"testing": "ensure_in_test"}, - query={"with_components": "true", "thread_id": "2346523432"}, - auth=None, - ) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 123} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, None) + ) as patched__build_message_payload, + ): + returned = await rest_client.edit_webhook_message( + 432, "hi, im a token", mock_message, content="new content", thread=mock_guild_public_thread_channel + ) + assert returned is patched_deserialize_message.return_value + + patched__build_message_payload.assert_called_once_with( + content="new content", + attachment=undefined.UNDEFINED, + attachments=undefined.UNDEFINED, + component=undefined.UNDEFINED, + components=undefined.UNDEFINED, + embed=undefined.UNDEFINED, + embeds=undefined.UNDEFINED, + mentions_everyone=undefined.UNDEFINED, + user_mentions=undefined.UNDEFINED, + role_mentions=undefined.UNDEFINED, + edit=True, + ) + patched__request.assert_awaited_once_with( + expected_route, + json={"testing": "ensure_in_test"}, + query={"with_components": "true", "thread_id": "45611"}, + auth=None, + ) + patched_deserialize_message.assert_called_once_with({"message_id": 123}) @pytest.mark.parametrize("webhook", [mock.Mock(webhooks.ExecutableWebhook, webhook_id=123), 123]) - async def test_delete_webhook_message(self, rest_client, webhook): - expected_route = routes.DELETE_WEBHOOK_MESSAGE.compile(webhook=123, token="token", message=456) - rest_client._request = mock.AsyncMock() + async def test_delete_webhook_message( + self, + rest_client: rest.RESTClientImpl, + mock_message: messages.Message, + webhook: webhooks.ExecutableWebhook | int, + ): + expected_route = routes.DELETE_WEBHOOK_MESSAGE.compile(webhook=123, token="token", message=101) - await rest_client.delete_webhook_message(webhook, "token", StubModel(456)) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_webhook_message(webhook, "token", mock_message) - rest_client._request.assert_awaited_once_with(expected_route, auth=None, query={}) + patched__request.assert_awaited_once_with(expected_route, auth=None, query={}) - async def test_delete_webhook_message_when_thread(self, rest_client): - expected_route = routes.DELETE_WEBHOOK_MESSAGE.compile(webhook=123, token="token", message=456) - rest_client._request = mock.AsyncMock() + async def test_delete_webhook_message_when_thread( + self, + rest_client: rest.RESTClientImpl, + mock_guild_public_thread_channel: channels.GuildThreadChannel, + mock_message: messages.Message, + ): + expected_route = routes.DELETE_WEBHOOK_MESSAGE.compile(webhook=123, token="token", message=101) - await rest_client.delete_webhook_message(123, "token", StubModel(456), thread=StubModel(432123)) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_webhook_message( + 123, "token", mock_message, thread=mock_guild_public_thread_channel + ) - rest_client._request.assert_awaited_once_with(expected_route, auth=None, query={"thread_id": "432123"}) + patched__request.assert_awaited_once_with(expected_route, auth=None, query={"thread_id": "45611"}) - async def test_fetch_gateway_url(self, rest_client): + async def test_fetch_gateway_url(self, rest_client: rest.RESTClientImpl): expected_route = routes.GET_GATEWAY.compile() - rest_client._request = mock.AsyncMock(return_value={"url": "wss://some.url"}) - assert await rest_client.fetch_gateway_url() == "wss://some.url" + with mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"url": "wss://some.url"} + ) as patched__request: + assert await rest_client.fetch_gateway_url() == "wss://some.url" - rest_client._request.assert_awaited_once_with(expected_route, auth=None) + patched__request.assert_awaited_once_with(expected_route, auth=None) - async def test_fetch_gateway_bot(self, rest_client): - bot = StubModel(123) + async def test_fetch_gateway_bot(self, rest_client: rest.RESTClientImpl): + bot = mock.Mock(sessions.GatewayBotInfo, id=123) expected_route = routes.GET_GATEWAY_BOT.compile() - rest_client._request = mock.AsyncMock(return_value={"id": "123"}) - rest_client._entity_factory.deserialize_gateway_bot_info = mock.Mock(return_value=bot) - assert await rest_client.fetch_gateway_bot_info() is bot + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "123"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_gateway_bot_info", return_value=bot + ) as patched_deserialize_gateway_bot_info, + ): + assert await rest_client.fetch_gateway_bot_info() is bot - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_gateway_bot_info.assert_called_once_with({"id": "123"}) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_gateway_bot_info.assert_called_once_with({"id": "123"}) - async def test_fetch_invite(self, rest_client): - return_invite = StubModel() - input_invite = StubModel() - input_invite.code = "Jx4cNGG" + async def test_fetch_invite(self, rest_client: rest.RESTClientImpl): + input_invite = mock.Mock(invites.InviteCode, code="Jx4cNGG") + return_invite = mock.Mock(invites.Invite) expected_route = routes.GET_INVITE.compile(invite_code="Jx4cNGG") - rest_client._request = mock.AsyncMock(return_value={"code": "Jx4cNGG"}) - rest_client._entity_factory.deserialize_invite = mock.Mock(return_value=return_invite) - assert await rest_client.fetch_invite(input_invite, with_counts=True) == return_invite - rest_client._request.assert_awaited_once_with(expected_route, query={"with_counts": "true"}) - rest_client._entity_factory.deserialize_invite.assert_called_once_with({"code": "Jx4cNGG"}) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"code": "Jx4cNGG"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_invite", return_value=return_invite + ) as patched_deserialize_invite, + ): + assert await rest_client.fetch_invite(input_invite, with_counts=True) == return_invite + + patched__request.assert_awaited_once_with(expected_route, query={"with_counts": "true"}) + patched_deserialize_invite.assert_called_once_with({"code": "Jx4cNGG"}) - async def test_delete_invite(self, rest_client): - input_invite = StubModel() - input_invite.code = "Jx4cNGG" + async def test_delete_invite(self, rest_client: rest.RESTClientImpl): + input_invite = mock.Mock(invites.InviteCode, code="Jx4cNGG") expected_route = routes.DELETE_INVITE.compile(invite_code="Jx4cNGG") - rest_client._request = mock.AsyncMock(return_value={"ok": "NO"}) - result = await rest_client.delete_invite(input_invite, reason="testing") + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"ok": "NO"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_invite") as patched_deserialize_invite, + ): + result = await rest_client.delete_invite(input_invite, reason="testing") - assert result is rest_client._entity_factory.deserialize_invite.return_value + assert result is patched_deserialize_invite.return_value - rest_client._entity_factory.deserialize_invite.assert_called_once_with(rest_client._request.return_value) - rest_client._request.assert_awaited_once_with(expected_route, reason="testing") + patched_deserialize_invite.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with(expected_route, reason="testing") - async def test_fetch_my_user(self, rest_client): - user = StubModel(123) + async def test_fetch_my_user(self, rest_client: rest.RESTClientImpl, mock_user: users.User): expected_route = routes.GET_MY_USER.compile() - rest_client._request = mock.AsyncMock(return_value={"id": "123"}) - rest_client._entity_factory.deserialize_my_user = mock.Mock(return_value=user) - assert await rest_client.fetch_my_user() is user + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "123"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_my_user", return_value=mock_user + ) as patched_deserialize_my_user, + ): + assert await rest_client.fetch_my_user() is mock_user - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_my_user.assert_called_once_with({"id": "123"}) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_my_user.assert_called_once_with({"id": "123"}) - async def test_edit_my_user(self, rest_client): - user = StubModel(123) + async def test_edit_my_user(self, rest_client: rest.RESTClientImpl, mock_user: users.User): expected_route = routes.PATCH_MY_USER.compile() expected_json = {"username": "new username"} - rest_client._request = mock.AsyncMock(return_value={"id": "123"}) - rest_client._entity_factory.deserialize_my_user = mock.Mock(return_value=user) - assert await rest_client.edit_my_user(username="new username") is user + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "123"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_my_user", return_value=mock_user + ) as patched_deserialize_my_user, + ): + assert await rest_client.edit_my_user(username="new username") is mock_user - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json) - rest_client._entity_factory.deserialize_my_user.assert_called_once_with({"id": "123"}) + patched__request.assert_awaited_once_with(expected_route, json=expected_json) + patched_deserialize_my_user.assert_called_once_with({"id": "123"}) - async def test_edit_my_user_when_avatar_is_None(self, rest_client): - user = StubModel(123) + async def test_edit_my_user_when_avatar_is_None(self, rest_client: rest.RESTClientImpl, mock_user: users.User): expected_route = routes.PATCH_MY_USER.compile() expected_json = {"username": "new username", "avatar": None} - rest_client._request = mock.AsyncMock(return_value={"id": "123"}) - rest_client._entity_factory.deserialize_my_user = mock.Mock(return_value=user) - assert await rest_client.edit_my_user(username="new username", avatar=None) is user + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "123"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_my_user", return_value=mock_user + ) as patched_deserialize_my_user, + ): + assert await rest_client.edit_my_user(username="new username", avatar=None) is mock_user - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json) - rest_client._entity_factory.deserialize_my_user.assert_called_once_with({"id": "123"}) + patched__request.assert_awaited_once_with(expected_route, json=expected_json) + patched_deserialize_my_user.assert_called_once_with({"id": "123"}) - async def test_edit_my_user_when_avatar_is_file(self, rest_client, file_resource_patch): - user = StubModel(123) + async def test_edit_my_user_when_avatar_is_file( + self, rest_client: rest.RESTClientImpl, mock_user: users.User, file_resource_patch: files.Resource[typing.Any] + ): expected_route = routes.PATCH_MY_USER.compile() expected_json = {"username": "new username", "avatar": "some data"} - rest_client._request = mock.AsyncMock(return_value={"id": "123"}) - rest_client._entity_factory.deserialize_my_user = mock.Mock(return_value=user) - assert await rest_client.edit_my_user(username="new username", avatar="someavatar.png") is user + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "123"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_my_user", return_value=mock_user + ) as patched_deserialize_my_user, + ): + assert await rest_client.edit_my_user(username="new username", avatar="someavatar.png") is mock_user - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json) - rest_client._entity_factory.deserialize_my_user.assert_called_once_with({"id": "123"}) + patched__request.assert_awaited_once_with(expected_route, json=expected_json) + patched_deserialize_my_user.assert_called_once_with({"id": "123"}) - async def test_edit_my_user_when_banner_is_None(self, rest_client): - user = StubModel(123) + async def test_edit_my_user_when_banner_is_None(self, rest_client: rest.RESTClientImpl, mock_user: users.User): expected_route = routes.PATCH_MY_USER.compile() expected_json = {"username": "new username", "banner": None} - rest_client._request = mock.AsyncMock(return_value={"id": "123"}) - rest_client._entity_factory.deserialize_my_user = mock.Mock(return_value=user) - assert await rest_client.edit_my_user(username="new username", banner=None) is user + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "123"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_my_user", return_value=mock_user + ) as patched_deserialize_my_user, + ): + assert await rest_client.edit_my_user(username="new username", banner=None) is mock_user - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json) - rest_client._entity_factory.deserialize_my_user.assert_called_once_with({"id": "123"}) + patched__request.assert_awaited_once_with(expected_route, json=expected_json) + patched_deserialize_my_user.assert_called_once_with({"id": "123"}) - async def test_edit_my_user_when_banner_is_file(self, rest_client, file_resource_patch): - user = StubModel(123) + async def test_edit_my_user_when_banner_is_file( + self, rest_client: rest.RESTClientImpl, mock_user: users.User, file_resource_patch: files.Resource[typing.Any] + ): expected_route = routes.PATCH_MY_USER.compile() expected_json = {"username": "new username", "banner": "some data"} - rest_client._request = mock.AsyncMock(return_value={"id": "123"}) - rest_client._entity_factory.deserialize_my_user = mock.Mock(return_value=user) - assert await rest_client.edit_my_user(username="new username", banner="somebanner.png") is user + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "123"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_my_user", return_value=mock_user + ) as patched_deserialize_my_user, + ): + assert await rest_client.edit_my_user(username="new username", banner="somebanner.png") is mock_user - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json) - rest_client._entity_factory.deserialize_my_user.assert_called_once_with({"id": "123"}) + patched__request.assert_awaited_once_with(expected_route, json=expected_json) + patched_deserialize_my_user.assert_called_once_with({"id": "123"}) - async def test_fetch_my_connections(self, rest_client): - connection1 = StubModel(123) - connection2 = StubModel(456) + async def test_fetch_my_connections(self, rest_client: rest.RESTClientImpl): + connection1 = mock.Mock(applications.OwnConnection, id=123) + connection2 = mock.Mock(applications.OwnConnection, id=456) expected_route = routes.GET_MY_CONNECTIONS.compile() - rest_client._request = mock.AsyncMock(return_value=[{"id": "123"}, {"id": "456"}]) - rest_client._entity_factory.deserialize_own_connection = mock.Mock(side_effect=[connection1, connection2]) - assert await rest_client.fetch_my_connections() == [connection1, connection2] + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "123"}, {"id": "456"}] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_own_connection", side_effect=[connection1, connection2] + ) as patched_deserialize_own_connection, + ): + assert await rest_client.fetch_my_connections() == [connection1, connection2] - rest_client._request.assert_awaited_once_with(expected_route) - assert rest_client._entity_factory.deserialize_own_connection.call_count == 2 - rest_client._entity_factory.deserialize_own_connection.assert_has_calls( - [mock.call({"id": "123"}), mock.call({"id": "456"})] - ) + patched__request.assert_awaited_once_with(expected_route) + assert patched_deserialize_own_connection.call_count == 2 + patched_deserialize_own_connection.assert_has_calls([mock.call({"id": "123"}), mock.call({"id": "456"})]) - async def test_leave_guild(self, rest_client): + async def test_leave_guild(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): expected_route = routes.DELETE_MY_GUILD.compile(guild=123) - rest_client._request = mock.AsyncMock() - await rest_client.leave_guild(StubModel(123)) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.leave_guild(mock_partial_guild) - rest_client._request.assert_awaited_once_with(expected_route) + patched__request.assert_awaited_once_with(expected_route) - async def test_create_dm_channel(self, rest_client, mock_cache): - dm_channel = StubModel(43234) - user = StubModel(123) + async def test_create_dm_channel( + self, + rest_client: rest.RESTClientImpl, + mock_cache: cache.MutableCache, + mock_dm_channel: channels.DMChannel, + mock_user: users.User, + ): expected_route = routes.POST_MY_CHANNELS.compile() - expected_json = {"recipient_id": "123"} - rest_client._request = mock.AsyncMock(return_value={"id": "43234"}) - rest_client._entity_factory.deserialize_dm = mock.Mock(return_value=dm_channel) + expected_json = {"recipient_id": "789"} - assert await rest_client.create_dm_channel(user) == dm_channel + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "43234"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_dm", return_value=mock_dm_channel + ) as patched_deserialize_dm, + mock.patch.object(mock_cache, "set_dm_channel_id") as patched_set_dm_channel_id, + ): + assert await rest_client.create_dm_channel(mock_user) == mock_dm_channel - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json) - rest_client._entity_factory.deserialize_dm.assert_called_once_with({"id": "43234"}) - mock_cache.set_dm_channel_id.assert_called_once_with(user, dm_channel.id) + patched__request.assert_awaited_once_with(expected_route, json=expected_json) + patched_deserialize_dm.assert_called_once_with({"id": "43234"}) + patched_set_dm_channel_id.assert_called_once_with(mock_user, mock_dm_channel.id) - async def test_create_dm_channel_when_cacheless(self, rest_client, mock_cache): - rest_client._cache = None - dm_channel = StubModel(43234) + async def test_create_dm_channel_when_cacheless( + self, + rest_client: rest.RESTClientImpl, + mock_cache: cache.MutableCache, + mock_dm_channel: channels.DMChannel, + mock_user: users.User, + ): expected_route = routes.POST_MY_CHANNELS.compile() - expected_json = {"recipient_id": "123"} - rest_client._request = mock.AsyncMock(return_value={"id": "43234"}) - rest_client._entity_factory.deserialize_dm = mock.Mock(return_value=dm_channel) + expected_json = {"recipient_id": "789"} - assert await rest_client.create_dm_channel(StubModel(123)) == dm_channel + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "43234"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_dm", return_value=mock_dm_channel + ) as patched_deserialize_dm, + mock.patch.object(rest_client, "_cache", None), + mock.patch.object(mock_cache, "set_dm_channel_id") as patched_set_dm_channel_id, + ): + assert await rest_client.create_dm_channel(mock_user) == mock_dm_channel - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json) - rest_client._entity_factory.deserialize_dm.assert_called_once_with({"id": "43234"}) - mock_cache.set_dm_channel_id.assert_not_called() + patched__request.assert_awaited_once_with(expected_route, json=expected_json) + patched_deserialize_dm.assert_called_once_with({"id": "43234"}) + patched_set_dm_channel_id.assert_not_called() - async def test_fetch_application(self, rest_client): - application = StubModel(123) + async def test_fetch_application( + self, rest_client: rest.RESTClientImpl, mock_application: applications.Application + ): expected_route = routes.GET_MY_APPLICATION.compile() - rest_client._request = mock.AsyncMock(return_value={"id": "123"}) - rest_client._entity_factory.deserialize_application = mock.Mock(return_value=application) - assert await rest_client.fetch_application() is application + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "123"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_application", return_value=mock_application + ) as patched_deserialize_application, + ): + assert await rest_client.fetch_application() is mock_application - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_application.assert_called_once_with({"id": "123"}) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_application.assert_called_once_with({"id": "123"}) - async def test_fetch_authorization(self, rest_client): + async def test_fetch_authorization(self, rest_client: rest.RESTClientImpl): expected_route = routes.GET_MY_AUTHORIZATION.compile() - rest_client._request = mock.AsyncMock(return_value={"application": {}}) - result = await rest_client.fetch_authorization() + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"application": {}} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_authorization_information" + ) as patched_deserialize_authorization_information, + ): + result = await rest_client.fetch_authorization() - assert result is rest_client._entity_factory.deserialize_authorization_information.return_value + assert result is patched_deserialize_authorization_information.return_value - rest_client._entity_factory.deserialize_authorization_information.assert_called_once_with( - rest_client._request.return_value - ) - rest_client._request.assert_awaited_once_with(expected_route) + patched_deserialize_authorization_information.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with(expected_route) - async def test_authorize_client_credentials_token(self, rest_client): + async def test_authorize_client_credentials_token(self, rest_client: rest.RESTClientImpl): expected_route = routes.POST_TOKEN.compile() mock_url_encoded_form = mock.Mock() - rest_client._request = mock.AsyncMock(return_value={"access_token": "43212123123123"}) - with mock.patch.object(data_binding, "URLEncodedFormBuilder", return_value=mock_url_encoded_form): + with ( + mock.patch.object(data_binding, "URLEncodedFormBuilder", return_value=mock_url_encoded_form), + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"access_token": "43212123123123"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_partial_token" + ) as patched_deserialize_partial_token, + ): await rest_client.authorize_client_credentials_token(65234123, "4312312", scopes=["scope1", "scope2"]) mock_url_encoded_form.add_field.assert_has_calls( [mock.call("grant_type", "client_credentials"), mock.call("scope", "scope1 scope2")] ) - rest_client._request.assert_awaited_once_with( + patched__request.assert_awaited_once_with( expected_route, form_builder=mock_url_encoded_form, auth="Basic NjUyMzQxMjM6NDMxMjMxMg==" ) - rest_client._entity_factory.deserialize_partial_token.assert_called_once_with(rest_client._request.return_value) + patched_deserialize_partial_token.assert_called_once_with(patched__request.return_value) - async def test_authorize_access_token_without_scopes(self, rest_client): + async def test_authorize_access_token_without_scopes(self, rest_client: rest.RESTClientImpl): expected_route = routes.POST_TOKEN.compile() mock_url_encoded_form = mock.Mock() - rest_client._request = mock.AsyncMock(return_value={"access_token": 42}) - with mock.patch.object(data_binding, "URLEncodedFormBuilder", return_value=mock_url_encoded_form): + with ( + mock.patch.object(data_binding, "URLEncodedFormBuilder", return_value=mock_url_encoded_form), + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"access_token": 42} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_authorization_token" + ) as patched_deserialize_authorization_token, + ): result = await rest_client.authorize_access_token(65234, "43123", "a.code", "htt:redirect//me") mock_url_encoded_form.add_field.assert_has_calls( @@ -4139,20 +5258,25 @@ async def test_authorize_access_token_without_scopes(self, rest_client): mock.call("redirect_uri", "htt:redirect//me"), ] ) - assert result is rest_client._entity_factory.deserialize_authorization_token.return_value - rest_client._entity_factory.deserialize_authorization_token.assert_called_once_with( - rest_client._request.return_value - ) - rest_client._request.assert_awaited_once_with( + assert result is patched_deserialize_authorization_token.return_value + patched_deserialize_authorization_token.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with( expected_route, form_builder=mock_url_encoded_form, auth="Basic NjUyMzQ6NDMxMjM=" ) - async def test_authorize_access_token_with_scopes(self, rest_client): + async def test_authorize_access_token_with_scopes(self, rest_client: rest.RESTClientImpl): expected_route = routes.POST_TOKEN.compile() mock_url_encoded_form = mock.Mock() - rest_client._request = mock.AsyncMock(return_value={"access_token": 42}) - with mock.patch.object(data_binding, "URLEncodedFormBuilder", return_value=mock_url_encoded_form): + with ( + mock.patch.object(data_binding, "URLEncodedFormBuilder", return_value=mock_url_encoded_form), + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"access_token": 42} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_authorization_token" + ) as patched_deserialize_authorization_token, + ): result = await rest_client.authorize_access_token(12343, "1235555", "a.codee", "htt:redirect//mee") mock_url_encoded_form.add_field.assert_has_calls( @@ -4162,39 +5286,49 @@ async def test_authorize_access_token_with_scopes(self, rest_client): mock.call("redirect_uri", "htt:redirect//mee"), ] ) - assert result is rest_client._entity_factory.deserialize_authorization_token.return_value - rest_client._entity_factory.deserialize_authorization_token.assert_called_once_with( - rest_client._request.return_value - ) - rest_client._request.assert_awaited_once_with( + assert result is patched_deserialize_authorization_token.return_value + patched_deserialize_authorization_token.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with( expected_route, form_builder=mock_url_encoded_form, auth="Basic MTIzNDM6MTIzNTU1NQ==" ) - async def test_refresh_access_token_without_scopes(self, rest_client): + async def test_refresh_access_token_without_scopes(self, rest_client: rest.RESTClientImpl): expected_route = routes.POST_TOKEN.compile() mock_url_encoded_form = mock.Mock() - rest_client._request = mock.AsyncMock(return_value={"access_token": 42}) - with mock.patch.object(data_binding, "URLEncodedFormBuilder", return_value=mock_url_encoded_form): + with ( + mock.patch.object(data_binding, "URLEncodedFormBuilder", return_value=mock_url_encoded_form), + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"access_token": 42} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_authorization_token" + ) as patched_deserialize_authorization_token, + ): result = await rest_client.refresh_access_token(454123, "123123", "a.codet") mock_url_encoded_form.add_field.assert_has_calls( [mock.call("grant_type", "refresh_token"), mock.call("refresh_token", "a.codet")] ) - assert result is rest_client._entity_factory.deserialize_authorization_token.return_value - rest_client._entity_factory.deserialize_authorization_token.assert_called_once_with( - rest_client._request.return_value - ) - rest_client._request.assert_awaited_once_with( + assert result is patched_deserialize_authorization_token.return_value + patched_deserialize_authorization_token.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with( expected_route, form_builder=mock_url_encoded_form, auth="Basic NDU0MTIzOjEyMzEyMw==" ) - async def test_refresh_access_token_with_scopes(self, rest_client): + async def test_refresh_access_token_with_scopes(self, rest_client: rest.RESTClientImpl): expected_route = routes.POST_TOKEN.compile() mock_url_encoded_form = mock.Mock() - rest_client._request = mock.AsyncMock(return_value={"access_token": 42}) - with mock.patch.object(data_binding, "URLEncodedFormBuilder", return_value=mock_url_encoded_form): + with ( + mock.patch.object(data_binding, "URLEncodedFormBuilder", return_value=mock_url_encoded_form), + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"access_token": 42} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_authorization_token" + ) as patched_deserialize_authorization_token, + ): result = await rest_client.refresh_access_token(54123, "312312", "a.codett", scopes=["1", "3", "scope43"]) mock_url_encoded_form.add_field.assert_has_calls( @@ -4204,30 +5338,33 @@ async def test_refresh_access_token_with_scopes(self, rest_client): mock.call("scope", "1 3 scope43"), ] ) - assert result is rest_client._entity_factory.deserialize_authorization_token.return_value - rest_client._entity_factory.deserialize_authorization_token.assert_called_once_with( - rest_client._request.return_value - ) - rest_client._request.assert_awaited_once_with( + assert result is patched_deserialize_authorization_token.return_value + patched_deserialize_authorization_token.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with( expected_route, form_builder=mock_url_encoded_form, auth="Basic NTQxMjM6MzEyMzEy" ) - async def test_revoke_access_token(self, rest_client): + async def test_revoke_access_token(self, rest_client: rest.RESTClientImpl): expected_route = routes.POST_TOKEN_REVOKE.compile() mock_url_encoded_form = mock.Mock() - rest_client._request = mock.AsyncMock() - with mock.patch.object(data_binding, "URLEncodedFormBuilder", return_value=mock_url_encoded_form): + with ( + mock.patch.object(data_binding, "URLEncodedFormBuilder", return_value=mock_url_encoded_form), + mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_authorization_token"), + ): await rest_client.revoke_access_token(54123, "123542", "not.a.token") mock_url_encoded_form.add_field.assert_called_once_with("token", "not.a.token") - rest_client._request.assert_awaited_once_with( + patched__request.assert_awaited_once_with( expected_route, form_builder=mock_url_encoded_form, auth="Basic NTQxMjM6MTIzNTQy" ) - async def test_add_user_to_guild(self, rest_client): - member = StubModel(789) - expected_route = routes.PUT_GUILD_MEMBER.compile(guild=123, user=456) + async def test_add_user_to_guild( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild, mock_user: users.User + ): + member = mock.Mock(guilds.Member, id=789) + expected_route = routes.PUT_GUILD_MEMBER.compile(guild=123, user=789) expected_json = { "access_token": "token", "nick": "cool nick", @@ -4235,258 +5372,423 @@ async def test_add_user_to_guild(self, rest_client): "mute": True, "deaf": False, } - rest_client._request = mock.AsyncMock(return_value={"id": "789"}) - rest_client._entity_factory.deserialize_member = mock.Mock(return_value=member) - returned = await rest_client.add_user_to_guild( - "token", - StubModel(123), - StubModel(456), - nickname="cool nick", - roles=[StubModel(234), StubModel(567)], - mute=True, - deaf=False, - ) - assert returned is member + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "789"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_member", return_value=member + ) as patched_deserialize_member, + ): + returned = await rest_client.add_user_to_guild( + "token", + mock_partial_guild, + mock_user, + nickname="cool nick", + roles=[make_partial_role(234), make_partial_role(567)], + mute=True, + deaf=False, + ) + assert returned is member - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json) - rest_client._entity_factory.deserialize_member.assert_called_once_with({"id": "789"}, guild_id=123) + patched__request.assert_awaited_once_with(expected_route, json=expected_json) + patched_deserialize_member.assert_called_once_with({"id": "789"}, guild_id=123) - async def test_add_user_to_guild_when_already_in_guild(self, rest_client): - expected_route = routes.PUT_GUILD_MEMBER.compile(guild=123, user=456) + async def test_add_user_to_guild_when_already_in_guild( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild, mock_user: users.User + ): + expected_route = routes.PUT_GUILD_MEMBER.compile(guild=123, user=789) expected_json = {"access_token": "token"} - rest_client._request = mock.AsyncMock(return_value=None) - rest_client._entity_factory.deserialize_member = mock.Mock() - assert await rest_client.add_user_to_guild("token", StubModel(123), StubModel(456)) is None + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=None + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_member") as patched_deserialize_member, + ): + assert await rest_client.add_user_to_guild("token", mock_partial_guild, mock_user) is None - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json) - rest_client._entity_factory.deserialize_member.assert_not_called() + patched__request.assert_awaited_once_with(expected_route, json=expected_json) + patched_deserialize_member.assert_not_called() - async def test_fetch_voice_regions(self, rest_client): - voice_region1 = StubModel(123) - voice_region2 = StubModel(456) + async def test_fetch_voice_regions(self, rest_client: rest.RESTClientImpl): + voice_region1 = mock.Mock(voices.VoiceRegion, id="123") + voice_region2 = mock.Mock(voices.VoiceRegion, id="456") expected_route = routes.GET_VOICE_REGIONS.compile() - rest_client._request = mock.AsyncMock(return_value=[{"id": "123"}, {"id": "456"}]) - rest_client._entity_factory.deserialize_voice_region = mock.Mock(side_effect=[voice_region1, voice_region2]) - assert await rest_client.fetch_voice_regions() == [voice_region1, voice_region2] + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "123"}, {"id": "456"}] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_voice_region", side_effect=[voice_region1, voice_region2] + ) as patched_deserialize_voice_region, + ): + assert await rest_client.fetch_voice_regions() == [voice_region1, voice_region2] - rest_client._request.assert_awaited_once_with(expected_route) - assert rest_client._entity_factory.deserialize_voice_region.call_count == 2 - rest_client._entity_factory.deserialize_voice_region.assert_has_calls( - [mock.call({"id": "123"}), mock.call({"id": "456"})] - ) + patched__request.assert_awaited_once_with(expected_route) + assert patched_deserialize_voice_region.call_count == 2 + patched_deserialize_voice_region.assert_has_calls([mock.call({"id": "123"}), mock.call({"id": "456"})]) - async def test_fetch_user(self, rest_client): - user = StubModel(456) - expected_route = routes.GET_USER.compile(user=123) - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_user = mock.Mock(return_value=user) + async def test_fetch_user(self, rest_client: rest.RESTClientImpl, mock_user: users.User): + user = mock.Mock(users.User, id=789) + expected_route = routes.GET_USER.compile(user=789) + + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_user", return_value=user + ) as patched_deserialize_user, + ): + assert await rest_client.fetch_user(mock_user) is user - assert await rest_client.fetch_user(StubModel(123)) is user + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_user.assert_called_once_with({"id": "456"}) - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_user.assert_called_once_with({"id": "456"}) + async def test_fetch_emoji( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_custom_emoji: emojis.CustomEmoji, + ): + expected_route = routes.GET_GUILD_EMOJI.compile(emoji=4440, guild=123) - async def test_fetch_emoji(self, rest_client): - emoji = StubModel(456) - expected_route = routes.GET_GUILD_EMOJI.compile(emoji=456, guild=123) - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_known_custom_emoji = mock.Mock(return_value=emoji) + emoji = make_custom_emoji(9989) - assert await rest_client.fetch_emoji(StubModel(123), StubModel(456)) is emoji + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_known_custom_emoji", return_value=emoji + ) as patched_deserialize_known_custom_emoji, + ): + assert await rest_client.fetch_emoji(mock_partial_guild, mock_custom_emoji) is emoji - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_known_custom_emoji.assert_called_once_with({"id": "456"}, guild_id=123) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_known_custom_emoji.assert_called_once_with({"id": "456"}, guild_id=123) - async def test_fetch_guild_emojis(self, rest_client): - emoji1 = StubModel(456) - emoji2 = StubModel(789) + async def test_fetch_guild_emojis(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): + emoji1 = make_custom_emoji(2893472193) + emoji2 = make_custom_emoji(9823748921) expected_route = routes.GET_GUILD_EMOJIS.compile(guild=123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) - rest_client._entity_factory.deserialize_known_custom_emoji = mock.Mock(side_effect=[emoji1, emoji2]) - assert await rest_client.fetch_guild_emojis(StubModel(123)) == [emoji1, emoji2] + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "456"}, {"id": "789"}] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_known_custom_emoji", side_effect=[emoji1, emoji2] + ) as patched_deserialize_known_custom_emoji, + ): + assert await rest_client.fetch_guild_emojis(mock_partial_guild) == [emoji1, emoji2] - rest_client._request.assert_awaited_once_with(expected_route) - assert rest_client._entity_factory.deserialize_known_custom_emoji.call_count == 2 - rest_client._entity_factory.deserialize_known_custom_emoji.assert_has_calls( - [mock.call({"id": "456"}, guild_id=123), mock.call({"id": "789"}, guild_id=123)] - ) + patched__request.assert_awaited_once_with(expected_route) + assert patched_deserialize_known_custom_emoji.call_count == 2 + patched_deserialize_known_custom_emoji.assert_has_calls( + [mock.call({"id": "456"}, guild_id=123), mock.call({"id": "789"}, guild_id=123)] + ) - async def test_create_emoji(self, rest_client, file_resource_patch): - emoji = StubModel(234) + async def test_create_emoji( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_custom_emoji: emojis.CustomEmoji, + file_resource_patch: files.Resource[typing.Any], + ): expected_route = routes.POST_GUILD_EMOJIS.compile(guild=123) - expected_json = {"name": "rooYay", "image": "some data", "roles": ["456", "789"]} - rest_client._request = mock.AsyncMock(return_value={"id": "234"}) - rest_client._entity_factory.deserialize_known_custom_emoji = mock.Mock(return_value=emoji) + expected_json = {"name": "rooYay", "image": "some data", "roles": ["22398429", "82903740"]} - returned = await rest_client.create_emoji( - StubModel(123), "rooYay", "rooYay.png", roles=[StubModel(456), StubModel(789)], reason="cause rooYay" - ) - assert returned is emoji + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "234"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_known_custom_emoji", return_value=mock_custom_emoji + ) as patched_deserialize_known_custom_emoji, + ): + returned = await rest_client.create_emoji( + mock_partial_guild, + "rooYay", + "rooYay.png", + roles=[make_partial_role(22398429), make_partial_role(82903740)], + reason="cause rooYay", + ) + assert returned is mock_custom_emoji - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="cause rooYay") - rest_client._entity_factory.deserialize_known_custom_emoji.assert_called_once_with({"id": "234"}, guild_id=123) + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="cause rooYay") + patched_deserialize_known_custom_emoji.assert_called_once_with({"id": "234"}, guild_id=123) - async def test_edit_emoji(self, rest_client): - emoji = StubModel(234) - expected_route = routes.PATCH_GUILD_EMOJI.compile(guild=123, emoji=456) - expected_json = {"name": "rooYay2", "roles": ["789", "987"]} - rest_client._request = mock.AsyncMock(return_value={"id": "234"}) - rest_client._entity_factory.deserialize_known_custom_emoji = mock.Mock(return_value=emoji) + async def test_edit_emoji( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_custom_emoji: emojis.CustomEmoji, + ): + emoji = mock.Mock(emojis.CustomEmoji, id=234) + expected_route = routes.PATCH_GUILD_EMOJI.compile(guild=123, emoji=4440) + expected_json = {"name": "rooYay2", "roles": ["22398429", "82903740"]} - returned = await rest_client.edit_emoji( - StubModel(123), - StubModel(456), - name="rooYay2", - roles=[StubModel(789), StubModel(987)], - reason="Because we have got the power", - ) - assert returned is emoji + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "234"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_known_custom_emoji", return_value=emoji + ) as patched_deserialize_known_custom_emoji, + ): + returned = await rest_client.edit_emoji( + mock_partial_guild, + mock_custom_emoji, + name="rooYay2", + roles=[make_partial_role(22398429), make_partial_role(82903740)], + reason="Because we have got the power", + ) + assert returned is emoji - rest_client._request.assert_awaited_once_with( - expected_route, json=expected_json, reason="Because we have got the power" - ) - rest_client._entity_factory.deserialize_known_custom_emoji.assert_called_once_with({"id": "234"}, guild_id=123) + patched__request.assert_awaited_once_with( + expected_route, json=expected_json, reason="Because we have got the power" + ) + patched_deserialize_known_custom_emoji.assert_called_once_with({"id": "234"}, guild_id=123) - async def test_delete_emoji(self, rest_client): - expected_route = routes.DELETE_GUILD_EMOJI.compile(guild=123, emoji=456) - rest_client._request = mock.AsyncMock() + async def test_delete_emoji( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_custom_emoji: emojis.CustomEmoji, + ): + expected_route = routes.DELETE_GUILD_EMOJI.compile(guild=123, emoji=4440) - await rest_client.delete_emoji(StubModel(123), StubModel(456), reason="testing") + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_emoji(mock_partial_guild, mock_custom_emoji, reason="testing") - rest_client._request.assert_awaited_once_with(expected_route, reason="testing") + patched__request.assert_awaited_once_with(expected_route, reason="testing") - async def test_fetch_application_emoji(self, rest_client): - emoji = StubModel(456) - expected_route = routes.GET_APPLICATION_EMOJI.compile(emoji=456, application=123) - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_known_custom_emoji = mock.Mock(return_value=emoji) + async def test_fetch_application_emoji( + self, + rest_client: rest.RESTClientImpl, + mock_application: applications.Application, + mock_custom_emoji: emojis.CustomEmoji, + ): + expected_route = routes.GET_APPLICATION_EMOJI.compile(emoji=28937492734, application=111) - assert await rest_client.fetch_application_emoji(StubModel(123), StubModel(456)) is emoji + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_known_custom_emoji", return_value=mock_custom_emoji + ) as patched_deserialize_known_custom_emoji, + ): + assert ( + await rest_client.fetch_application_emoji(mock_application, make_custom_emoji(28937492734)) + is mock_custom_emoji + ) - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_known_custom_emoji.assert_called_once_with({"id": "456"}) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_known_custom_emoji.assert_called_once_with({"id": "456"}) - async def test_fetch_application_emojis(self, rest_client): - emoji1 = StubModel(456) - emoji2 = StubModel(789) - expected_route = routes.GET_APPLICATION_EMOJIS.compile(application=123) - rest_client._request = mock.AsyncMock(return_value={"items": [{"id": "456"}, {"id": "789"}]}) - rest_client._entity_factory.deserialize_known_custom_emoji = mock.Mock(side_effect=[emoji1, emoji2]) + async def test_fetch_application_emojis( + self, rest_client: rest.RESTClientImpl, mock_application: applications.Application + ): + emoji1 = make_custom_emoji(2398472983) + emoji2 = make_custom_emoji(2309842398) + expected_route = routes.GET_APPLICATION_EMOJIS.compile(application=111) - assert await rest_client.fetch_application_emojis(StubModel(123)) == [emoji1, emoji2] + with ( + mock.patch.object( + rest_client, "_request", return_value={"items": [{"id": "456"}, {"id": "789"}]} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_known_custom_emoji", side_effect=[emoji1, emoji2] + ) as patched_deserialize_known_custom_emoji, + ): + assert await rest_client.fetch_application_emojis(mock_application) == [emoji1, emoji2] - rest_client._request.assert_awaited_once_with(expected_route) - assert rest_client._entity_factory.deserialize_known_custom_emoji.call_count == 2 - rest_client._entity_factory.deserialize_known_custom_emoji.assert_has_calls( - [mock.call({"id": "456"}), mock.call({"id": "789"})] - ) + patched__request.assert_awaited_once_with(expected_route) + assert patched_deserialize_known_custom_emoji.call_count == 2 + patched_deserialize_known_custom_emoji.assert_has_calls( + [mock.call({"id": "456"}), mock.call({"id": "789"})] + ) - async def test_create_application_emoji(self, rest_client, file_resource_patch): - emoji = StubModel(234) - expected_route = routes.POST_APPLICATION_EMOJIS.compile(application=123) + async def test_create_application_emoji( + self, + rest_client: rest.RESTClientImpl, + mock_application: applications.Application, + mock_custom_emoji: emojis.CustomEmoji, + file_resource_patch: files.Resource[typing.Any], + ): + expected_route = routes.POST_APPLICATION_EMOJIS.compile(application=111) expected_json = {"name": "rooYay", "image": "some data"} - rest_client._request = mock.AsyncMock(return_value={"id": "234"}) - rest_client._entity_factory.deserialize_known_custom_emoji = mock.Mock(return_value=emoji) - returned = await rest_client.create_application_emoji(StubModel(123), "rooYay", "rooYay.png") - assert returned is emoji + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "234"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_known_custom_emoji", return_value=mock_custom_emoji + ) as patched_deserialize_known_custom_emoji, + ): + returned = await rest_client.create_application_emoji(mock_application, "rooYay", "rooYay.png") + assert returned is mock_custom_emoji - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json) - rest_client._entity_factory.deserialize_known_custom_emoji.assert_called_once_with({"id": "234"}) + patched__request.assert_awaited_once_with(expected_route, json=expected_json) + patched_deserialize_known_custom_emoji.assert_called_once_with({"id": "234"}) - async def test_edit_application_emoji(self, rest_client): - emoji = StubModel(234) - expected_route = routes.PATCH_APPLICATION_EMOJI.compile(application=123, emoji=456) + async def test_edit_application_emoji( + self, + rest_client: rest.RESTClientImpl, + mock_application: applications.Application, + mock_custom_emoji: emojis.CustomEmoji, + ): + expected_route = routes.PATCH_APPLICATION_EMOJI.compile(application=111, emoji=23847234) expected_json = {"name": "rooYay2"} - rest_client._request = mock.AsyncMock(return_value={"id": "234"}) - rest_client._entity_factory.deserialize_known_custom_emoji = mock.Mock(return_value=emoji) - returned = await rest_client.edit_application_emoji(StubModel(123), StubModel(456), name="rooYay2") - assert returned is emoji + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "234"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_known_custom_emoji", return_value=mock_custom_emoji + ) as patched_deserialize_known_custom_emoji, + ): + returned = await rest_client.edit_application_emoji( + mock_application, make_custom_emoji(23847234), name="rooYay2" + ) + assert returned is mock_custom_emoji - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json) - rest_client._entity_factory.deserialize_known_custom_emoji.assert_called_once_with({"id": "234"}) + patched__request.assert_awaited_once_with(expected_route, json=expected_json) + patched_deserialize_known_custom_emoji.assert_called_once_with({"id": "234"}) - async def test_delete_application_emoji(self, rest_client): - expected_route = routes.DELETE_APPLICATION_EMOJI.compile(application=123, emoji=456) - rest_client._request = mock.AsyncMock() + async def test_delete_application_emoji( + self, + rest_client: rest.RESTClientImpl, + mock_application: applications.Application, + mock_custom_emoji: emojis.CustomEmoji, + ): + expected_route = routes.DELETE_APPLICATION_EMOJI.compile(application=111, emoji=4440) - await rest_client.delete_application_emoji(StubModel(123), StubModel(456)) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_application_emoji(mock_application, mock_custom_emoji) - rest_client._request.assert_awaited_once_with(expected_route) + patched__request.assert_awaited_once_with(expected_route) - async def test_fetch_sticker_packs(self, rest_client): - pack1 = object() - pack2 = object() - pack3 = object() + async def test_fetch_sticker_packs(self, rest_client: rest.RESTClientImpl): + pack1 = mock.Mock() + pack2 = mock.Mock() + pack3 = mock.Mock() expected_route = routes.GET_STICKER_PACKS.compile() - rest_client._request = mock.AsyncMock( - return_value={"sticker_packs": [{"id": "123"}, {"id": "456"}, {"id": "789"}]} - ) - rest_client._entity_factory.deserialize_sticker_pack = mock.Mock(side_effect=[pack1, pack2, pack3]) - assert await rest_client.fetch_available_sticker_packs() == [pack1, pack2, pack3] + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"sticker_packs": [{"id": "123"}, {"id": "456"}, {"id": "789"}]}, + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_sticker_pack", side_effect=[pack1, pack2, pack3] + ) as patched_deserialize_sticker_pack, + ): + assert await rest_client.fetch_available_sticker_packs() == [pack1, pack2, pack3] - rest_client._request.assert_awaited_once_with(expected_route, auth=None) - rest_client._entity_factory.deserialize_sticker_pack.assert_has_calls( - [mock.call({"id": "123"}), mock.call({"id": "456"}), mock.call({"id": "789"})] - ) + patched__request.assert_awaited_once_with(expected_route, auth=None) + patched_deserialize_sticker_pack.assert_has_calls( + [mock.call({"id": "123"}), mock.call({"id": "456"}), mock.call({"id": "789"})] + ) - async def test_fetch_sticker_when_guild_sticker(self, rest_client): - expected_route = routes.GET_STICKER.compile(sticker=123) - rest_client._request = mock.AsyncMock(return_value={"id": "123", "guild_id": "456"}) - rest_client._entity_factory.deserialize_guild_sticker = mock.Mock() + async def test_fetch_sticker_when_guild_sticker( + self, rest_client: rest.RESTClientImpl, mock_partial_sticker: stickers.PartialSticker + ): + expected_route = routes.GET_STICKER.compile(sticker=222) - returned = await rest_client.fetch_sticker(StubModel(123)) - assert returned is rest_client._entity_factory.deserialize_guild_sticker.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "123", "guild_id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_sticker" + ) as patched_deserialize_guild_sticker, + ): + returned = await rest_client.fetch_sticker(mock_partial_sticker) + assert returned is patched_deserialize_guild_sticker.return_value - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_guild_sticker.assert_called_once_with({"id": "123", "guild_id": "456"}) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_guild_sticker.assert_called_once_with({"id": "123", "guild_id": "456"}) - async def test_fetch_sticker_when_standard_sticker(self, rest_client): - expected_route = routes.GET_STICKER.compile(sticker=123) - rest_client._request = mock.AsyncMock(return_value={"id": "123"}) - rest_client._entity_factory.deserialize_standard_sticker = mock.Mock() + async def test_fetch_sticker_when_standard_sticker( + self, rest_client: rest.RESTClientImpl, mock_partial_sticker: stickers.PartialSticker + ): + expected_route = routes.GET_STICKER.compile(sticker=222) - returned = await rest_client.fetch_sticker(StubModel(123)) - assert returned is rest_client._entity_factory.deserialize_standard_sticker.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "123"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_standard_sticker" + ) as patched_deserialize_standard_sticker, + ): + returned = await rest_client.fetch_sticker(mock_partial_sticker) + assert returned is patched_deserialize_standard_sticker.return_value - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_standard_sticker.assert_called_once_with({"id": "123"}) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_standard_sticker.assert_called_once_with({"id": "123"}) - async def test_fetch_guild_stickers(self, rest_client): - sticker1 = object() - sticker2 = object() - sticker3 = object() - expected_route = routes.GET_GUILD_STICKERS.compile(guild=987) - rest_client._request = mock.AsyncMock(return_value=[{"id": "123"}, {"id": "456"}, {"id": "789"}]) - rest_client._entity_factory.deserialize_guild_sticker = mock.Mock(side_effect=[sticker1, sticker2, sticker3]) + async def test_fetch_guild_stickers( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + sticker1 = mock.Mock() + sticker2 = mock.Mock() + sticker3 = mock.Mock() + expected_route = routes.GET_GUILD_STICKERS.compile(guild=123) - assert await rest_client.fetch_guild_stickers(StubModel(987)) == [sticker1, sticker2, sticker3] + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value=[{"id": "123"}, {"id": "456"}, {"id": "789"}], + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_sticker", side_effect=[sticker1, sticker2, sticker3] + ) as patched_deserialize_guild_sticker, + ): + assert await rest_client.fetch_guild_stickers(mock_partial_guild) == [sticker1, sticker2, sticker3] - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_guild_sticker.assert_has_calls( - [mock.call({"id": "123"}), mock.call({"id": "456"}), mock.call({"id": "789"})] - ) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_guild_sticker.assert_has_calls( + [mock.call({"id": "123"}), mock.call({"id": "456"}), mock.call({"id": "789"})] + ) - async def test_fetch_guild_sticker(self, rest_client): - expected_route = routes.GET_GUILD_STICKER.compile(guild=456, sticker=123) - rest_client._request = mock.AsyncMock(return_value={"id": "123"}) - rest_client._entity_factory.deserialize_guild_sticker = mock.Mock() + async def test_fetch_guild_sticker( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_partial_sticker: stickers.PartialSticker, + ): + expected_route = routes.GET_GUILD_STICKER.compile(guild=123, sticker=222) - returned = await rest_client.fetch_guild_sticker(StubModel(456), StubModel(123)) - assert returned is rest_client._entity_factory.deserialize_guild_sticker.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "123"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_sticker" + ) as patched_deserialize_guild_sticker, + ): + returned = await rest_client.fetch_guild_sticker(mock_partial_guild, mock_partial_sticker) + assert returned is patched_deserialize_guild_sticker.return_value - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_guild_sticker.assert_called_once_with({"id": "123"}) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_guild_sticker.assert_called_once_with({"id": "123"}) - async def test_create_sticker(self, rest_client): + async def test_create_sticker(self, rest_client: rest.RESTClientImpl): rest_client.create_sticker = mock.AsyncMock() - file = object() + file = mock.Mock() sticker = await rest_client.create_sticker( 90210, "NewSticker", "funny", file, description="A sticker", reason="blah blah blah" @@ -4497,75 +5799,110 @@ async def test_create_sticker(self, rest_client): 90210, "NewSticker", "funny", file, description="A sticker", reason="blah blah blah" ) - async def test_edit_sticker(self, rest_client): - expected_route = routes.PATCH_GUILD_STICKER.compile(guild=123, sticker=456) - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_guild_sticker = mock.Mock() + async def test_edit_sticker( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_partial_sticker: stickers.PartialSticker, + ): + expected_route = routes.PATCH_GUILD_STICKER.compile(guild=123, sticker=222) - returned = await rest_client.edit_sticker( - StubModel(123), - StubModel(456), - name="testing_sticker", - description="blah", - tag=":cry:", - reason="i am bored and have too much time in my hands", - ) - assert returned is rest_client._entity_factory.deserialize_guild_sticker.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_sticker" + ) as patched_deserialize_guild_sticker, + ): + returned = await rest_client.edit_sticker( + mock_partial_guild, + mock_partial_sticker, + name="testing_sticker", + description="blah", + tag=":cry:", + reason="i am bored and have too much time in my hands", + ) + assert returned is patched_deserialize_guild_sticker.return_value - rest_client._request.assert_awaited_once_with( - expected_route, - json={"name": "testing_sticker", "description": "blah", "tags": ":cry:"}, - reason="i am bored and have too much time in my hands", - ) - rest_client._entity_factory.deserialize_guild_sticker.assert_called_once_with({"id": "456"}) + patched__request.assert_awaited_once_with( + expected_route, + json={"name": "testing_sticker", "description": "blah", "tags": ":cry:"}, + reason="i am bored and have too much time in my hands", + ) + patched_deserialize_guild_sticker.assert_called_once_with({"id": "456"}) - async def test_delete_sticker(self, rest_client): - expected_route = routes.DELETE_GUILD_STICKER.compile(guild=123, sticker=456) - rest_client._request = mock.AsyncMock() + async def test_delete_sticker( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_partial_sticker: stickers.PartialSticker, + ): + expected_route = routes.DELETE_GUILD_STICKER.compile(guild=123, sticker=222) - await rest_client.delete_sticker( - StubModel(123), StubModel(456), reason="i am bored and have too much time in my hands" - ) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_sticker( + mock_partial_guild, mock_partial_sticker, reason="i am bored and have too much time in my hands" + ) - rest_client._request.assert_awaited_once_with( - expected_route, reason="i am bored and have too much time in my hands" - ) + patched__request.assert_awaited_once_with( + expected_route, reason="i am bored and have too much time in my hands" + ) - async def test_fetch_guild(self, rest_client): - guild = StubModel(1234) + async def test_fetch_guild(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): expected_route = routes.GET_GUILD.compile(guild=123) expected_query = {"with_counts": "true"} - rest_client._request = mock.AsyncMock(return_value={"id": "1234"}) - rest_client._entity_factory.deserialize_rest_guild = mock.Mock(return_value=guild) - assert await rest_client.fetch_guild(StubModel(123)) is guild + guild = mock.Mock(guilds.PartialGuild, id=23478274) + + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "1234"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_rest_guild", return_value=guild + ) as patched_deserialize_rest_guild, + ): + assert await rest_client.fetch_guild(mock_partial_guild) is guild - rest_client._request.assert_awaited_once_with(expected_route, query=expected_query) - rest_client._entity_factory.deserialize_rest_guild.assert_called_once_with({"id": "1234"}) + patched__request.assert_awaited_once_with(expected_route, query=expected_query) + patched_deserialize_rest_guild.assert_called_once_with({"id": "1234"}) - async def test_fetch_guild_preview(self, rest_client): - guild_preview = StubModel(1234) + async def test_fetch_guild_preview(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): + guild_preview = mock.Mock(guilds.GuildPreview, id=1234) expected_route = routes.GET_GUILD_PREVIEW.compile(guild=123) - rest_client._request = mock.AsyncMock(return_value={"id": "1234"}) - rest_client._entity_factory.deserialize_guild_preview = mock.Mock(return_value=guild_preview) - assert await rest_client.fetch_guild_preview(StubModel(123)) is guild_preview + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "1234"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_preview", return_value=guild_preview + ) as patched_deserialize_guild_preview, + ): + assert await rest_client.fetch_guild_preview(mock_partial_guild) is guild_preview - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_guild_preview.assert_called_once_with({"id": "1234"}) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_guild_preview.assert_called_once_with({"id": "1234"}) - async def test_delete_guild(self, rest_client): + async def test_delete_guild(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): expected_route = routes.DELETE_GUILD.compile(guild=123) - rest_client._request = mock.AsyncMock() - await rest_client.delete_guild(StubModel(123)) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_guild(mock_partial_guild) - rest_client._request.assert_awaited_once_with(expected_route) + patched__request.assert_awaited_once_with(expected_route) - async def test_edit_guild(self, rest_client, file_resource): - icon_resource = file_resource("icon data") - splash_resource = file_resource("splash data") - banner_resource = file_resource("banner data") + async def test_edit_guild( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_voice_channel: channels.GuildVoiceChannel, + mock_user: users.User, + ): + icon_resource = MockFileResource("icon data") + splash_resource = MockFileResource("splash data") + banner_resource = MockFileResource("banner data") expected_route = routes.PATCH_GUILD.compile(guild=123) expected_json = { "name": "hikari", @@ -4574,7 +5911,7 @@ async def test_edit_guild(self, rest_client, file_resource): "explicit_content_filter": 1, "afk_timeout": 60, "preferred_locale": "en-UK", - "afk_channel_id": "456", + "afk_channel_id": "4562", "owner_id": "789", "system_channel_id": "789", "rules_channel_id": "987", @@ -4584,34 +5921,45 @@ async def test_edit_guild(self, rest_client, file_resource): "banner": "banner data", "features": ["COMMUNITY", "RAID_ALERTS_DISABLED"], } - rest_client._request = mock.AsyncMock(return_value={"id": "123"}) - with mock.patch.object(files, "ensure_resource", side_effect=[icon_resource, splash_resource, banner_resource]): + with ( + mock.patch.object(files, "ensure_resource", side_effect=[icon_resource, splash_resource, banner_resource]), + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "123"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_rest_guild") as patched_deserialize_rest_guild, + ): result = await rest_client.edit_guild( - StubModel(123), + mock_partial_guild, name="hikari", verification_level=guilds.GuildVerificationLevel.HIGH, default_message_notifications=guilds.GuildMessageNotificationsLevel.ONLY_MENTIONS, explicit_content_filter_level=guilds.GuildExplicitContentFilterLevel.MEMBERS_WITHOUT_ROLES, - afk_channel=StubModel(456), + afk_channel=mock_guild_voice_channel, afk_timeout=60, icon="icon.png", - owner=StubModel(789), + owner=mock_user, splash="splash.png", banner="banner.png", - system_channel=StubModel(789), - rules_channel=StubModel(987), - public_updates_channel=(654), + system_channel=make_guild_text_channel(789), + rules_channel=make_guild_text_channel(987), + public_updates_channel=make_guild_text_channel(654), preferred_locale="en-UK", features=[guilds.GuildFeature.COMMUNITY, guilds.GuildFeature.RAID_ALERTS_DISABLED], reason="hikari best", ) - assert result is rest_client._entity_factory.deserialize_rest_guild.return_value + assert result is patched_deserialize_rest_guild.return_value - rest_client._entity_factory.deserialize_rest_guild.assert_called_once_with(rest_client._request.return_value) - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="hikari best") + patched_deserialize_rest_guild.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="hikari best") - async def test_edit_guild_when_images_are_None(self, rest_client): + async def test_edit_guild_when_images_are_None( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_voice_channel: channels.GuildVoiceChannel, + mock_user: users.User, + ): expected_route = routes.PATCH_GUILD.compile(guild=123) expected_json = { "name": "hikari", @@ -4620,7 +5968,7 @@ async def test_edit_guild_when_images_are_None(self, rest_client): "explicit_content_filter": 1, "afk_timeout": 60, "preferred_locale": "en-UK", - "afk_channel_id": "456", + "afk_channel_id": "4562", "owner_id": "789", "system_channel_id": "789", "rules_channel_id": "987", @@ -4630,244 +5978,308 @@ async def test_edit_guild_when_images_are_None(self, rest_client): "banner": None, "features": ["COMMUNITY", "RAID_ALERTS_DISABLED"], } - rest_client._request = mock.AsyncMock(return_value={"ok": "NO"}) - - result = await rest_client.edit_guild( - StubModel(123), - name="hikari", - verification_level=guilds.GuildVerificationLevel.HIGH, - default_message_notifications=guilds.GuildMessageNotificationsLevel.ONLY_MENTIONS, - explicit_content_filter_level=guilds.GuildExplicitContentFilterLevel.MEMBERS_WITHOUT_ROLES, - afk_channel=StubModel(456), - afk_timeout=60, - icon=None, - owner=StubModel(789), - splash=None, - banner=None, - system_channel=StubModel(789), - rules_channel=StubModel(987), - public_updates_channel=(654), - preferred_locale="en-UK", - features=[guilds.GuildFeature.COMMUNITY, guilds.GuildFeature.RAID_ALERTS_DISABLED], - reason="hikari best", - ) - assert result is rest_client._entity_factory.deserialize_rest_guild.return_value - - rest_client._entity_factory.deserialize_rest_guild.assert_called_once_with(rest_client._request.return_value) - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="hikari best") - - async def test_edit_guild_without_optionals(self, rest_client): + + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"ok": "NO"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_rest_guild") as patched_deserialize_rest_guild, + ): + result = await rest_client.edit_guild( + mock_partial_guild, + name="hikari", + verification_level=guilds.GuildVerificationLevel.HIGH, + default_message_notifications=guilds.GuildMessageNotificationsLevel.ONLY_MENTIONS, + explicit_content_filter_level=guilds.GuildExplicitContentFilterLevel.MEMBERS_WITHOUT_ROLES, + afk_channel=mock_guild_voice_channel, + afk_timeout=60, + icon=None, + owner=mock_user, + splash=None, + banner=None, + system_channel=make_guild_text_channel(789), + rules_channel=make_guild_text_channel(987), + public_updates_channel=make_guild_text_channel(654), + preferred_locale="en-UK", + features=[guilds.GuildFeature.COMMUNITY, guilds.GuildFeature.RAID_ALERTS_DISABLED], + reason="hikari best", + ) + assert result is patched_deserialize_rest_guild.return_value + + patched_deserialize_rest_guild.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="hikari best") + + async def test_edit_guild_without_optionals( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): expected_route = routes.PATCH_GUILD.compile(guild=123) expected_json = {} - rest_client._request = mock.AsyncMock(return_value={"id": "42"}) - result = await rest_client.edit_guild(StubModel(123)) - assert result is rest_client._entity_factory.deserialize_rest_guild.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "42"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_rest_guild") as patched_deserialize_rest_guild, + ): + result = await rest_client.edit_guild(mock_partial_guild) + assert result is patched_deserialize_rest_guild.return_value - rest_client._entity_factory.deserialize_rest_guild.assert_called_once_with(rest_client._request.return_value) - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason=undefined.UNDEFINED) + patched_deserialize_rest_guild.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason=undefined.UNDEFINED) - async def test_set_guild_incident_actions(self, rest_client): + async def test_set_guild_incident_actions(self, rest_client: rest.RESTClientImpl): expected_route = routes.PUT_GUILD_INCIDENT_ACTIONS.compile(guild=123) expected_json = {"invites_disabled_until": "2023-09-01T14:48:02.222000+00:00", "dms_disabled_until": None} - rest_client._request = mock.AsyncMock(return_value={"testing": "data"}) - result = await rest_client.set_guild_incident_actions( - 123, invites_disabled_until=datetime.datetime(2023, 9, 1, 14, 48, 2, 222000, tzinfo=datetime.timezone.utc) - ) - assert result is rest_client._entity_factory.deserialize_guild_incidents.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"testing": "data"} + ) as patched__request, + mock.patch.object( + rest_client._entity_factory, "deserialize_guild_incidents" + ) as patched_deserialize_guild_incidents, + ): + result = await rest_client.set_guild_incident_actions( + 123, + invites_disabled_until=datetime.datetime(2023, 9, 1, 14, 48, 2, 222000, tzinfo=datetime.timezone.utc), + ) + assert result is patched_deserialize_guild_incidents.return_value - rest_client._entity_factory.deserialize_guild_incidents.assert_called_once_with( - rest_client._request.return_value - ) - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json) + patched_deserialize_guild_incidents.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with(expected_route, json=expected_json) - async def test_fetch_guild_channels(self, rest_client): - channel1 = StubModel(456) - channel2 = StubModel(789) + async def test_fetch_guild_channels( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + channel1 = make_guild_text_channel(456) + channel2 = make_guild_text_channel(789) expected_route = routes.GET_GUILD_CHANNELS.compile(guild=123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) - rest_client._entity_factory.deserialize_channel = mock.Mock(side_effect=[channel1, channel2]) - assert await rest_client.fetch_guild_channels(StubModel(123)) == [channel1, channel2] + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "456"}, {"id": "789"}] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_channel", side_effect=[channel1, channel2] + ) as patched_deserialize_channel, + ): + assert await rest_client.fetch_guild_channels(mock_partial_guild) == [channel1, channel2] - rest_client._request.assert_awaited_once_with(expected_route) - assert rest_client._entity_factory.deserialize_channel.call_count == 2 - rest_client._entity_factory.deserialize_channel.assert_has_calls( - [mock.call({"id": "456"}), mock.call({"id": "789"})] - ) + patched__request.assert_awaited_once_with(expected_route) + assert patched_deserialize_channel.call_count == 2 + patched_deserialize_channel.assert_has_calls([mock.call({"id": "456"}), mock.call({"id": "789"})]) - async def test_fetch_guild_channels_ignores_unknown_channel_type(self, rest_client): - channel1 = StubModel(456) + async def test_fetch_guild_channels_ignores_unknown_channel_type( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_text_channel: channels.GuildTextChannel, + ): expected_route = routes.GET_GUILD_CHANNELS.compile(guild=123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) - rest_client._entity_factory.deserialize_channel = mock.Mock( - side_effect=[errors.UnrecognisedEntityError("echelon"), channel1] - ) - assert await rest_client.fetch_guild_channels(StubModel(123)) == [channel1] + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "456"}, {"id": "789"}] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, + "deserialize_channel", + side_effect=[errors.UnrecognisedEntityError("echelon"), mock_guild_text_channel], + ) as patched_deserialize_channel, + ): + assert await rest_client.fetch_guild_channels(mock_partial_guild) == [mock_guild_text_channel] - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_channel.assert_has_calls( - [mock.call({"id": "456"}), mock.call({"id": "789"})] - ) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_channel.assert_has_calls([mock.call({"id": "456"}), mock.call({"id": "789"})]) - async def test_create_guild_text_channel(self, rest_client: rest.RESTClientImpl): - guild = StubModel(123) - category_channel = StubModel(789) - overwrite1 = StubModel(987) - overwrite2 = StubModel(654) - rest_client._create_guild_channel = mock.AsyncMock() + async def test_create_guild_text_channel( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_category: channels.GuildCategory, + ): + overwrite1 = make_permission_overwrite(9283749) + overwrite2 = make_permission_overwrite(2837472) - returned = await rest_client.create_guild_text_channel( - guild, - "general", - position=1, - topic="general chat", - nsfw=False, - rate_limit_per_user=60, - permission_overwrites=[overwrite1, overwrite2], - category=category_channel, - reason="because we need one", - default_auto_archive_duration=123332, - ) - assert returned is rest_client._entity_factory.deserialize_guild_text_channel.return_value + with ( + mock.patch.object( + rest_client, "_create_guild_channel", new_callable=mock.AsyncMock + ) as patched__create_guild_channel, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_text_channel" + ) as patched_deserialize_guild_text_channel, + ): + returned = await rest_client.create_guild_text_channel( + mock_partial_guild, + "general", + position=1, + topic="general chat", + nsfw=False, + rate_limit_per_user=60, + permission_overwrites=[overwrite1, overwrite2], + category=mock_guild_category, + reason="because we need one", + default_auto_archive_duration=123332, + ) + assert returned is patched_deserialize_guild_text_channel.return_value + + patched__create_guild_channel.assert_awaited_once_with( + mock_partial_guild, + "general", + channels.ChannelType.GUILD_TEXT, + position=1, + topic="general chat", + nsfw=False, + rate_limit_per_user=60, + permission_overwrites=[overwrite1, overwrite2], + category=mock_guild_category, + reason="because we need one", + default_auto_archive_duration=123332, + ) + patched_deserialize_guild_text_channel.assert_called_once_with(patched__create_guild_channel.return_value) - rest_client._create_guild_channel.assert_awaited_once_with( - guild, - "general", - channels.ChannelType.GUILD_TEXT, - position=1, - topic="general chat", - nsfw=False, - rate_limit_per_user=60, - permission_overwrites=[overwrite1, overwrite2], - category=category_channel, - reason="because we need one", - default_auto_archive_duration=123332, - ) - rest_client._entity_factory.deserialize_guild_text_channel.assert_called_once_with( - rest_client._create_guild_channel.return_value - ) + async def test_create_guild_news_channel( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_category: channels.GuildCategory, + ): + overwrite1 = make_permission_overwrite(9283749) + overwrite2 = make_permission_overwrite(2837472) - async def test_create_guild_news_channel(self, rest_client: rest.RESTClientImpl): - guild = StubModel(123) - category_channel = StubModel(789) - overwrite1 = StubModel(987) - overwrite2 = StubModel(654) - rest_client._create_guild_channel = mock.AsyncMock() + with ( + mock.patch.object( + rest_client, "_create_guild_channel", new_callable=mock.AsyncMock + ) as patched__create_guild_channel, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_news_channel" + ) as patched_deserialize_guild_news_channel, + ): + returned = await rest_client.create_guild_news_channel( + mock_partial_guild, + "general", + position=1, + topic="general news", + nsfw=False, + rate_limit_per_user=60, + permission_overwrites=[overwrite1, overwrite2], + category=mock_guild_category, + reason="because we need one", + default_auto_archive_duration=5445234, + ) + assert returned is patched_deserialize_guild_news_channel.return_value + + patched__create_guild_channel.assert_awaited_once_with( + mock_partial_guild, + "general", + channels.ChannelType.GUILD_NEWS, + position=1, + topic="general news", + nsfw=False, + rate_limit_per_user=60, + permission_overwrites=[overwrite1, overwrite2], + category=mock_guild_category, + reason="because we need one", + default_auto_archive_duration=5445234, + ) + patched_deserialize_guild_news_channel.assert_called_once_with(patched__create_guild_channel.return_value) - returned = await rest_client.create_guild_news_channel( - guild, - "general", - position=1, - topic="general news", - nsfw=False, - rate_limit_per_user=60, - permission_overwrites=[overwrite1, overwrite2], - category=category_channel, - reason="because we need one", - default_auto_archive_duration=5445234, - ) - assert returned is rest_client._entity_factory.deserialize_guild_news_channel.return_value - - rest_client._create_guild_channel.assert_awaited_once_with( - guild, - "general", - channels.ChannelType.GUILD_NEWS, - position=1, - topic="general news", - nsfw=False, - rate_limit_per_user=60, - permission_overwrites=[overwrite1, overwrite2], - category=category_channel, - reason="because we need one", - default_auto_archive_duration=5445234, - ) - rest_client._entity_factory.deserialize_guild_news_channel.assert_called_once_with( - rest_client._create_guild_channel.return_value - ) - - async def test_create_guild_forum_channel(self, rest_client: rest.RESTClientImpl): - guild = StubModel(123) - category_channel = StubModel(789) - overwrite1 = StubModel(987) - overwrite2 = StubModel(654) - tag1 = StubModel(1203) - tag2 = StubModel(1204) - rest_client._create_guild_channel = mock.AsyncMock() + async def test_create_guild_forum_channel( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_category: channels.GuildCategory, + ): + overwrite1 = make_permission_overwrite(9283749) + overwrite2 = make_permission_overwrite(2837472) - returned = await rest_client.create_guild_forum_channel( - guild, - "help-center", - position=1, - topic="get help!", - nsfw=False, - rate_limit_per_user=60, - permission_overwrites=[overwrite1, overwrite2], - category=category_channel, - reason="because we need one", - default_auto_archive_duration=5445234, - default_thread_rate_limit_per_user=40, - default_forum_layout=channels.ForumLayoutType.LIST_VIEW, - default_sort_order=channels.ForumSortOrderType.LATEST_ACTIVITY, - available_tags=[tag1, tag2], - default_reaction_emoji="some reaction", - ) - assert returned is rest_client._entity_factory.deserialize_guild_forum_channel.return_value + tag1 = mock.Mock(channels.ForumTag, id=1203) + tag2 = mock.Mock(channels.ForumTag, id=1204) - rest_client._create_guild_channel.assert_awaited_once_with( - guild, - "help-center", - channels.ChannelType.GUILD_FORUM, - position=1, - topic="get help!", - nsfw=False, - rate_limit_per_user=60, - permission_overwrites=[overwrite1, overwrite2], - category=category_channel, - reason="because we need one", - default_auto_archive_duration=5445234, - default_thread_rate_limit_per_user=40, - default_forum_layout=channels.ForumLayoutType.LIST_VIEW, - default_sort_order=channels.ForumSortOrderType.LATEST_ACTIVITY, - available_tags=[tag1, tag2], - default_reaction_emoji="some reaction", - ) - rest_client._entity_factory.deserialize_guild_forum_channel.assert_called_once_with( - rest_client._create_guild_channel.return_value - ) + with ( + mock.patch.object( + rest_client, "_create_guild_channel", new_callable=mock.AsyncMock + ) as patched__create_guild_channel, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_forum_channel" + ) as patched_deserialize_guild_forum_channel, + ): + returned = await rest_client.create_guild_forum_channel( + mock_partial_guild, + "help-center", + position=1, + topic="get help!", + nsfw=False, + rate_limit_per_user=60, + permission_overwrites=[overwrite1, overwrite2], + category=mock_guild_category, + reason="because we need one", + default_auto_archive_duration=5445234, + default_thread_rate_limit_per_user=40, + default_forum_layout=channels.ForumLayoutType.LIST_VIEW, + default_sort_order=channels.ForumSortOrderType.LATEST_ACTIVITY, + available_tags=[tag1, tag2], + default_reaction_emoji="some reaction", + ) + assert returned is patched_deserialize_guild_forum_channel.return_value + + patched__create_guild_channel.assert_awaited_once_with( + mock_partial_guild, + "help-center", + channels.ChannelType.GUILD_FORUM, + position=1, + topic="get help!", + nsfw=False, + rate_limit_per_user=60, + permission_overwrites=[overwrite1, overwrite2], + category=mock_guild_category, + reason="because we need one", + default_auto_archive_duration=5445234, + default_thread_rate_limit_per_user=40, + default_forum_layout=channels.ForumLayoutType.LIST_VIEW, + default_sort_order=channels.ForumSortOrderType.LATEST_ACTIVITY, + available_tags=[tag1, tag2], + default_reaction_emoji="some reaction", + ) + patched_deserialize_guild_forum_channel.assert_called_once_with(patched__create_guild_channel.return_value) - async def test_create_guild_media_channel(self, rest_client: rest.RESTClientImpl): - guild = StubModel(123) - category_channel = StubModel(789) - overwrite1 = StubModel(987) - overwrite2 = StubModel(654) - tag1 = StubModel(1203) - tag2 = StubModel(1204) + async def test_create_guild_media_channel( + self, + rest_client: rest.RESTClientImpl, + hikari_partial_guild: guilds.PartialGuild, + mock_guild_category: channels.GuildCategory, + ): + overwrite1 = make_permission_overwrite(987) + overwrite2 = make_permission_overwrite(654) + tag1 = make_forum_tag(1203) + tag2 = make_forum_tag(1204) rest_client._create_guild_channel = mock.AsyncMock() - returned = await rest_client.create_guild_media_channel( - guild, - "help-center", - position=1, - topic="get help!", - nsfw=False, - rate_limit_per_user=60, - permission_overwrites=[overwrite1, overwrite2], - category=category_channel, - reason="because we need one", - default_auto_archive_duration=5445234, - default_thread_rate_limit_per_user=40, - default_forum_layout=channels.ForumLayoutType.LIST_VIEW, - default_sort_order=channels.ForumSortOrderType.LATEST_ACTIVITY, - available_tags=[tag1, tag2], - default_reaction_emoji="some reaction", - ) - assert returned is rest_client._entity_factory.deserialize_guild_media_channel.return_value + with mock.patch.object( + rest_client._entity_factory, "deserialize_guild_media_channel" + ) as patched_deserialize_guild_media_channel: + returned = await rest_client.create_guild_media_channel( + hikari_partial_guild, + "help-center", + position=1, + topic="get help!", + nsfw=False, + rate_limit_per_user=60, + permission_overwrites=[overwrite1, overwrite2], + category=mock_guild_category, + reason="because we need one", + default_auto_archive_duration=5445234, + default_thread_rate_limit_per_user=40, + default_forum_layout=channels.ForumLayoutType.LIST_VIEW, + default_sort_order=channels.ForumSortOrderType.LATEST_ACTIVITY, + available_tags=[tag1, tag2], + default_reaction_emoji="some reaction", + ) + + assert returned is patched_deserialize_guild_media_channel.return_value rest_client._create_guild_channel.assert_awaited_once_with( - guild, + hikari_partial_guild, "help-center", channels.ChannelType.GUILD_MEDIA, position=1, @@ -4875,7 +6287,7 @@ async def test_create_guild_media_channel(self, rest_client: rest.RESTClientImpl nsfw=False, rate_limit_per_user=60, permission_overwrites=[overwrite1, overwrite2], - category=category_channel, + category=mock_guild_category, reason="because we need one", default_auto_archive_duration=5445234, default_thread_rate_limit_per_user=40, @@ -4884,33 +6296,37 @@ async def test_create_guild_media_channel(self, rest_client: rest.RESTClientImpl available_tags=[tag1, tag2], default_reaction_emoji="some reaction", ) - rest_client._entity_factory.deserialize_guild_media_channel.assert_called_once_with( - rest_client._create_guild_channel.return_value - ) + patched_deserialize_guild_media_channel.assert_called_once_with(rest_client._create_guild_channel.return_value) - async def test_create_guild_voice_channel(self, rest_client: rest.RESTClientImpl): - guild = StubModel(123) - category_channel = StubModel(789) - overwrite1 = StubModel(987) - overwrite2 = StubModel(654) + async def test_create_guild_voice_channel( + self, + rest_client: rest.RESTClientImpl, + hikari_partial_guild: guilds.PartialGuild, + mock_guild_category: channels.GuildCategory, + ): + overwrite1 = make_permission_overwrite(987) + overwrite2 = make_permission_overwrite(654) rest_client._create_guild_channel = mock.AsyncMock() - returned = await rest_client.create_guild_voice_channel( - guild, - "general", - position=1, - user_limit=60, - bitrate=64, - video_quality_mode=channels.VideoQualityMode.FULL, - permission_overwrites=[overwrite1, overwrite2], - category=category_channel, - region="ok boomer", - reason="because we need one", - ) - assert returned is rest_client._entity_factory.deserialize_guild_voice_channel.return_value + with mock.patch.object( + rest_client._entity_factory, "deserialize_guild_voice_channel" + ) as patched_deserialize_guild_voice_channel: + returned = await rest_client.create_guild_voice_channel( + hikari_partial_guild, + "general", + position=1, + user_limit=60, + bitrate=64, + video_quality_mode=channels.VideoQualityMode.FULL, + permission_overwrites=[overwrite1, overwrite2], + category=mock_guild_category, + region="ok boomer", + reason="because we need one", + ) + assert returned is patched_deserialize_guild_voice_channel.return_value rest_client._create_guild_channel.assert_awaited_once_with( - guild, + hikari_partial_guild, "general", channels.ChannelType.GUILD_VOICE, position=1, @@ -4919,83 +6335,111 @@ async def test_create_guild_voice_channel(self, rest_client: rest.RESTClientImpl video_quality_mode=channels.VideoQualityMode.FULL, permission_overwrites=[overwrite1, overwrite2], region="ok boomer", - category=category_channel, + category=mock_guild_category, reason="because we need one", ) - rest_client._entity_factory.deserialize_guild_voice_channel.assert_called_once_with( - rest_client._create_guild_channel.return_value - ) - - async def test_create_guild_stage_channel(self, rest_client: rest.RESTClientImpl): - guild = StubModel(123) - category_channel = StubModel(789) - overwrite1 = StubModel(987) - overwrite2 = StubModel(654) - rest_client._create_guild_channel = mock.AsyncMock() - - returned = await rest_client.create_guild_stage_channel( - guild, - "general", - position=1, - user_limit=60, - bitrate=64, - permission_overwrites=[overwrite1, overwrite2], - category=category_channel, - region="Doge Moon", - reason="When doge == 1$", - ) - assert returned is rest_client._entity_factory.deserialize_guild_stage_channel.return_value + patched_deserialize_guild_voice_channel.assert_called_once_with(rest_client._create_guild_channel.return_value) - rest_client._create_guild_channel.assert_awaited_once_with( - guild, - "general", - channels.ChannelType.GUILD_STAGE, - position=1, - user_limit=60, - bitrate=64, - permission_overwrites=[overwrite1, overwrite2], - region="Doge Moon", - category=category_channel, - reason="When doge == 1$", - ) - rest_client._entity_factory.deserialize_guild_stage_channel.assert_called_once_with( - rest_client._create_guild_channel.return_value - ) + async def test_create_guild_stage_channel( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_category: channels.GuildCategory, + ): + overwrite1 = make_permission_overwrite(9283749) + overwrite2 = make_permission_overwrite(2837472) - async def test_create_guild_category(self, rest_client: rest.RESTClientImpl): - guild = StubModel(123) - overwrite1 = StubModel(987) - overwrite2 = StubModel(654) - rest_client._create_guild_channel = mock.AsyncMock() + with ( + mock.patch.object( + rest_client, "_create_guild_channel", new_callable=mock.AsyncMock + ) as patched__create_guild_channel, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_stage_channel" + ) as patched_deserialize_guild_stage_channel, + ): + returned = await rest_client.create_guild_stage_channel( + mock_partial_guild, + "general", + position=1, + user_limit=60, + bitrate=64, + permission_overwrites=[overwrite1, overwrite2], + category=mock_guild_category, + region="Doge Moon", + reason="When doge == 1$", + ) + assert returned is patched_deserialize_guild_stage_channel.return_value + + patched__create_guild_channel.assert_awaited_once_with( + mock_partial_guild, + "general", + channels.ChannelType.GUILD_STAGE, + position=1, + user_limit=60, + bitrate=64, + permission_overwrites=[overwrite1, overwrite2], + region="Doge Moon", + category=mock_guild_category, + reason="When doge == 1$", + ) + patched_deserialize_guild_stage_channel.assert_called_once_with(patched__create_guild_channel.return_value) - returned = await rest_client.create_guild_category( - guild, "general", position=1, permission_overwrites=[overwrite1, overwrite2], reason="because we need one" - ) - assert returned is rest_client._entity_factory.deserialize_guild_category.return_value + async def test_create_guild_category( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + overwrite1 = make_permission_overwrite(9283749) + overwrite2 = make_permission_overwrite(2837472) - rest_client._create_guild_channel.assert_awaited_once_with( - guild, - "general", - channels.ChannelType.GUILD_CATEGORY, - position=1, - permission_overwrites=[overwrite1, overwrite2], - reason="because we need one", - ) - rest_client._entity_factory.deserialize_guild_category.assert_called_once_with( - rest_client._create_guild_channel.return_value - ) + with ( + mock.patch.object( + rest_client, "_create_guild_channel", new_callable=mock.AsyncMock + ) as patched__create_guild_channel, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_category" + ) as patched_deserialize_guild_category, + ): + returned = await rest_client.create_guild_category( + mock_partial_guild, + "general", + position=1, + permission_overwrites=[overwrite1, overwrite2], + reason="because we need one", + ) + assert returned is patched_deserialize_guild_category.return_value + + patched__create_guild_channel.assert_awaited_once_with( + mock_partial_guild, + "general", + channels.ChannelType.GUILD_CATEGORY, + position=1, + permission_overwrites=[overwrite1, overwrite2], + reason="because we need one", + ) + patched_deserialize_guild_category.assert_called_once_with(patched__create_guild_channel.return_value) @pytest.mark.parametrize( - ("emoji", "expected_emoji_id", "expected_emoji_name"), [(123, 123, None), ("emoji", None, "emoji")] + ("emoji", "expected_emoji_id", "expected_emoji_name"), + [ + (emojis.CustomEmoji(id=snowflakes.Snowflake(989), name="emoji", is_animated=False), 989, None), + (emojis.UnicodeEmoji("❤️"), None, "❤️"), + ], ) @pytest.mark.parametrize("default_auto_archive_duration", [12322, (datetime.timedelta(minutes=12322)), 12322.0]) async def test__create_guild_channel( - self, rest_client, default_auto_archive_duration, emoji, expected_emoji_id, expected_emoji_name + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_category: channels.GuildCategory, + default_auto_archive_duration: int | float | datetime.timedelta, + emoji: emojis.Emoji, + expected_emoji_id: int | None, + expected_emoji_name: str | None, ): - overwrite1 = StubModel(987) - overwrite2 = StubModel(654) - tag1 = StubModel(321) - tag2 = StubModel(123) + overwrite1 = make_permission_overwrite(9283749) + overwrite2 = make_permission_overwrite(2837472) + tag1 = mock.Mock(channels.ForumTag, id=321) + tag2 = mock.Mock(channels.ForumTag, id=123) + expected_route = routes.POST_GUILD_CHANNELS.compile(guild=123) expected_json = { "type": 0, @@ -5007,7 +6451,7 @@ async def test__create_guild_channel( "user_limit": 99, "rate_limit_per_user": 60, "rtc_region": "wicky wicky", - "parent_id": "321", + "parent_id": "4564", "permission_overwrites": [{"id": "987"}, {"id": "654"}], "default_auto_archive_duration": 12322, "default_thread_rate_limit_per_user": 40, @@ -5016,44 +6460,48 @@ async def test__create_guild_channel( "default_reaction_emoji": {"emoji_id": expected_emoji_id, "emoji_name": expected_emoji_name}, "available_tags": [{"id": "321"}, {"id": "123"}], } - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.serialize_permission_overwrite = mock.Mock( - side_effect=[{"id": "987"}, {"id": "654"}] - ) - rest_client._entity_factory.serialize_forum_tag = mock.Mock(side_effect=[{"id": "321"}, {"id": "123"}]) - returned = await rest_client._create_guild_channel( - StubModel(123), - "general", - channels.ChannelType.GUILD_TEXT, - position=1, - topic="some topic", - nsfw=True, - bitrate=64, - user_limit=99, - rate_limit_per_user=60, - permission_overwrites=[overwrite1, overwrite2], - region="wicky wicky", - category=StubModel(321), - reason="we have got the power", - default_auto_archive_duration=default_auto_archive_duration, - default_thread_rate_limit_per_user=40, - default_forum_layout=channels.ForumLayoutType.LIST_VIEW, - default_sort_order=channels.ForumSortOrderType.LATEST_ACTIVITY, - available_tags=[tag1, tag2], - default_reaction_emoji=emoji, - ) - assert returned is rest_client._request.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "serialize_permission_overwrite", side_effect=[{"id": "987"}, {"id": "654"}] + ) as patched_serialize_permission_overwrite, + mock.patch.object( + rest_client.entity_factory, "serialize_forum_tag", side_effect=[{"id": "321"}, {"id": "123"}] + ) as patched_serialize_forum_tag, + ): + returned = await rest_client._create_guild_channel( + mock_partial_guild, + "general", + channels.ChannelType.GUILD_TEXT, + position=1, + topic="some topic", + nsfw=True, + bitrate=64, + user_limit=99, + rate_limit_per_user=60, + permission_overwrites=[overwrite1, overwrite2], + region="wicky wicky", + category=mock_guild_category, + reason="we have got the power", + default_auto_archive_duration=default_auto_archive_duration, + default_thread_rate_limit_per_user=40, + default_forum_layout=channels.ForumLayoutType.LIST_VIEW, + default_sort_order=channels.ForumSortOrderType.LATEST_ACTIVITY, + available_tags=[tag1, tag2], + default_reaction_emoji=emoji, + ) + assert returned is patched__request.return_value - rest_client._request.assert_awaited_once_with( - expected_route, json=expected_json, reason="we have got the power" - ) - assert rest_client._entity_factory.serialize_permission_overwrite.call_count == 2 - rest_client._entity_factory.serialize_permission_overwrite.assert_has_calls( - [mock.call(overwrite1), mock.call(overwrite2)] - ) - assert rest_client._entity_factory.serialize_forum_tag.call_count == 2 - rest_client._entity_factory.serialize_forum_tag.assert_has_calls([mock.call(tag1), mock.call(tag2)]) + patched__request.assert_awaited_once_with( + expected_route, json=expected_json, reason="we have got the power" + ) + assert patched_serialize_permission_overwrite.call_count == 2 + patched_serialize_permission_overwrite.assert_has_calls([mock.call(overwrite1), mock.call(overwrite2)]) + assert patched_serialize_forum_tag.call_count == 2 + patched_serialize_forum_tag.assert_has_calls([mock.call(tag1), mock.call(tag2)]) @pytest.mark.parametrize( ("auto_archive_duration", "rate_limit_per_user"), @@ -5062,60 +6510,101 @@ async def test__create_guild_channel( async def test_create_message_thread( self, rest_client: rest.RESTClientImpl, - auto_archive_duration: typing.Union[int, datetime.datetime, float], - rate_limit_per_user: typing.Union[int, datetime.datetime, float], + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + auto_archive_duration: time.Intervalish, + rate_limit_per_user: time.Intervalish, ): - expected_route = routes.POST_MESSAGE_THREADS.compile(channel=123432, message=595959) + expected_route = routes.POST_MESSAGE_THREADS.compile(channel=4560, message=101) expected_payload = {"name": "Sass alert!!!", "auto_archive_duration": 12322, "rate_limit_per_user": 42069} - rest_client._request = mock.AsyncMock(return_value={"id": "54123123", "name": "dlksksldalksad"}) - rest_client._entity_factory.deserialize_guild_thread.return_value = mock.Mock(channels.GuildPublicThread) - - result = await rest_client.create_message_thread( - StubModel(123432), - StubModel(595959), - "Sass alert!!!", - auto_archive_duration=auto_archive_duration, - rate_limit_per_user=rate_limit_per_user, - reason="because we need one", - ) - assert result is rest_client._entity_factory.deserialize_guild_thread.return_value - rest_client._request.assert_awaited_once_with( - expected_route, json=expected_payload, reason="because we need one" - ) - rest_client._entity_factory.deserialize_guild_thread.assert_called_once_with(rest_client._request.return_value) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"id": "54123123", "name": "dlksksldalksad"}, + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, + "deserialize_guild_thread", + return_value=mock.Mock(channels.GuildPublicThread), + ) as patched_deserialize_guild_thread, + ): + result = await rest_client.create_message_thread( + mock_guild_text_channel, + mock_message, + "Sass alert!!!", + auto_archive_duration=auto_archive_duration, + rate_limit_per_user=rate_limit_per_user, + reason="because we need one", + ) + + assert result is patched_deserialize_guild_thread.return_value + patched__request.assert_awaited_once_with( + expected_route, json=expected_payload, reason="because we need one" + ) + patched_deserialize_guild_thread.assert_called_once_with(patched__request.return_value) - async def test_create_message_thread_without_optionals(self, rest_client: rest.RESTClientImpl): - expected_route = routes.POST_MESSAGE_THREADS.compile(channel=123432, message=595959) + async def test_create_message_thread_without_optionals( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): + expected_route = routes.POST_MESSAGE_THREADS.compile(channel=4560, message=101) expected_payload = {"name": "Sass alert!!!", "auto_archive_duration": 1440} - rest_client._request = mock.AsyncMock(return_value={"id": "54123123", "name": "dlksksldalksad"}) - rest_client._entity_factory.deserialize_guild_thread.return_value = mock.Mock(channels.GuildNewsThread) - result = await rest_client.create_message_thread( - StubModel(123432), StubModel(595959), "Sass alert!!!", reason="because why not" - ) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"id": "54123123", "name": "dlksksldalksad"}, + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_thread", return_value=mock.Mock(channels.GuildNewsThread) + ) as patched_deserialize_guild_thread, + ): + result = await rest_client.create_message_thread( + mock_guild_text_channel, mock_message, "Sass alert!!!", reason="because why not" + ) - assert result is rest_client._entity_factory.deserialize_guild_thread.return_value - rest_client._request.assert_awaited_once_with(expected_route, json=expected_payload, reason="because why not") - rest_client._entity_factory.deserialize_guild_thread.assert_called_once_with(rest_client._request.return_value) + assert result is patched_deserialize_guild_thread.return_value + patched__request.assert_awaited_once_with(expected_route, json=expected_payload, reason="because why not") + patched_deserialize_guild_thread.assert_called_once_with(patched__request.return_value) - async def test_create_message_thread_with_all_undefined(self, rest_client: rest.RESTClientImpl): - expected_route = routes.POST_MESSAGE_THREADS.compile(channel=123432, message=595959) + async def test_create_message_thread_with_all_undefined( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): + expected_route = routes.POST_MESSAGE_THREADS.compile(channel=4560, message=101) expected_payload = {"name": "Sass alert!!!"} - rest_client._request = mock.AsyncMock(return_value={"id": "54123123", "name": "dlksksldalksad"}) - rest_client._entity_factory.deserialize_guild_thread.return_value = mock.Mock(channels.GuildNewsThread) - result = await rest_client.create_message_thread( - StubModel(123432), - StubModel(595959), - "Sass alert!!!", - auto_archive_duration=undefined.UNDEFINED, - reason="because why not", - ) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"id": "54123123", "name": "dlksksldalksad"}, + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_thread", return_value=mock.Mock(channels.GuildNewsThread) + ) as patched_deserialize_guild_thread, + ): + result = await rest_client.create_message_thread( + mock_guild_text_channel, + mock_message, + "Sass alert!!!", + auto_archive_duration=undefined.UNDEFINED, + reason="because why not", + ) - assert result is rest_client._entity_factory.deserialize_guild_thread.return_value - rest_client._request.assert_awaited_once_with(expected_route, json=expected_payload, reason="because why not") - rest_client._entity_factory.deserialize_guild_thread.assert_called_once_with(rest_client._request.return_value) + assert result is patched_deserialize_guild_thread.return_value + patched__request.assert_awaited_once_with(expected_route, json=expected_payload, reason="because why not") + patched_deserialize_guild_thread.assert_called_once_with(patched__request.return_value) @pytest.mark.parametrize( ("auto_archive_duration", "rate_limit_per_user"), @@ -5124,10 +6613,11 @@ async def test_create_message_thread_with_all_undefined(self, rest_client: rest. async def test_create_thread( self, rest_client: rest.RESTClientImpl, - auto_archive_duration: typing.Union[int, datetime.datetime, float], - rate_limit_per_user: typing.Union[int, datetime.datetime, float], + mock_guild_text_channel: channels.GuildTextChannel, + auto_archive_duration: time.Intervalish, + rate_limit_per_user: time.Intervalish, ): - expected_route = routes.POST_CHANNEL_THREADS.compile(channel=321123) + expected_route = routes.POST_CHANNEL_THREADS.compile(channel=4560) expected_payload = { "name": "Something something send help, they're keeping the catgirls locked up at ", "auto_archive_duration": 54123, @@ -5135,63 +6625,95 @@ async def test_create_thread( "invitable": True, "rate_limit_per_user": 101, } - rest_client._request = mock.AsyncMock(return_value={"id": "54123123", "name": "dlksksldalksad"}) - result = await rest_client.create_thread( - StubModel(321123), - channels.ChannelType.GUILD_NEWS_THREAD, - "Something something send help, they're keeping the catgirls locked up at ", - auto_archive_duration=auto_archive_duration, - invitable=True, - rate_limit_per_user=rate_limit_per_user, - reason="think of the catgirls!!! >:3", - ) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"id": "54123123", "name": "dlksksldalksad"}, + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_thread" + ) as patched_deserialize_guild_thread, + ): + result = await rest_client.create_thread( + mock_guild_text_channel, + channels.ChannelType.GUILD_NEWS_THREAD, + "Something something send help, they're keeping the catgirls locked up at ", + auto_archive_duration=auto_archive_duration, + invitable=True, + rate_limit_per_user=rate_limit_per_user, + reason="think of the catgirls!!! >:3", + ) - assert result is rest_client._entity_factory.deserialize_guild_thread.return_value - rest_client._request.assert_awaited_once_with( - expected_route, json=expected_payload, reason="think of the catgirls!!! >:3" - ) - rest_client._entity_factory.deserialize_guild_thread.assert_called_once_with(rest_client._request.return_value) + assert result is patched_deserialize_guild_thread.return_value + patched__request.assert_awaited_once_with( + expected_route, json=expected_payload, reason="think of the catgirls!!! >:3" + ) + patched_deserialize_guild_thread.assert_called_once_with(patched__request.return_value) - async def test_create_thread_without_optionals(self, rest_client: rest.RESTClientImpl): - expected_route = routes.POST_CHANNEL_THREADS.compile(channel=321123) + async def test_create_thread_without_optionals( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): + expected_route = routes.POST_CHANNEL_THREADS.compile(channel=4560) expected_payload = { "name": "Something something send help, they're keeping the catgirls locked up at ", "auto_archive_duration": 1440, "type": 12, } - rest_client._request = mock.AsyncMock(return_value={"id": "54123123", "name": "dlksksldalksad"}) - result = await rest_client.create_thread( - StubModel(321123), - channels.ChannelType.GUILD_PRIVATE_THREAD, - "Something something send help, they're keeping the catgirls locked up at ", - reason="because why not", - ) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"id": "54123123", "name": "dlksksldalksad"}, + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_thread" + ) as patched_deserialize_guild_thread, + ): + result = await rest_client.create_thread( + mock_guild_text_channel, + channels.ChannelType.GUILD_PRIVATE_THREAD, + "Something something send help, they're keeping the catgirls locked up at ", + ) - assert result is rest_client._entity_factory.deserialize_guild_thread.return_value - rest_client._request.assert_awaited_once_with(expected_route, json=expected_payload, reason="because why not") - rest_client._entity_factory.deserialize_guild_thread.assert_called_once_with(rest_client._request.return_value) + assert result is patched_deserialize_guild_thread.return_value + patched__request.assert_awaited_once_with(expected_route, json=expected_payload, reason=undefined.UNDEFINED) + patched_deserialize_guild_thread.assert_called_once_with(patched__request.return_value) - async def test_create_thread_with_all_undefined(self, rest_client: rest.RESTClientImpl): - expected_route = routes.POST_CHANNEL_THREADS.compile(channel=321123) + async def test_create_thread_with_all_undefined( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): + expected_route = routes.POST_CHANNEL_THREADS.compile(channel=4560) expected_payload = { "name": "Something something send help, they're keeping the catgirls locked up at ", "type": 12, } - rest_client._request = mock.AsyncMock(return_value={"id": "54123123", "name": "dlksksldalksad"}) - result = await rest_client.create_thread( - StubModel(321123), - channels.ChannelType.GUILD_PRIVATE_THREAD, - "Something something send help, they're keeping the catgirls locked up at ", - auto_archive_duration=undefined.UNDEFINED, - reason="because why not", - ) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"id": "54123123", "name": "dlksksldalksad"}, + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_thread" + ) as patched_deserialize_guild_thread, + ): + result = await rest_client.create_thread( + mock_guild_text_channel, + channels.ChannelType.GUILD_PRIVATE_THREAD, + "Something something send help, they're keeping the catgirls locked up at ", + auto_archive_duration=undefined.UNDEFINED, + ) - assert result is rest_client._entity_factory.deserialize_guild_thread.return_value - rest_client._request.assert_awaited_once_with(expected_route, json=expected_payload, reason="because why not") - rest_client._entity_factory.deserialize_guild_thread.assert_called_once_with(rest_client._request.return_value) + assert result is patched_deserialize_guild_thread.return_value + patched__request.assert_awaited_once_with(expected_route, json=expected_payload, reason=undefined.UNDEFINED) + patched_deserialize_guild_thread.assert_called_once_with(patched__request.return_value) @pytest.mark.parametrize( ("auto_archive_duration", "rate_limit_per_user"), @@ -5200,19 +6722,20 @@ async def test_create_thread_with_all_undefined(self, rest_client: rest.RESTClie async def test_create_forum_post_when_no_form( self, rest_client: rest.RESTClientImpl, - auto_archive_duration: typing.Union[int, datetime.datetime, float], - rate_limit_per_user: typing.Union[int, datetime.datetime, float], - ): - attachment_obj = object() - attachment_obj2 = object() - component_obj = object() - component_obj2 = object() - embed_obj = object() - embed_obj2 = object() - poll_obj = object() + mock_guild_text_channel: channels.GuildTextChannel, + auto_archive_duration: time.Intervalish, + rate_limit_per_user: time.Intervalish, + ): + attachment_obj = mock.Mock() + attachment_obj2 = mock.Mock() + component_obj = mock.Mock() + component_obj2 = mock.Mock() + embed_obj = mock.Mock() + embed_obj2 = mock.Mock() + poll_obj = mock.Mock() mock_body = data_binding.JSONObjectBuilder() - expected_route = routes.POST_CHANNEL_THREADS.compile(channel=321123) + expected_route = routes.POST_CHANNEL_THREADS.compile(channel=4560) expected_payload = { "name": "Post with secret content!", "auto_archive_duration": 54123, @@ -5220,34 +6743,43 @@ async def test_create_forum_post_when_no_form( "applied_tags": ["12220", "12201"], "message": mock_body, } - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, None)) - rest_client._request = mock.AsyncMock(return_value={"some": "message"}) - result = await rest_client.create_forum_post( - StubModel(321123), - "Post with secret content!", - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - poll=poll_obj, - sticker=132543, - stickers=[654234, 123321], - tts=True, - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - flags=54123, - auto_archive_duration=auto_archive_duration, - rate_limit_per_user=rate_limit_per_user, - tags=[12220, 12201], - reason="Secrets!!", - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"some": "message"} + ) as patched__request, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, None) + ) as patched__build_message_payload, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_public_thread" + ) as patched_deserialize_guild_public_thread, + ): + result = await rest_client.create_forum_post( + mock_guild_text_channel, + "Post with secret content!", + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + poll=poll_obj, + sticker=132543, + stickers=[654234, 123321], + tts=True, + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + flags=54123, + auto_archive_duration=auto_archive_duration, + rate_limit_per_user=rate_limit_per_user, + tags=[snowflakes.Snowflake(12220), snowflakes.Snowflake(12201)], + reason="Secrets!!", + ) - rest_client._build_message_payload.assert_called_once_with( + patched__build_message_payload.assert_called_once_with( content="new content", attachment=attachment_obj, attachments=[attachment_obj2], @@ -5266,11 +6798,9 @@ async def test_create_forum_post_when_no_form( flags=54123, ) - assert result is rest_client._entity_factory.deserialize_guild_public_thread.return_value - rest_client._request.assert_awaited_once_with(expected_route, json=expected_payload, reason="Secrets!!") - rest_client._entity_factory.deserialize_guild_public_thread.assert_called_once_with( - rest_client._request.return_value - ) + assert result is patched_deserialize_guild_public_thread.return_value + patched__request.assert_awaited_once_with(expected_route, json=expected_payload, reason="Secrets!!") + patched_deserialize_guild_public_thread.assert_called_once_with(patched__request.return_value) @pytest.mark.parametrize( ("auto_archive_duration", "rate_limit_per_user"), @@ -5279,26 +6809,59 @@ async def test_create_forum_post_when_no_form( async def test_create_forum_post_when_form( self, rest_client: rest.RESTClientImpl, - auto_archive_duration: typing.Union[int, datetime.datetime, float], - rate_limit_per_user: typing.Union[int, datetime.datetime, float], - ): - attachment_obj = object() - attachment_obj2 = object() - component_obj = object() - component_obj2 = object() - embed_obj = object() - embed_obj2 = object() - poll_obj = object() + mock_guild_text_channel: channels.GuildTextChannel, + auto_archive_duration: time.Intervalish, + rate_limit_per_user: time.Intervalish, + ): + attachment_obj = mock.Mock() + attachment_obj2 = mock.Mock() + component_obj = mock.Mock() + component_obj2 = mock.Mock() + embed_obj = mock.Mock() + embed_obj2 = mock.Mock() + poll_obj = mock.Mock() mock_body = {"mock": "message body"} mock_form = mock.Mock() - expected_route = routes.POST_CHANNEL_THREADS.compile(channel=321123) - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, mock_form)) - rest_client._request = mock.AsyncMock(return_value={"some": "message"}) + expected_route = routes.POST_CHANNEL_THREADS.compile(channel=4560) + + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"some": "message"} + ) as patched__request, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, mock_form) + ) as patched__build_message_payload, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_public_thread" + ) as patched_deserialize_guild_public_thread, + ): + result = await rest_client.create_forum_post( + mock_guild_text_channel, + "Post with secret content!", + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + poll=poll_obj, + sticker=314542, + stickers=[56234, 123312], + tts=True, + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + flags=54123, + auto_archive_duration=auto_archive_duration, + rate_limit_per_user=rate_limit_per_user, + tags=[snowflakes.Snowflake(12220), snowflakes.Snowflake(12201)], + reason="Secrets!!", + ) + assert result is patched_deserialize_guild_public_thread.return_value - result = await rest_client.create_forum_post( - StubModel(321123), - "Post with secret content!", + patched__build_message_payload.assert_called_once_with( content="new content", attachment=attachment_obj, attachments=[attachment_obj2], @@ -5311,29 +6874,7 @@ async def test_create_forum_post_when_form( stickers=[56234, 123312], tts=True, mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - flags=54123, - auto_archive_duration=auto_archive_duration, - rate_limit_per_user=rate_limit_per_user, - tags=[12220, 12201], - reason="Secrets!!", - ) - - rest_client._build_message_payload.assert_called_once_with( - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - poll=poll_obj, - sticker=314542, - stickers=[56234, 123312], - tts=True, - mentions_everyone=False, - mentions_reply=undefined.UNDEFINED, + mentions_reply=undefined.UNDEFINED, user_mentions=[9876], role_mentions=[1234], flags=54123, @@ -5345,340 +6886,455 @@ async def test_create_forum_post_when_form( b'"applied_tags":["12220","12201"],"message":{"mock":"message body"}}', content_type="application/json", ) + patched__request.assert_awaited_once_with(expected_route, form_builder=mock_form, reason="Secrets!!") + patched_deserialize_guild_public_thread.assert_called_once_with(patched__request.return_value) - assert result is rest_client._entity_factory.deserialize_guild_public_thread.return_value - rest_client._request.assert_awaited_once_with(expected_route, form_builder=mock_form, reason="Secrets!!") - rest_client._entity_factory.deserialize_guild_public_thread.assert_called_once_with( - rest_client._request.return_value - ) - - async def test_join_thread(self, rest_client: rest.RESTClientImpl): - rest_client._request = mock.AsyncMock() - - await rest_client.join_thread(StubModel(54123123)) - - rest_client._request.assert_awaited_once_with(routes.PUT_MY_THREAD_MEMBER.compile(channel=54123123)) - - async def test_add_thread_member(self, rest_client: rest.RESTClientImpl): - rest_client._request = mock.AsyncMock() - - # why is 8 afraid of 6 and 7? - await rest_client.add_thread_member(StubModel(789), StubModel(666)) - - rest_client._request.assert_awaited_once_with(routes.PUT_THREAD_MEMBER.compile(channel=789, user=666)) + async def test_join_thread( + self, rest_client: rest.RESTClientImpl, mock_guild_text_channel: channels.GuildTextChannel + ): + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.join_thread(mock_guild_text_channel) - async def test_leave_thread(self, rest_client: rest.RESTClientImpl): - rest_client._request = mock.AsyncMock() + patched__request.assert_awaited_once_with(routes.PUT_MY_THREAD_MEMBER.compile(channel=4560)) - await rest_client.leave_thread(StubModel(54123123)) + async def test_add_thread_member( + self, + rest_client: rest.RESTClientImpl, + mock_guild_public_thread_channel: channels.GuildThreadChannel, + mock_user: users.User, + ): + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + # why is 8 afraid of 6 and 7? + await rest_client.add_thread_member(mock_guild_public_thread_channel, mock_user) - rest_client._request.assert_awaited_once_with(routes.DELETE_MY_THREAD_MEMBER.compile(channel=54123123)) + patched__request.assert_awaited_once_with(routes.PUT_THREAD_MEMBER.compile(channel=45611, user=789)) - async def test_remove_thread_member(self, rest_client: rest.RESTClientImpl): - rest_client._request = mock.AsyncMock() + async def test_leave_thread( + self, rest_client: rest.RESTClientImpl, mock_guild_public_thread_channel: channels.GuildThreadChannel + ): + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.leave_thread(mock_guild_public_thread_channel) - await rest_client.remove_thread_member(StubModel(669), StubModel(421)) + patched__request.assert_awaited_once_with(routes.DELETE_MY_THREAD_MEMBER.compile(channel=45611)) - rest_client._request.assert_awaited_once_with(routes.DELETE_THREAD_MEMBER.compile(channel=669, user=421)) + async def test_remove_thread_member( + self, + rest_client: rest.RESTClientImpl, + mock_guild_public_thread_channel: channels.GuildThreadChannel, + mock_user: users.User, + ): + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.remove_thread_member(mock_guild_public_thread_channel, mock_user) - async def test_fetch_thread_member(self, rest_client: rest.RESTClientImpl): - rest_client._request = mock.AsyncMock(return_value={"id": "9239292", "user_id": "949494"}) + patched__request.assert_awaited_once_with(routes.DELETE_THREAD_MEMBER.compile(channel=45611, user=789)) - result = await rest_client.fetch_thread_member(StubModel(55445454), StubModel(45454454)) + async def test_fetch_thread_member( + self, + rest_client: rest.RESTClientImpl, + mock_guild_public_thread_channel: channels.GuildThreadChannel, + mock_user: users.User, + ): + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"id": "9239292", "user_id": "949494"}, + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_thread_member" + ) as patched_deserialize_thread_member, + ): + result = await rest_client.fetch_thread_member(mock_guild_public_thread_channel, mock_user) - assert result is rest_client.entity_factory.deserialize_thread_member.return_value - rest_client.entity_factory.deserialize_thread_member.assert_called_once_with(rest_client._request.return_value) - rest_client._request.assert_awaited_once_with(routes.GET_THREAD_MEMBER.compile(channel=55445454, user=45454454)) + assert result is patched_deserialize_thread_member.return_value + patched_deserialize_thread_member.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with(routes.GET_THREAD_MEMBER.compile(channel=45611, user=789)) - async def test_fetch_thread_members(self, rest_client: rest.RESTClientImpl): + async def test_fetch_thread_members( + self, rest_client: rest.RESTClientImpl, mock_guild_public_thread_channel: channels.GuildThreadChannel + ): mock_payload_1 = mock.Mock() mock_payload_2 = mock.Mock() mock_payload_3 = mock.Mock() mock_member_1 = mock.Mock() mock_member_2 = mock.Mock() mock_member_3 = mock.Mock() - rest_client._request = mock.AsyncMock(return_value=[mock_payload_1, mock_payload_2, mock_payload_3]) - rest_client._entity_factory.deserialize_thread_member = mock.Mock( - side_effect=[mock_member_1, mock_member_2, mock_member_3] - ) - result = await rest_client.fetch_thread_members(StubModel(110101010101)) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value=[mock_payload_1, mock_payload_2, mock_payload_3], + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, + "deserialize_thread_member", + side_effect=[mock_member_1, mock_member_2, mock_member_3], + ) as patched_deserialize_thread_member, + ): + result = await rest_client.fetch_thread_members(mock_guild_public_thread_channel) - assert result == [mock_member_1, mock_member_2, mock_member_3] - rest_client._request.assert_awaited_once_with(routes.GET_THREAD_MEMBERS.compile(channel=110101010101)) - rest_client._entity_factory.deserialize_thread_member.assert_has_calls( - [mock.call(mock_payload_1), mock.call(mock_payload_2), mock.call(mock_payload_3)] - ) + assert result == [mock_member_1, mock_member_2, mock_member_3] + patched__request.assert_awaited_once_with(routes.GET_THREAD_MEMBERS.compile(channel=45611)) + patched_deserialize_thread_member.assert_has_calls( + [mock.call(mock_payload_1), mock.call(mock_payload_2), mock.call(mock_payload_3)] + ) + @pytest.mark.skip(reason="TODO") async def test_fetch_active_threads(self, rest_client: rest.RESTClientImpl): ... - async def test_reposition_channels(self, rest_client): + async def test_reposition_channels(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): expected_route = routes.PATCH_GUILD_CHANNELS.compile(guild=123) expected_json = [{"id": "456", "position": 1}, {"id": "789", "position": 2}] - rest_client._request = mock.AsyncMock() - await rest_client.reposition_channels( - StubModel(123), {1: StubModel(456), 2: StubModel(789)}, reason="because why not" - ) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.reposition_channels( + mock_partial_guild, + {1: make_guild_text_channel(456), 2: make_guild_text_channel(789)}, + reason="because why not", + ) - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="because why not") + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="because why not") - async def test_reposition_channels_no_channels(self, rest_client): + async def test_reposition_channels_no_channels( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): expected_route = routes.PATCH_GUILD_CHANNELS.compile(guild=123) - rest_client._request = mock.AsyncMock() - await rest_client.reposition_channels(StubModel(123), reason="because why not") + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.reposition_channels(mock_partial_guild, reason="because why not") - rest_client._request.assert_awaited_once_with(expected_route, json=[], reason="because why not") + patched__request.assert_awaited_once_with(expected_route, json=[], reason="because why not") - async def test_fetch_member(self, rest_client): - member = StubModel(789) - expected_route = routes.GET_GUILD_MEMBER.compile(guild=123, user=456) - rest_client._request = mock.AsyncMock(return_value={"id": "789"}) - rest_client._entity_factory.deserialize_member = mock.Mock(return_value=member) + async def test_fetch_member( + self, rest_client: rest.RESTClientImpl, hikari_partial_guild: guilds.PartialGuild, hikari_user: users.User + ): + member = hikari_user + expected_route = routes.GET_GUILD_MEMBER.compile(guild=123, user=789) - assert await rest_client.fetch_member(StubModel(123), StubModel(456)) == member + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "789"} + ) as patched__request, + mock.patch.object( + rest_client._entity_factory, "deserialize_member", return_value=member + ) as patched_deserialize_member, + ): + assert await rest_client.fetch_member(hikari_partial_guild, hikari_user) == member - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_member.assert_called_once_with({"id": "789"}, guild_id=123) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_member.assert_called_once_with({"id": "789"}, guild_id=123) - async def test_fetch_my_member(self, rest_client): - expected_route = routes.GET_MY_GUILD_MEMBER.compile(guild=45123) - rest_client._request = mock.AsyncMock(return_value={"id": "595995"}) + async def test_fetch_my_member(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): + expected_route = routes.GET_MY_GUILD_MEMBER.compile(guild=123) - result = await rest_client.fetch_my_member(StubModel(45123)) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "595995"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_member") as patched_deserialize_member, + ): + result = await rest_client.fetch_my_member(mock_partial_guild) - assert result is rest_client._entity_factory.deserialize_member.return_value - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_member.assert_called_once_with( - rest_client._request.return_value, guild_id=45123 - ) + assert result is patched_deserialize_member.return_value + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_member.assert_called_once_with(patched__request.return_value, guild_id=123) - async def test_search_members(self, rest_client): - member = StubModel(645234123) - expected_route = routes.GET_GUILD_MEMBERS_SEARCH.compile(guild=645234123) + async def test_search_members(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): + member = mock.Mock(guilds.Member, id=645234123) + expected_route = routes.GET_GUILD_MEMBERS_SEARCH.compile(guild=123) expected_query = {"query": "a name", "limit": "1000"} - rest_client._request = mock.AsyncMock(return_value=[{"id": "764435"}]) - rest_client._entity_factory.deserialize_member = mock.Mock(return_value=member) - assert await rest_client.search_members(StubModel(645234123), "a name") == [member] + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "764435"}] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_member", return_value=member + ) as patched_deserialize_member, + ): + assert await rest_client.search_members(mock_partial_guild, "a name") == [member] - rest_client._entity_factory.deserialize_member.assert_called_once_with({"id": "764435"}, guild_id=645234123) - rest_client._request.assert_awaited_once_with(expected_route, query=expected_query) + patched_deserialize_member.assert_called_once_with({"id": "764435"}, guild_id=123) + patched__request.assert_awaited_once_with(expected_route, query=expected_query) - async def test_edit_member(self, rest_client): - expected_route = routes.PATCH_GUILD_MEMBER.compile(guild=123, user=456) + async def test_edit_member( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_voice_channel: channels.GuildVoiceChannel, + mock_user: users.User, + ): + expected_route = routes.PATCH_GUILD_MEMBER.compile(guild=123, user=789) expected_json = { "nick": "test", "roles": ["654", "321"], "mute": True, "deaf": False, - "channel_id": "987", + "channel_id": "4562", "communication_disabled_until": "2021-10-18T07:18:11.554023+00:00", } - rest_client._request = mock.AsyncMock(return_value={"id": "789"}) mock_timestamp = datetime.datetime(2021, 10, 18, 7, 18, 11, 554023, tzinfo=datetime.timezone.utc) - result = await rest_client.edit_member( - StubModel(123), - StubModel(456), - nickname="test", - roles=[StubModel(654), StubModel(321)], - mute=True, - deaf=False, - voice_channel=StubModel(987), - communication_disabled_until=mock_timestamp, - reason="because i can", - ) - assert result is rest_client._entity_factory.deserialize_member.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "789"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_member") as patched_deserialize_member, + ): + result = await rest_client.edit_member( + mock_partial_guild, + mock_user, + nickname="test", + roles=[make_partial_role(654), make_partial_role(321)], + mute=True, + deaf=False, + voice_channel=mock_guild_voice_channel, + communication_disabled_until=mock_timestamp, + reason="because i can", + ) + assert result is patched_deserialize_member.return_value - rest_client._entity_factory.deserialize_member.assert_called_once_with( - rest_client._request.return_value, guild_id=123 - ) - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="because i can") + patched_deserialize_member.assert_called_once_with(patched__request.return_value, guild_id=123) + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="because i can") - async def test_edit_member_when_voice_channel_is_None(self, rest_client): - expected_route = routes.PATCH_GUILD_MEMBER.compile(guild=123, user=456) + async def test_edit_member_when_voice_channel_is_None( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild, mock_user: users.User + ): + expected_route = routes.PATCH_GUILD_MEMBER.compile(guild=123, user=789) expected_json = {"nick": "test", "roles": ["654", "321"], "mute": True, "deaf": False, "channel_id": None} - rest_client._request = mock.AsyncMock(return_value={"id": "789"}) - result = await rest_client.edit_member( - StubModel(123), - StubModel(456), - nickname="test", - roles=[StubModel(654), StubModel(321)], - mute=True, - deaf=False, - voice_channel=None, - reason="because i can", - ) - assert result is rest_client._entity_factory.deserialize_member.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "789"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_member") as patched_deserialize_member, + ): + result = await rest_client.edit_member( + mock_partial_guild, + mock_user, + nickname="test", + roles=[make_partial_role(654), make_partial_role(321)], + mute=True, + deaf=False, + voice_channel=None, + reason="because i can", + ) + assert result is patched_deserialize_member.return_value - rest_client._entity_factory.deserialize_member.assert_called_once_with( - rest_client._request.return_value, guild_id=123 - ) - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="because i can") + patched_deserialize_member.assert_called_once_with(patched__request.return_value, guild_id=123) + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="because i can") - async def test_edit_member_when_communication_disabled_until_is_None(self, rest_client): - expected_route = routes.PATCH_GUILD_MEMBER.compile(guild=123, user=456) + async def test_edit_member_when_communication_disabled_until_is_None( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild, mock_user: users.User + ): + expected_route = routes.PATCH_GUILD_MEMBER.compile(guild=123, user=789) expected_json = {"communication_disabled_until": None} - rest_client._request = mock.AsyncMock(return_value={"id": "789"}) - result = await rest_client.edit_member( - StubModel(123), StubModel(456), communication_disabled_until=None, reason="because i can" - ) - assert result is rest_client._entity_factory.deserialize_member.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "789"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_member") as patched_deserialize_member, + ): + result = await rest_client.edit_member( + mock_partial_guild, mock_user, communication_disabled_until=None, reason="because i can" + ) + assert result is patched_deserialize_member.return_value - rest_client._entity_factory.deserialize_member.assert_called_once_with( - rest_client._request.return_value, guild_id=123 - ) - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="because i can") + patched_deserialize_member.assert_called_once_with(patched__request.return_value, guild_id=123) + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="because i can") - async def test_edit_member_without_optionals(self, rest_client): - expected_route = routes.PATCH_GUILD_MEMBER.compile(guild=123, user=456) - rest_client._request = mock.AsyncMock(return_value={"id": "789"}) + async def test_edit_member_without_optionals( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild, mock_user: users.User + ): + expected_route = routes.PATCH_GUILD_MEMBER.compile(guild=123, user=789) - result = await rest_client.edit_member(StubModel(123), StubModel(456)) - assert result is rest_client._entity_factory.deserialize_member.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "789"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_member") as patched_deserialize_member, + ): + result = await rest_client.edit_member(mock_partial_guild, mock_user) + assert result is patched_deserialize_member.return_value - rest_client._entity_factory.deserialize_member.assert_called_once_with( - rest_client._request.return_value, guild_id=123 - ) - rest_client._request.assert_awaited_once_with(expected_route, json={}, reason=undefined.UNDEFINED) + patched_deserialize_member.assert_called_once_with(patched__request.return_value, guild_id=123) + patched__request.assert_awaited_once_with(expected_route, json={}, reason=undefined.UNDEFINED) - async def test_edit_my_member(self, rest_client, file_resource): + async def test_edit_my_member( + self, + rest_client: rest.RESTClientImpl, + file_resource_patch: files.Resource[typing.Any], + hikari_partial_guild: guilds.PartialGuild, + ): expected_route = routes.PATCH_MY_GUILD_MEMBER.compile(guild=123) - expected_json = {"nick": "test", "avatar": "avatar data", "banner": "banner data", "bio": "do not the beanos."} + expected_json = {"nick": "test", "avatar": "some data", "banner": "some data", "bio": "do not the beanos."} rest_client._request = mock.AsyncMock(return_value={"id": "789"}) - avatar_resource = file_resource("avatar data") - banner_resource = file_resource("banner data") - - with mock.patch.object(files, "ensure_resource", side_effect=[avatar_resource, banner_resource]): + with mock.patch.object(rest_client._entity_factory, "deserialize_member") as patched_deserialize_member: result = await rest_client.edit_my_member( - StubModel(123), + hikari_partial_guild, nickname="test", - avatar=avatar_resource, - banner=banner_resource, + avatar="avatar.png", + banner="banner.jpeg", bio="do not the beanos.", reason="because i can", ) - assert result is rest_client._entity_factory.deserialize_member.return_value + assert result is patched_deserialize_member.return_value - rest_client._entity_factory.deserialize_member.assert_called_once_with( - rest_client._request.return_value, guild_id=123 - ) + patched_deserialize_member.assert_called_once_with(rest_client._request.return_value, guild_id=123) rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="because i can") - async def test_edit_my_member_without_optionals(self, rest_client): + async def test_edit_my_member_without_optionals( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): expected_route = routes.PATCH_MY_GUILD_MEMBER.compile(guild=123) - rest_client._request = mock.AsyncMock(return_value={"id": "789"}) - result = await rest_client.edit_my_member(StubModel(123)) - assert result is rest_client._entity_factory.deserialize_member.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "789"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_member") as patched_deserialize_member, + ): + result = await rest_client.edit_my_member(mock_partial_guild) + assert result is patched_deserialize_member.return_value - rest_client._entity_factory.deserialize_member.assert_called_once_with( - rest_client._request.return_value, guild_id=123 - ) - rest_client._request.assert_awaited_once_with(expected_route, json={}, reason=undefined.UNDEFINED) + patched_deserialize_member.assert_called_once_with(patched__request.return_value, guild_id=123) + patched__request.assert_awaited_once_with(expected_route, json={}, reason=undefined.UNDEFINED) - async def test_edit_my_member_with_nulls(self, rest_client): + async def test_edit_my_member_with_nulls( + self, rest_client: rest.RESTClientImpl, hikari_partial_guild: guilds.PartialGuild + ): expected_route = routes.PATCH_MY_GUILD_MEMBER.compile(guild=123) expected_json = {"nick": None, "avatar": None, "banner": None, "bio": None} rest_client._request = mock.AsyncMock(return_value={"id": "789"}) - result = await rest_client.edit_my_member(StubModel(123), nickname=None, avatar=None, banner=None, bio=None) - assert result is rest_client._entity_factory.deserialize_member.return_value + with mock.patch.object(rest_client._entity_factory, "deserialize_member") as patched_deserialize_member: + result = await rest_client.edit_my_member( + hikari_partial_guild, nickname=None, avatar=None, banner=None, bio=None + ) - rest_client._entity_factory.deserialize_member.assert_called_once_with( - rest_client._request.return_value, guild_id=123 - ) + assert result is patched_deserialize_member.return_value + + patched_deserialize_member.assert_called_once_with(rest_client._request.return_value, guild_id=123) rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason=undefined.UNDEFINED) - async def test_add_role_to_member(self, rest_client): - expected_route = routes.PUT_GUILD_MEMBER_ROLE.compile(guild=123, user=456, role=789) + async def test_add_role_to_member( + self, + rest_client: rest.RESTClientImpl, + hikari_partial_guild: guilds.PartialGuild, + hikari_user: users.User, + mock_partial_role: guilds.PartialRole, + ): + expected_route = routes.PUT_GUILD_MEMBER_ROLE.compile(guild=123, user=789, role=333) rest_client._request = mock.AsyncMock() - await rest_client.add_role_to_member(StubModel(123), StubModel(456), StubModel(789), reason="because i can") + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.add_role_to_member( + hikari_partial_guild, hikari_user, mock_partial_role, reason="because i can" + ) - rest_client._request.assert_awaited_once_with(expected_route, reason="because i can") + patched__request.assert_awaited_once_with(expected_route, reason="because i can") - async def test_remove_role_from_member(self, rest_client): - expected_route = routes.DELETE_GUILD_MEMBER_ROLE.compile(guild=123, user=456, role=789) - rest_client._request = mock.AsyncMock() + async def test_remove_role_from_member( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_user: users.User, + mock_partial_role: guilds.PartialRole, + ): + expected_route = routes.DELETE_GUILD_MEMBER_ROLE.compile(guild=123, user=789, role=333) - await rest_client.remove_role_from_member( - StubModel(123), StubModel(456), StubModel(789), reason="because i can" - ) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.remove_role_from_member( + mock_partial_guild, mock_user, mock_partial_role, reason="because i can" + ) - rest_client._request.assert_awaited_once_with(expected_route, reason="because i can") + patched__request.assert_awaited_once_with(expected_route, reason="because i can") - async def test_kick_user(self, rest_client): - expected_route = routes.DELETE_GUILD_MEMBER.compile(guild=123, user=456) - rest_client._request = mock.AsyncMock() + async def test_kick_user( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild, mock_user: users.User + ): + expected_route = routes.DELETE_GUILD_MEMBER.compile(guild=123, user=789) - await rest_client.kick_user(StubModel(123), StubModel(456), reason="because i can") + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.kick_user(mock_partial_guild, mock_user, reason="because i can") - rest_client._request.assert_awaited_once_with(expected_route, reason="because i can") + patched__request.assert_awaited_once_with(expected_route, reason="because i can") - async def test_ban_user(self, rest_client): - expected_route = routes.PUT_GUILD_BAN.compile(guild=123, user=456) + async def test_ban_user( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild, mock_user: users.User + ): + expected_route = routes.PUT_GUILD_BAN.compile(guild=123, user=789) expected_json = {"delete_message_seconds": 604800} - rest_client._request = mock.AsyncMock() - - await rest_client.ban_user( - StubModel(123), StubModel(456), delete_message_seconds=604800, reason="because i can" - ) - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="because i can") - - async def test_unban_user(self, rest_client): - expected_route = routes.DELETE_GUILD_BAN.compile(guild=123, user=456) - rest_client._request = mock.AsyncMock() - - await rest_client.unban_user(StubModel(123), StubModel(456), reason="because i can") + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.ban_user( + mock_partial_guild, mock_user, delete_message_seconds=604800, reason="because i can" + ) - rest_client._request.assert_awaited_once_with(expected_route, reason="because i can") + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="because i can") - async def test_fetch_ban(self, rest_client): - ban = StubModel(789) - expected_route = routes.GET_GUILD_BAN.compile(guild=123, user=456) - rest_client._request = mock.AsyncMock(return_value={"id": "789"}) - rest_client._entity_factory.deserialize_guild_member_ban = mock.Mock(return_value=ban) + async def test_unban_user( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild, mock_user: users.User + ): + expected_route = routes.DELETE_GUILD_BAN.compile(guild=123, user=789) - assert await rest_client.fetch_ban(StubModel(123), StubModel(456)) == ban + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.unban_user(mock_partial_guild, mock_user, reason="because i can") - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_guild_member_ban.assert_called_once_with({"id": "789"}) + patched__request.assert_awaited_once_with(expected_route, reason="because i can") - async def test_fetch_role(self, rest_client): - role = StubModel(456) - expected_route = routes.GET_GUILD_ROLE.compile(guild=123, role=456) - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_role = mock.Mock(return_value=role) + async def test_fetch_ban( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild, mock_user: users.User + ): + ban = mock.Mock(guilds.GuildBan) + expected_route = routes.GET_GUILD_BAN.compile(guild=123, user=789) - assert await rest_client.fetch_role(StubModel(123), StubModel(456)) is role + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "789"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_member_ban", return_value=ban + ) as patched_deserialize_guild_member_ban, + ): + assert await rest_client.fetch_ban(mock_partial_guild, mock_user) == ban - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_role.assert_called_once_with({"id": "456"}, guild_id=123) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_guild_member_ban.assert_called_once_with({"id": "789"}) - async def test_fetch_roles(self, rest_client): - role1 = StubModel(456) - role2 = StubModel(789) + async def test_fetch_roles(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): + role1 = make_partial_role(456) + role2 = make_partial_role(789) expected_route = routes.GET_GUILD_ROLES.compile(guild=123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) - rest_client._entity_factory.deserialize_role = mock.Mock(side_effect=[role1, role2]) - assert await rest_client.fetch_roles(StubModel(123)) == [role1, role2] + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "456"}, {"id": "789"}] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_role", side_effect=[role1, role2] + ) as patched_deserialize_role, + ): + assert await rest_client.fetch_roles(mock_partial_guild) == [role1, role2] - rest_client._request.assert_awaited_once_with(expected_route) - assert rest_client._entity_factory.deserialize_role.call_count == 2 - rest_client._entity_factory.deserialize_role.assert_has_calls( - [mock.call({"id": "456"}, guild_id=123), mock.call({"id": "789"}, guild_id=123)] - ) + patched__request.assert_awaited_once_with(expected_route) + assert patched_deserialize_role.call_count == 2 + patched_deserialize_role.assert_has_calls( + [mock.call({"id": "456"}, guild_id=123), mock.call({"id": "789"}, guild_id=123)] + ) - async def test_create_role(self, rest_client, file_resource_patch): + async def test_create_role( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + file_resource_patch: files.Resource[typing.Any], + ): expected_route = routes.POST_GUILD_ROLES.compile(guild=123) expected_json = { "name": "admin", @@ -5688,25 +7344,35 @@ async def test_create_role(self, rest_client, file_resource_patch): "icon": "some data", "mentionable": False, } - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - returned = await rest_client.create_role( - StubModel(123), - name="admin", - permissions=permissions.Permissions.ADMINISTRATOR, - color=colors.Color.from_int(12345), - hoist=True, - icon="icon.png", - mentionable=False, - reason="roles are cool", - ) - assert returned is rest_client._entity_factory.deserialize_role.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_role") as patched_deserialize_role, + ): + returned = await rest_client.create_role( + mock_partial_guild, + name="admin", + permissions=permissions.Permissions.ADMINISTRATOR, + color=colors.Color.from_int(12345), + hoist=True, + icon="icon.png", + mentionable=False, + reason="roles are cool", + ) + + assert returned is patched_deserialize_role.return_value - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="roles are cool") - rest_client._entity_factory.deserialize_role.assert_called_once_with({"id": "456"}, guild_id=123) + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="roles are cool") + patched_deserialize_role.assert_called_once_with({"id": "456"}, guild_id=123) - async def test_create_role_when_permissions_undefined(self, rest_client): - role = StubModel(456) + async def test_create_role_when_permissions_undefined( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_partial_role: guilds.PartialRole, + ): expected_route = routes.POST_GUILD_ROLES.compile(guild=123) expected_json = { "name": "admin", @@ -5715,45 +7381,60 @@ async def test_create_role_when_permissions_undefined(self, rest_client): "hoist": True, "mentionable": False, } - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_role = mock.Mock(return_value=role) - returned = await rest_client.create_role( - StubModel(123), - name="admin", - color=colors.Color.from_int(12345), - hoist=True, - mentionable=False, - reason="roles are cool", - ) - assert returned is role + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_role", return_value=mock_partial_role + ) as patched_deserialize_role, + ): + returned = await rest_client.create_role( + mock_partial_guild, + name="admin", + color=colors.Color.from_int(12345), + hoist=True, + mentionable=False, + reason="roles are cool", + ) + assert returned is mock_partial_role - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="roles are cool") - rest_client._entity_factory.deserialize_role.assert_called_once_with({"id": "456"}, guild_id=123) + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="roles are cool") + patched_deserialize_role.assert_called_once_with({"id": "456"}, guild_id=123) - async def test_create_role_when_color_and_colour_specified(self, rest_client): + async def test_create_role_when_color_and_colour_specified( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): with pytest.raises(TypeError, match=r"Can not specify 'color' and 'colour' together."): await rest_client.create_role( - StubModel(123), color=colors.Color.from_int(12345), colour=colors.Color.from_int(12345) + mock_partial_guild, color=colors.Color.from_int(12345), colour=colors.Color.from_int(12345) ) - async def test_create_role_when_icon_unicode_emoji_specified(self, rest_client): + async def test_create_role_when_icon_unicode_emoji_specified( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): with pytest.raises(TypeError, match=r"Can not specify 'icon' and 'unicode_emoji' together."): - await rest_client.create_role(StubModel(123), icon="icon.png", unicode_emoji="\N{OK HAND SIGN}") + await rest_client.create_role(mock_partial_guild, icon="icon.png", unicode_emoji="\N{OK HAND SIGN}") - async def test_reposition_roles(self, rest_client): + async def test_reposition_roles(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): expected_route = routes.PATCH_GUILD_ROLES.compile(guild=123) expected_json = [{"id": "456", "position": 1}, {"id": "789", "position": 2}] - rest_client._request = mock.AsyncMock() - - await rest_client.reposition_roles( - StubModel(123), {1: StubModel(456), 2: StubModel(789)}, reason="because why not" - ) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.reposition_roles( + mock_partial_guild, {1: make_partial_role(456), 2: make_partial_role(789)}, reason="because why not" + ) - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="because why not") + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="because why not") - async def test_edit_role(self, rest_client, file_resource_patch): - expected_route = routes.PATCH_GUILD_ROLE.compile(guild=123, role=789) + async def test_edit_role( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_partial_role: guilds.PartialRole, + file_resource_patch: files.Resource[typing.Any], + ): + expected_route = routes.PATCH_GUILD_ROLE.compile(guild=123, role=333) expected_json = { "name": "admin", "permissions": 8, @@ -5762,105 +7443,144 @@ async def test_edit_role(self, rest_client, file_resource_patch): "icon": "some data", "mentionable": False, } - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - returned = await rest_client.edit_role( - StubModel(123), - StubModel(789), - name="admin", - permissions=permissions.Permissions.ADMINISTRATOR, - color=colors.Color.from_int(12345), - hoist=True, - icon="icon.png", - mentionable=False, - reason="roles are cool", - ) - assert returned is rest_client._entity_factory.deserialize_role.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_role") as patched_deserialize_role, + ): + returned = await rest_client.edit_role( + mock_partial_guild, + mock_partial_role, + name="admin", + permissions=permissions.Permissions.ADMINISTRATOR, + color=colors.Color.from_int(12345), + hoist=True, + icon="icon.png", + mentionable=False, + reason="roles are cool", + ) + assert returned is patched_deserialize_role.return_value - rest_client._request.assert_awaited_once_with(expected_route, json=expected_json, reason="roles are cool") - rest_client._entity_factory.deserialize_role.assert_called_once_with({"id": "456"}, guild_id=123) + patched__request.assert_awaited_once_with(expected_route, json=expected_json, reason="roles are cool") + patched_deserialize_role.assert_called_once_with({"id": "456"}, guild_id=123) - async def test_edit_role_when_color_and_colour_specified(self, rest_client): + async def test_edit_role_when_color_and_colour_specified( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_partial_role: guilds.PartialRole, + ): with pytest.raises(TypeError, match=r"Can not specify 'color' and 'colour' together."): await rest_client.edit_role( - StubModel(123), StubModel(456), color=colors.Color.from_int(12345), colour=colors.Color.from_int(12345) + mock_partial_guild, + mock_partial_role, + color=colors.Color.from_int(12345), + colour=colors.Color.from_int(12345), ) - async def test_edit_role_when_icon_and_unicode_emoji_specified(self, rest_client): + async def test_edit_role_when_icon_and_unicode_emoji_specified( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_partial_role: guilds.PartialRole, + ): with pytest.raises(TypeError, match=r"Can not specify 'icon' and 'unicode_emoji' together."): await rest_client.edit_role( - StubModel(123), StubModel(456), icon="icon.png", unicode_emoji="\N{OK HAND SIGN}" + mock_partial_guild, mock_partial_role, icon="icon.png", unicode_emoji="\N{OK HAND SIGN}" ) - async def test_delete_role(self, rest_client): - expected_route = routes.DELETE_GUILD_ROLE.compile(guild=123, role=456) - rest_client._request = mock.AsyncMock() + async def test_delete_role( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_partial_role: guilds.PartialRole, + ): + expected_route = routes.DELETE_GUILD_ROLE.compile(guild=123, role=333) - await rest_client.delete_role(StubModel(123), StubModel(456), reason="testing") + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_role(mock_partial_guild, mock_partial_role, reason="because i can") - rest_client._request.assert_awaited_once_with(expected_route, reason="testing") + patched__request.assert_awaited_once_with(expected_route, reason="because i can") - async def test_estimate_guild_prune_count(self, rest_client): + async def test_estimate_guild_prune_count( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): expected_route = routes.GET_GUILD_PRUNE.compile(guild=123) expected_query = {"days": "1"} - rest_client._request = mock.AsyncMock(return_value={"pruned": "69"}) - assert await rest_client.estimate_guild_prune_count(StubModel(123), days=1) == 69 - rest_client._request.assert_awaited_once_with(expected_route, query=expected_query) + with mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"pruned": "69"} + ) as patched__request: + assert await rest_client.estimate_guild_prune_count(mock_partial_guild, days=1) == 69 + patched__request.assert_awaited_once_with(expected_route, query=expected_query) - async def test_estimate_guild_prune_count_with_include_roles(self, rest_client): + async def test_estimate_guild_prune_count_with_include_roles( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): expected_route = routes.GET_GUILD_PRUNE.compile(guild=123) expected_query = {"days": "1", "include_roles": "456,678"} - rest_client._request = mock.AsyncMock(return_value={"pruned": "69"}) - returned = await rest_client.estimate_guild_prune_count( - StubModel(123), days=1, include_roles=[StubModel(456), StubModel(678)] - ) - assert returned == 69 + with mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"pruned": "69"} + ) as patched__request: + returned = await rest_client.estimate_guild_prune_count( + mock_partial_guild, days=1, include_roles=[make_partial_role(456), make_partial_role(678)] + ) + assert returned == 69 - rest_client._request.assert_awaited_once_with(expected_route, query=expected_query) + patched__request.assert_awaited_once_with(expected_route, query=expected_query) - async def test_begin_guild_prune(self, rest_client): + async def test_begin_guild_prune(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): expected_route = routes.POST_GUILD_PRUNE.compile(guild=123) expected_json = {"days": 1, "compute_prune_count": True, "include_roles": ["456", "678"]} - rest_client._request = mock.AsyncMock(return_value={"pruned": "69"}) - returned = await rest_client.begin_guild_prune( - StubModel(123), - days=1, - compute_prune_count=True, - include_roles=[StubModel(456), StubModel(678)], - reason="cause inactive people bad", - ) - assert returned == 69 + with mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"pruned": "69"} + ) as patched__request: + returned = await rest_client.begin_guild_prune( + mock_partial_guild, + days=1, + compute_prune_count=True, + include_roles=[make_partial_role(456), make_partial_role(678)], + reason="cause inactive people bad", + ) + assert returned == 69 - rest_client._request.assert_awaited_once_with( - expected_route, json=expected_json, reason="cause inactive people bad" - ) + patched__request.assert_awaited_once_with( + expected_route, json=expected_json, reason="cause inactive people bad" + ) - async def test_fetch_guild_voice_regions(self, rest_client): - voice_region1 = StubModel(456) - voice_region2 = StubModel(789) + async def test_fetch_guild_voice_regions( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + voice_region1 = mock.Mock(voices.VoiceRegion, id="456") + voice_region2 = mock.Mock(voices.VoiceRegion, id="789") expected_route = routes.GET_GUILD_VOICE_REGIONS.compile(guild=123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) - rest_client._entity_factory.deserialize_voice_region = mock.Mock(side_effect=[voice_region1, voice_region2]) - assert await rest_client.fetch_guild_voice_regions(StubModel(123)) == [voice_region1, voice_region2] + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "456"}, {"id": "789"}] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_voice_region", side_effect=[voice_region1, voice_region2] + ) as patched_deserialize_voice_region, + ): + assert await rest_client.fetch_guild_voice_regions(mock_partial_guild) == [voice_region1, voice_region2] - rest_client._request.assert_awaited_once_with(expected_route) - assert rest_client._entity_factory.deserialize_voice_region.call_count == 2 - rest_client._entity_factory.deserialize_voice_region.assert_has_calls( - [mock.call({"id": "456"}), mock.call({"id": "789"})] - ) + patched__request.assert_awaited_once_with(expected_route) + assert patched_deserialize_voice_region.call_count == 2 + patched_deserialize_voice_region.assert_has_calls([mock.call({"id": "456"}), mock.call({"id": "789"})]) - async def test_fetch_guild_invites(self, rest_client): - invite1 = StubModel(456) - invite2 = StubModel(789) + async def test_fetch_guild_invites(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): + invite1 = make_invite_with_metadata("ashdfjhas") + invite2 = make_invite_with_metadata("asdjfhasj") expected_route = routes.GET_GUILD_INVITES.compile(guild=123) rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) rest_client._entity_factory.deserialize_invite = mock.Mock(side_effect=[invite1, invite2]) - assert await rest_client.fetch_guild_invites(StubModel(123)) == [invite1, invite2] + assert await rest_client.fetch_guild_invites(mock_partial_guild) == [invite1, invite2] rest_client._request.assert_awaited_once_with(expected_route) assert rest_client._entity_factory.deserialize_invite.call_count == 2 @@ -5868,16 +7588,18 @@ async def test_fetch_guild_invites(self, rest_client): [mock.call({"id": "456"}), mock.call({"id": "789"})] ) - async def test_fetch_guild_invites_with_metadata(self, rest_client): - invite1 = StubModel(456) - invite2 = StubModel(789) + async def test_fetch_guild_invites_with_metadata( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + invite1 = make_invite_with_metadata("456") + invite2 = make_invite_with_metadata("789") expected_route = routes.GET_GUILD_INVITES.compile(guild=123) rest_client._request = mock.AsyncMock( return_value=[{"id": "456", "created_at": "metadata"}, {"id": "789", "created_at": "metadata"}] ) rest_client._entity_factory.deserialize_invite_with_metadata = mock.Mock(side_effect=[invite1, invite2]) - assert await rest_client.fetch_guild_invites(StubModel(123)) == [invite1, invite2] + assert await rest_client.fetch_guild_invites(mock_partial_guild) == [invite1, invite2] rest_client._request.assert_awaited_once_with(expected_route) assert rest_client._entity_factory.deserialize_invite_with_metadata.call_count == 2 @@ -5885,170 +7607,255 @@ async def test_fetch_guild_invites_with_metadata(self, rest_client): [mock.call({"id": "456", "created_at": "metadata"}), mock.call({"id": "789", "created_at": "metadata"})] ) - async def test_fetch_integrations(self, rest_client): - integration1 = StubModel(456) - integration2 = StubModel(789) + async def test_fetch_integrations(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): + integration1 = mock.Mock(guilds.Integration, id=456) + integration2 = mock.Mock(guilds.Integration, id=789) expected_route = routes.GET_GUILD_INTEGRATIONS.compile(guild=123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "456"}, {"id": "789"}]) - rest_client._entity_factory.deserialize_integration = mock.Mock(side_effect=[integration1, integration2]) - assert await rest_client.fetch_integrations(StubModel(123)) == [integration1, integration2] + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "456"}, {"id": "789"}] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_integration", side_effect=[integration1, integration2] + ) as patched_deserialize_integration, + ): + assert await rest_client.fetch_integrations(mock_partial_guild) == [integration1, integration2] - rest_client._request.assert_awaited_once_with(expected_route) - assert rest_client._entity_factory.deserialize_integration.call_count == 2 - rest_client._entity_factory.deserialize_integration.assert_has_calls( - [mock.call({"id": "456"}, guild_id=123), mock.call({"id": "789"}, guild_id=123)] - ) + patched__request.assert_awaited_once_with(expected_route) + assert patched_deserialize_integration.call_count == 2 + patched_deserialize_integration.assert_has_calls( + [mock.call({"id": "456"}, guild_id=123), mock.call({"id": "789"}, guild_id=123)] + ) - async def test_fetch_widget(self, rest_client): - widget = StubModel(789) + async def test_fetch_widget(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): + widget = mock.Mock(guilds.GuildWidget, id=23847293) expected_route = routes.GET_GUILD_WIDGET.compile(guild=123) - rest_client._request = mock.AsyncMock(return_value={"id": "789"}) - rest_client._entity_factory.deserialize_guild_widget = mock.Mock(return_value=widget) - assert await rest_client.fetch_widget(StubModel(123)) == widget + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "789"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_widget", return_value=widget + ) as patched_deserialize_guild_widget, + ): + assert await rest_client.fetch_widget(mock_partial_guild) == widget - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_guild_widget.assert_called_once_with({"id": "789"}) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_guild_widget.assert_called_once_with({"id": "789"}) - async def test_edit_widget(self, rest_client): - widget = StubModel(456) + async def test_edit_widget( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_text_channel: channels.GuildTextChannel, + ): + widget = mock.Mock(guilds.GuildWidget, id=456) expected_route = routes.PATCH_GUILD_WIDGET.compile(guild=123) - expected_json = {"enabled": True, "channel": "456"} - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_guild_widget = mock.Mock(return_value=widget) + expected_json = {"enabled": True, "channel": "4560"} - returned = await rest_client.edit_widget( - StubModel(123), channel=StubModel(456), enabled=True, reason="this should have been enabled" - ) - assert returned is widget + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_widget", return_value=widget + ) as patched_deserialize_guild_widget, + ): + returned = await rest_client.edit_widget( + mock_partial_guild, + channel=mock_guild_text_channel, + enabled=True, + reason="this should have been enabled", + ) + assert returned is widget - rest_client._request.assert_awaited_once_with( - expected_route, json=expected_json, reason="this should have been enabled" - ) - rest_client._entity_factory.deserialize_guild_widget.assert_called_once_with({"id": "456"}) + patched__request.assert_awaited_once_with( + expected_route, json=expected_json, reason="this should have been enabled" + ) + patched_deserialize_guild_widget.assert_called_once_with({"id": "456"}) - async def test_edit_widget_when_channel_is_None(self, rest_client): - widget = StubModel(456) + async def test_edit_widget_when_channel_is_None( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + widget = mock.Mock(guilds.GuildWidget, id=456) expected_route = routes.PATCH_GUILD_WIDGET.compile(guild=123) expected_json = {"enabled": True, "channel": None} - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_guild_widget = mock.Mock(return_value=widget) - - returned = await rest_client.edit_widget( - StubModel(123), channel=None, enabled=True, reason="this should have been enabled" - ) - assert returned is widget - - rest_client._request.assert_awaited_once_with( - expected_route, json=expected_json, reason="this should have been enabled" - ) - rest_client._entity_factory.deserialize_guild_widget.assert_called_once_with({"id": "456"}) - async def test_edit_widget_without_optionals(self, rest_client): - widget = StubModel(456) - expected_route = routes.PATCH_GUILD_WIDGET.compile(guild=123) - rest_client._request = mock.AsyncMock(return_value={"id": "456"}) - rest_client._entity_factory.deserialize_guild_widget = mock.Mock(return_value=widget) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_widget", return_value=widget + ) as patched_deserialize_guild_widget, + ): + returned = await rest_client.edit_widget( + mock_partial_guild, channel=None, enabled=True, reason="this should have been enabled" + ) + assert returned is widget - assert await rest_client.edit_widget(StubModel(123)) == widget + patched__request.assert_awaited_once_with( + expected_route, json=expected_json, reason="this should have been enabled" + ) + patched_deserialize_guild_widget.assert_called_once_with({"id": "456"}) - rest_client._request.assert_awaited_once_with(expected_route, json={}, reason=undefined.UNDEFINED) - rest_client._entity_factory.deserialize_guild_widget.assert_called_once_with({"id": "456"}) + async def test_edit_widget_without_optionals( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + widget = mock.Mock(guilds.GuildWidget, id=456) + expected_route = routes.PATCH_GUILD_WIDGET.compile(guild=123) - async def test_fetch_welcome_screen(self, rest_client): - rest_client._request = mock.AsyncMock(return_value={"haha": "funny"}) - expected_route = routes.GET_GUILD_WELCOME_SCREEN.compile(guild=52341231) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "456"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_widget", return_value=widget + ) as patched_deserialize_guild_widget, + ): + assert await rest_client.edit_widget(mock_partial_guild) == widget - result = await rest_client.fetch_welcome_screen(StubModel(52341231)) - assert result is rest_client._entity_factory.deserialize_welcome_screen.return_value + patched__request.assert_awaited_once_with(expected_route, json={}, reason=undefined.UNDEFINED) + patched_deserialize_guild_widget.assert_called_once_with({"id": "456"}) - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_welcome_screen.assert_called_once_with( - rest_client._request.return_value - ) + async def test_fetch_welcome_screen( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + expected_route = routes.GET_GUILD_WELCOME_SCREEN.compile(guild=123) - async def test_edit_welcome_screen_with_optional_kwargs(self, rest_client): - mock_channel = object() - rest_client._request = mock.AsyncMock(return_value={"go": "home", "you're": "drunk"}) - expected_route = routes.PATCH_GUILD_WELCOME_SCREEN.compile(guild=54123564) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"haha": "funny"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_welcome_screen" + ) as patched_deserialize_welcome_screen, + ): + result = await rest_client.fetch_welcome_screen(mock_partial_guild) + assert result is patched_deserialize_welcome_screen.return_value - result = await rest_client.edit_welcome_screen( - StubModel(54123564), description="blam blam", enabled=True, channels=[mock_channel] - ) - assert result is rest_client._entity_factory.deserialize_welcome_screen.return_value + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_welcome_screen.assert_called_once_with(patched__request.return_value) - rest_client._request.assert_awaited_once_with( - expected_route, - json={ - "description": "blam blam", - "enabled": True, - "welcome_channels": [rest_client._entity_factory.serialize_welcome_channel.return_value], - }, - ) - rest_client._entity_factory.deserialize_welcome_screen.assert_called_once_with( - rest_client._request.return_value - ) - rest_client._entity_factory.serialize_welcome_channel.assert_called_once_with(mock_channel) + async def test_edit_welcome_screen_with_optional_kwargs( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + mock_channel = mock.Mock() + expected_route = routes.PATCH_GUILD_WELCOME_SCREEN.compile(guild=123) + + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"go": "home", "you're": "drunk"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_welcome_screen" + ) as patched_deserialize_welcome_screen, + mock.patch.object( + rest_client.entity_factory, "serialize_welcome_channel" + ) as patched_serialize_welcome_channel, + ): + result = await rest_client.edit_welcome_screen( + mock_partial_guild, description="blam blam", enabled=True, channels=[mock_channel] + ) + assert result is patched_deserialize_welcome_screen.return_value + + patched__request.assert_awaited_once_with( + expected_route, + json={ + "description": "blam blam", + "enabled": True, + "welcome_channels": [patched_serialize_welcome_channel.return_value], + }, + ) + patched_deserialize_welcome_screen.assert_called_once_with(patched__request.return_value) + patched_serialize_welcome_channel.assert_called_once_with(mock_channel) - async def test_edit_welcome_screen_with_null_kwargs(self, rest_client): - rest_client._request = mock.AsyncMock(return_value={"go": "go", "power": "rangers"}) - expected_route = routes.PATCH_GUILD_WELCOME_SCREEN.compile(guild=54123564) + async def test_edit_welcome_screen_with_null_kwargs( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + expected_route = routes.PATCH_GUILD_WELCOME_SCREEN.compile(guild=123) - result = await rest_client.edit_welcome_screen(StubModel(54123564), description=None, channels=None) - assert result is rest_client._entity_factory.deserialize_welcome_screen.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"go": "go", "power": "rangers"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_welcome_screen" + ) as patched_deserialize_welcome_screen, + mock.patch.object( + rest_client.entity_factory, "serialize_welcome_channel" + ) as patched_serialize_welcome_channel, + ): + result = await rest_client.edit_welcome_screen(mock_partial_guild, description=None, channels=None) + assert result is patched_deserialize_welcome_screen.return_value - rest_client._request.assert_awaited_once_with( - expected_route, json={"description": None, "welcome_channels": None} - ) - rest_client._entity_factory.deserialize_welcome_screen.assert_called_once_with( - rest_client._request.return_value - ) - rest_client._entity_factory.serialize_welcome_channel.assert_not_called() + patched__request.assert_awaited_once_with( + expected_route, json={"description": None, "welcome_channels": None} + ) + patched_deserialize_welcome_screen.assert_called_once_with(patched__request.return_value) + patched_serialize_welcome_channel.assert_not_called() - async def test_edit_welcome_screen_without_optional_kwargs(self, rest_client): - rest_client._request = mock.AsyncMock(return_value={"screen": "NBO"}) - expected_route = routes.PATCH_GUILD_WELCOME_SCREEN.compile(guild=54123564) + async def test_edit_welcome_screen_without_optional_kwargs( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + expected_route = routes.PATCH_GUILD_WELCOME_SCREEN.compile(guild=123) - result = await rest_client.edit_welcome_screen(StubModel(54123564)) - assert result is rest_client._entity_factory.deserialize_welcome_screen.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"screen": "NBO"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_welcome_screen" + ) as patched_deserialize_welcome_screen, + mock.patch.object(rest_client.entity_factory, "serialize_welcome_channel"), + ): + result = await rest_client.edit_welcome_screen(mock_partial_guild) + assert result is patched_deserialize_welcome_screen.return_value - rest_client._request.assert_awaited_once_with(expected_route, json={}) - rest_client._entity_factory.deserialize_welcome_screen.assert_called_once_with( - rest_client._request.return_value - ) + patched__request.assert_awaited_once_with(expected_route, json={}) + patched_deserialize_welcome_screen.assert_called_once_with(patched__request.return_value) - async def test_fetch_guild_onboarding(self, rest_client): + async def test_fetch_guild_onboarding(self, rest_client: rest.RESTClientImpl): GUILD = 123 rest_client._request = mock.AsyncMock(return_value={"haha": "funny"}) expected_route = routes.GET_GUILD_ONBOARDING.compile(guild=GUILD) - result = await rest_client.fetch_guild_onboarding(GUILD) - assert result is rest_client._entity_factory.deserialize_guild_onboarding.return_value + + with mock.patch.object( + rest_client._entity_factory, "deserialize_guild_onboarding" + ) as patched_deserialize_guild_onboarding: + result = await rest_client.fetch_guild_onboarding(GUILD) + + assert result is patched_deserialize_guild_onboarding.return_value rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_guild_onboarding.assert_called_once_with( - rest_client._request.return_value - ) + patched_deserialize_guild_onboarding.assert_called_once_with(rest_client._request.return_value) - async def test_edit_guild_onboarding(self, rest_client): + async def test_edit_guild_onboarding(self, rest_client: rest.RESTClientImpl): GUILD = 123 expected_route = routes.PUT_GUILD_ONBOARDING.compile(guild=GUILD) rest_client._request = mock.AsyncMock(return_value={"haha": "funny"}) - result = await rest_client.edit_guild_onboarding( - guild=GUILD, - default_channel_ids=[456], - enabled=True, - mode=guilds.GuildOnboardingMode.ONBOARDING_DEFAULT, - prompts=[ - special_endpoints.GuildOnboardingPromptBuilder( - title="Test Title", single_select=True, in_onboarding=True, required=True - ), - special_endpoints.GuildOnboardingPromptBuilder( - title="Test Title", single_select=True, in_onboarding=True, required=True - ).set_id(187), - ], - reason="test reason", - ) - assert result is rest_client._entity_factory.deserialize_guild_onboarding.return_value + + with mock.patch.object( + rest_client._entity_factory, "deserialize_guild_onboarding" + ) as patched_deserialize_guild_onboarding: + result = await rest_client.edit_guild_onboarding( + guild=GUILD, + default_channel_ids=[456], + enabled=True, + mode=guilds.GuildOnboardingMode.ONBOARDING_DEFAULT, + prompts=[ + special_endpoints.GuildOnboardingPromptBuilder( + title="Test Title", single_select=True, in_onboarding=True, required=True + ), + special_endpoints.GuildOnboardingPromptBuilder( + title="Test Title", single_select=True, in_onboarding=True, required=True + ).set_id(187), + ], + reason="test reason", + ) + + assert result is patched_deserialize_guild_onboarding.return_value rest_client._request.assert_awaited_once_with( expected_route, json={ @@ -6076,501 +7883,762 @@ async def test_edit_guild_onboarding(self, rest_client): }, reason="test reason", ) - rest_client._entity_factory.deserialize_guild_onboarding.assert_called_once_with( - rest_client._request.return_value - ) + patched_deserialize_guild_onboarding.assert_called_once_with(rest_client._request.return_value) - async def test_fetch_vanity_url(self, rest_client): - vanity_url = StubModel(789) + async def test_fetch_vanity_url(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): + vanity_url = make_vanity_url("hikari") expected_route = routes.GET_GUILD_VANITY_URL.compile(guild=123) - rest_client._request = mock.AsyncMock(return_value={"id": "789"}) - rest_client._entity_factory.deserialize_vanity_url = mock.Mock(return_value=vanity_url) - assert await rest_client.fetch_vanity_url(StubModel(123)) == vanity_url + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "789"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_vanity_url", return_value=vanity_url + ) as patched_deserialize_vanity_url, + ): + assert await rest_client.fetch_vanity_url(mock_partial_guild) == vanity_url - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_vanity_url.assert_called_once_with({"id": "789"}) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_vanity_url.assert_called_once_with({"id": "789"}) - async def test_fetch_template(self, rest_client): + async def test_fetch_template(self, rest_client: rest.RESTClientImpl): expected_route = routes.GET_TEMPLATE.compile(template="kodfskoijsfikoiok") - rest_client._request = mock.AsyncMock(return_value={"code": "KSDAOKSDKIO"}) - result = await rest_client.fetch_template("kodfskoijsfikoiok") - assert result is rest_client._entity_factory.deserialize_template.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"code": "KSDAOKSDKIO"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_template") as patched_deserialize_template, + ): + result = await rest_client.fetch_template("kodfskoijsfikoiok") + assert result is patched_deserialize_template.return_value - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_template.assert_called_once_with({"code": "KSDAOKSDKIO"}) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_template.assert_called_once_with({"code": "KSDAOKSDKIO"}) - async def test_fetch_guild_templates(self, rest_client): - expected_route = routes.GET_GUILD_TEMPLATES.compile(guild=43123123) - rest_client._request = mock.AsyncMock(return_value=[{"code": "jirefu98ai90w"}]) + async def test_fetch_guild_templates( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + expected_route = routes.GET_GUILD_TEMPLATES.compile(guild=123) - result = await rest_client.fetch_guild_templates(StubModel(43123123)) - assert result == [rest_client._entity_factory.deserialize_template.return_value] + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"code": "jirefu98ai90w"}] + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_template") as patched_deserialize_template, + ): + result = await rest_client.fetch_guild_templates(mock_partial_guild) + assert result == [patched_deserialize_template.return_value] - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_template.assert_called_once_with({"code": "jirefu98ai90w"}) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_template.assert_called_once_with({"code": "jirefu98ai90w"}) - async def test_sync_guild_template(self, rest_client): - expected_route = routes.PUT_GUILD_TEMPLATE.compile(guild=431231, template="oeroeoeoeoeo") - rest_client._request = mock.AsyncMock(return_value={"code": "ldsaosdokskdoa"}) + async def test_sync_guild_template(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): + expected_route = routes.PUT_GUILD_TEMPLATE.compile(guild=123, template="oeroeoeoeoeo") - result = await rest_client.sync_guild_template(StubModel(431231), template="oeroeoeoeoeo") - assert result is rest_client._entity_factory.deserialize_template.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"code": "ldsaosdokskdoa"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_template") as patched_deserialize_template, + ): + result = await rest_client.sync_guild_template(mock_partial_guild, template="oeroeoeoeoeo") + assert result is patched_deserialize_template.return_value - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_template.assert_called_once_with({"code": "ldsaosdokskdoa"}) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_template.assert_called_once_with({"code": "ldsaosdokskdoa"}) - async def test_create_template_without_description(self, rest_client): - expected_routes = routes.POST_GUILD_TEMPLATES.compile(guild=1235432) - rest_client._request = mock.AsyncMock(return_value={"code": "94949sdfkds"}) + async def test_create_template_without_description( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + expected_routes = routes.POST_GUILD_TEMPLATES.compile(guild=123) - result = await rest_client.create_template(StubModel(1235432), "OKOKOK") - assert result is rest_client._entity_factory.deserialize_template.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"code": "94949sdfkds"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_template") as patched_deserialize_template, + ): + result = await rest_client.create_template(mock_partial_guild, "OKOKOK") + assert result is patched_deserialize_template.return_value - rest_client._request.assert_awaited_once_with(expected_routes, json={"name": "OKOKOK"}) - rest_client._entity_factory.deserialize_template.assert_called_once_with({"code": "94949sdfkds"}) + patched__request.assert_awaited_once_with(expected_routes, json={"name": "OKOKOK"}) + patched_deserialize_template.assert_called_once_with({"code": "94949sdfkds"}) - async def test_create_template_with_description(self, rest_client): - expected_route = routes.POST_GUILD_TEMPLATES.compile(guild=4123123) - rest_client._request = mock.AsyncMock(return_value={"code": "76345345"}) + async def test_create_template_with_description( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + expected_route = routes.POST_GUILD_TEMPLATES.compile(guild=123) - result = await rest_client.create_template(StubModel(4123123), "33", description="43123123") - assert result is rest_client._entity_factory.deserialize_template.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"code": "76345345"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_template") as patched_deserialize_template, + ): + result = await rest_client.create_template(mock_partial_guild, "33", description="43123123") + assert result is patched_deserialize_template.return_value - rest_client._request.assert_awaited_once_with(expected_route, json={"name": "33", "description": "43123123"}) - rest_client._entity_factory.deserialize_template.assert_called_once_with({"code": "76345345"}) + patched__request.assert_awaited_once_with(expected_route, json={"name": "33", "description": "43123123"}) + patched_deserialize_template.assert_called_once_with({"code": "76345345"}) - async def test_edit_template_without_optionals(self, rest_client): - expected_route = routes.PATCH_GUILD_TEMPLATE.compile(guild=3412312, template="oeodsosda") - rest_client._request = mock.AsyncMock(return_value={"code": "9493293ikiwopop"}) + async def test_edit_template_without_optionals( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + expected_route = routes.PATCH_GUILD_TEMPLATE.compile(guild=123, template="oeodsosda") - result = await rest_client.edit_template(StubModel(3412312), "oeodsosda") - assert result is rest_client._entity_factory.deserialize_template.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"code": "9493293ikiwopop"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_template") as patched_deserialize_template, + ): + result = await rest_client.edit_template(mock_partial_guild, "oeodsosda") + assert result is patched_deserialize_template.return_value - rest_client._request.assert_awaited_once_with(expected_route, json={}) - rest_client._entity_factory.deserialize_template.assert_called_once_with({"code": "9493293ikiwopop"}) + patched__request.assert_awaited_once_with(expected_route, json={}) + patched_deserialize_template.assert_called_once_with({"code": "9493293ikiwopop"}) - async def test_edit_template_with_optionals(self, rest_client): - expected_route = routes.PATCH_GUILD_TEMPLATE.compile(guild=34123122, template="oeodsosda2") - rest_client._request = mock.AsyncMock(return_value={"code": "9493293ikiwopop"}) + async def test_edit_template_with_optionals( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + expected_route = routes.PATCH_GUILD_TEMPLATE.compile(guild=123, template="oeodsosda2") - result = await rest_client.edit_template( - StubModel(34123122), "oeodsosda2", name="new name", description="i'm lazy" - ) - assert result is rest_client._entity_factory.deserialize_template.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"code": "9493293ikiwopop"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_template") as patched_deserialize_template, + ): + result = await rest_client.edit_template( + mock_partial_guild, "oeodsosda2", name="new name", description="i'm lazy" + ) + assert result is patched_deserialize_template.return_value - rest_client._request.assert_awaited_once_with( - expected_route, json={"name": "new name", "description": "i'm lazy"} - ) - rest_client._entity_factory.deserialize_template.assert_called_once_with({"code": "9493293ikiwopop"}) + patched__request.assert_awaited_once_with( + expected_route, json={"name": "new name", "description": "i'm lazy"} + ) + patched_deserialize_template.assert_called_once_with({"code": "9493293ikiwopop"}) - async def test_delete_template(self, rest_client): - expected_route = routes.DELETE_GUILD_TEMPLATE.compile(guild=3123123, template="eoiesri9er99") - rest_client._request = mock.AsyncMock(return_value={"code": "oeoekfgkdkf"}) + async def test_delete_template(self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild): + expected_route = routes.DELETE_GUILD_TEMPLATE.compile(guild=123, template="eoiesri9er99") - result = await rest_client.delete_template(StubModel(3123123), "eoiesri9er99") - assert result is rest_client._entity_factory.deserialize_template.return_value + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"code": "oeoekfgkdkf"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_template") as patched_deserialize_template, + ): + result = await rest_client.delete_template(mock_partial_guild, "eoiesri9er99") + assert result is patched_deserialize_template.return_value - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_template.assert_called_once_with({"code": "oeoekfgkdkf"}) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_template.assert_called_once_with({"code": "oeoekfgkdkf"}) - async def test_fetch_application_command_with_guild(self, rest_client): - expected_route = routes.GET_APPLICATION_GUILD_COMMAND.compile(application=32154, guild=5312312, command=42123) - rest_client._request = mock.AsyncMock(return_value={"id": "424242"}) + async def test_fetch_application_command_with_guild( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_application: applications.Application, + mock_partial_command: commands.PartialCommand, + ): + expected_route = routes.GET_APPLICATION_GUILD_COMMAND.compile(application=111, guild=123, command=666) - result = await rest_client.fetch_application_command(StubModel(32154), StubModel(42123), StubModel(5312312)) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "424242"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_command") as patched_deserialize_command, + ): + result = await rest_client.fetch_application_command( + mock_application, mock_partial_command, mock_partial_guild + ) - assert result is rest_client._entity_factory.deserialize_command.return_value - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_command.assert_called_once_with( - rest_client._request.return_value, guild_id=5312312 - ) + assert result is patched_deserialize_command.return_value + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_command.assert_called_once_with(patched__request.return_value, guild_id=123) - async def test_fetch_application_command_without_guild(self, rest_client): - expected_route = routes.GET_APPLICATION_COMMAND.compile(application=32154, command=42123) - rest_client._request = mock.AsyncMock(return_value={"id": "424242"}) + async def test_fetch_application_command_without_guild( + self, + rest_client: rest.RESTClientImpl, + mock_application: applications.Application, + mock_partial_command: commands.PartialCommand, + ): + expected_route = routes.GET_APPLICATION_COMMAND.compile(application=111, command=666) - result = await rest_client.fetch_application_command(StubModel(32154), StubModel(42123)) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "424242"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_command") as patched_deserialize_command, + ): + result = await rest_client.fetch_application_command(mock_application, mock_partial_command) - assert result is rest_client._entity_factory.deserialize_command.return_value - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_command.assert_called_once_with( - rest_client._request.return_value, guild_id=None - ) + assert result is patched_deserialize_command.return_value + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_command.assert_called_once_with(patched__request.return_value, guild_id=None) - async def test_fetch_application_commands_with_guild(self, rest_client): - expected_route = routes.GET_APPLICATION_GUILD_COMMANDS.compile(application=54123, guild=7623423) - rest_client._request = mock.AsyncMock(return_value=[{"id": "34512312"}]) + async def test_fetch_application_commands_with_guild( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_application: applications.Application, + ): + expected_route = routes.GET_APPLICATION_GUILD_COMMANDS.compile(application=111, guild=123) - result = await rest_client.fetch_application_commands(StubModel(54123), StubModel(7623423)) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "34512312"}] + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_command") as patched_deserialize_command, + ): + result = await rest_client.fetch_application_commands(mock_application, mock_partial_guild) - assert result == [rest_client._entity_factory.deserialize_command.return_value] - rest_client._request.assert_awaited_once_with(expected_route, query={"with_localizations": "true"}) - rest_client._entity_factory.deserialize_command.assert_called_once_with({"id": "34512312"}, guild_id=7623423) + assert result == [patched_deserialize_command.return_value] + patched__request.assert_awaited_once_with(expected_route, query={"with_localizations": "true"}) + patched_deserialize_command.assert_called_once_with({"id": "34512312"}, guild_id=123) - async def test_fetch_application_commands_without_guild(self, rest_client): - expected_route = routes.GET_APPLICATION_COMMANDS.compile(application=54123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "34512312"}]) + async def test_fetch_application_commands_without_guild( + self, rest_client: rest.RESTClientImpl, mock_application: applications.Application + ): + expected_route = routes.GET_APPLICATION_COMMANDS.compile(application=111) - result = await rest_client.fetch_application_commands(StubModel(54123)) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "34512312"}] + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_command") as patched_deserialize_command, + ): + result = await rest_client.fetch_application_commands(mock_application) - assert result == [rest_client._entity_factory.deserialize_command.return_value] - rest_client._request.assert_awaited_once_with(expected_route, query={"with_localizations": "true"}) - rest_client._entity_factory.deserialize_command.assert_called_once_with({"id": "34512312"}, guild_id=None) + assert result == [patched_deserialize_command.return_value] + patched__request.assert_awaited_once_with(expected_route, query={"with_localizations": "true"}) + patched_deserialize_command.assert_called_once_with({"id": "34512312"}, guild_id=None) - async def test_fetch_application_commands_ignores_unknown_command_types(self, rest_client): + async def test_fetch_application_commands_ignores_unknown_command_types( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_application: applications.Application, + ): mock_command = mock.Mock() - expected_route = routes.GET_APPLICATION_GUILD_COMMANDS.compile(application=54123, guild=432234) - rest_client._entity_factory.deserialize_command.side_effect = [ - errors.UnrecognisedEntityError("eep"), - mock_command, - ] - rest_client._request = mock.AsyncMock(return_value=[{"id": "541234"}, {"id": "553234"}]) + expected_route = routes.GET_APPLICATION_GUILD_COMMANDS.compile(application=111, guild=123) - result = await rest_client.fetch_application_commands(StubModel(54123), StubModel(432234)) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "541234"}, {"id": "553234"}] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, + "deserialize_command", + side_effect=[errors.UnrecognisedEntityError("eep"), mock_command], + ) as patched_deserialize_command, + ): + result = await rest_client.fetch_application_commands(mock_application, mock_partial_guild) - assert result == [mock_command] - rest_client._request.assert_awaited_once_with(expected_route, query={"with_localizations": "true"}) - rest_client._entity_factory.deserialize_command.assert_has_calls( - [mock.call({"id": "541234"}, guild_id=432234), mock.call({"id": "553234"}, guild_id=432234)] - ) + assert result == [mock_command] + patched__request.assert_awaited_once_with(expected_route, query={"with_localizations": "true"}) + patched_deserialize_command.assert_has_calls( + [mock.call({"id": "541234"}, guild_id=123), mock.call({"id": "553234"}, guild_id=123)] + ) - async def test__create_application_command_with_optionals(self, rest_client: rest.RESTClientImpl): - expected_route = routes.POST_APPLICATION_GUILD_COMMAND.compile(application=4332123, guild=653452134) - rest_client._request = mock.AsyncMock(return_value={"id": "29393939"}) - mock_option = object() + async def test__create_application_command_with_optionals( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_application: applications.Application, + ): + expected_route = routes.POST_APPLICATION_GUILD_COMMAND.compile(application=111, guild=123) + mock_option = mock.Mock() - result = await rest_client._create_application_command( - application=StubModel(4332123), - type=100, - name="okokok", - description="not ok anymore", - guild=StubModel(653452134), - options=[mock_option], - default_member_permissions=permissions.Permissions.ADMINISTRATOR, - nsfw=True, - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "29393939"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "serialize_command_option" + ) as patched_serialize_command_option, + ): + result = await rest_client._create_application_command( + application=mock_application, + type=100, + name="okokok", + description="not ok anymore", + guild=mock_partial_guild, + options=[mock_option], + default_member_permissions=permissions.Permissions.ADMINISTRATOR, + nsfw=True, + ) - assert result is rest_client._request.return_value - rest_client._entity_factory.serialize_command_option.assert_called_once_with(mock_option) - rest_client._request.assert_awaited_once_with( - expected_route, - json={ - "type": 100, - "name": "okokok", - "description": "not ok anymore", - "options": [rest_client._entity_factory.serialize_command_option.return_value], - "default_member_permissions": 8, - "nsfw": True, - }, - ) + assert result is patched__request.return_value + patched_serialize_command_option.assert_called_once_with(mock_option) + patched__request.assert_awaited_once_with( + expected_route, + json={ + "type": 100, + "name": "okokok", + "description": "not ok anymore", + "options": [patched_serialize_command_option.return_value], + "default_member_permissions": 8, + "nsfw": True, + }, + ) - async def test_create_application_command_without_optionals(self, rest_client: rest.RESTClientImpl): - expected_route = routes.POST_APPLICATION_COMMAND.compile(application=4332123) - rest_client._request = mock.AsyncMock(return_value={"id": "29393939"}) + async def test_create_application_command_without_optionals( + self, rest_client: rest.RESTClientImpl, mock_application: applications.Application + ): + expected_route = routes.POST_APPLICATION_COMMAND.compile(application=111) - result = await rest_client._create_application_command( - application=StubModel(4332123), type=100, name="okokok", description="not ok anymore" - ) + with mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "29393939"} + ) as patched__request: + result = await rest_client._create_application_command( + application=mock_application, type=100, name="okokok", description="not ok anymore" + ) - assert result is rest_client._request.return_value - rest_client._request.assert_awaited_once_with( - expected_route, json={"type": 100, "name": "okokok", "description": "not ok anymore"} - ) + assert result is patched__request.return_value + patched__request.assert_awaited_once_with( + expected_route, json={"type": 100, "name": "okokok", "description": "not ok anymore"} + ) async def test__create_application_command_standardizes_default_member_permissions( - self, rest_client: rest.RESTClientImpl + self, rest_client: rest.RESTClientImpl, mock_application: applications.Application ): - expected_route = routes.POST_APPLICATION_COMMAND.compile(application=4332123) - rest_client._request = mock.AsyncMock(return_value={"id": "29393939"}) + expected_route = routes.POST_APPLICATION_COMMAND.compile(application=111) - result = await rest_client._create_application_command( - application=StubModel(4332123), - type=100, - name="okokok", - description="not ok anymore", - default_member_permissions=permissions.Permissions.NONE, - ) + with mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "29393939"} + ) as patched__request: + result = await rest_client._create_application_command( + application=mock_application, + type=100, + name="okokok", + description="not ok anymore", + default_member_permissions=permissions.Permissions.NONE, + ) - assert result is rest_client._request.return_value - rest_client._request.assert_awaited_once_with( - expected_route, - json={"type": 100, "name": "okokok", "description": "not ok anymore", "default_member_permissions": None}, - ) - - async def test_create_slash_command(self, rest_client: rest.RESTClientImpl): - rest_client._create_application_command = mock.AsyncMock() - mock_options = object() - mock_application = StubModel(4332123) - mock_guild = StubModel(123123123) - - result = await rest_client.create_slash_command( - mock_application, - "okokok", - "not ok anymore", - guild=mock_guild, - options=mock_options, - name_localizations={locales.Locale.TR: "hhh"}, - description_localizations={locales.Locale.TR: "jello"}, - default_member_permissions=permissions.Permissions.ADMINISTRATOR, - nsfw=True, - ) - - assert result is rest_client._entity_factory.deserialize_slash_command.return_value - rest_client._entity_factory.deserialize_slash_command.assert_called_once_with( - rest_client._create_application_command.return_value, guild_id=123123123 - ) - rest_client._create_application_command.assert_awaited_once_with( - application=mock_application, - type=commands.CommandType.SLASH, - name="okokok", - description="not ok anymore", - guild=mock_guild, - options=mock_options, - name_localizations={"tr": "hhh"}, - description_localizations={"tr": "jello"}, - default_member_permissions=permissions.Permissions.ADMINISTRATOR, - nsfw=True, - ) - - async def test_create_context_menu_command(self, rest_client: rest.RESTClientImpl): - rest_client._create_application_command = mock.AsyncMock() - mock_application = StubModel(4332123) - mock_guild = StubModel(123123123) - - result = await rest_client.create_context_menu_command( - mock_application, - commands.CommandType.USER, - "okokok", - guild=mock_guild, - default_member_permissions=permissions.Permissions.ADMINISTRATOR, - nsfw=True, - name_localizations={locales.Locale.TR: "hhh"}, - ) - - assert result is rest_client._entity_factory.deserialize_context_menu_command.return_value - rest_client._entity_factory.deserialize_context_menu_command.assert_called_once_with( - rest_client._create_application_command.return_value, guild_id=123123123 - ) - rest_client._create_application_command.assert_awaited_once_with( - application=mock_application, - type=commands.CommandType.USER, - name="okokok", - guild=mock_guild, - default_member_permissions=permissions.Permissions.ADMINISTRATOR, - nsfw=True, - name_localizations={"tr": "hhh"}, - ) - - async def test_set_application_commands_with_guild(self, rest_client): - expected_route = routes.PUT_APPLICATION_GUILD_COMMANDS.compile(application=4321231, guild=6543234) - rest_client._request = mock.AsyncMock(return_value=[{"id": "9459329932"}]) + assert result is patched__request.return_value + patched__request.assert_awaited_once_with( + expected_route, + json={ + "type": 100, + "name": "okokok", + "description": "not ok anymore", + "default_member_permissions": None, + }, + ) + + async def test_create_slash_command( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_application: applications.Application, + ): + mock_options = mock.Mock() + + with ( + mock.patch.object( + rest_client.entity_factory, "deserialize_slash_command" + ) as patched_deserialize_slash_command, + mock.patch.object(rest_client, "_create_application_command") as patched__create_application_command, + ): + result = await rest_client.create_slash_command( + mock_application, + "okokok", + "not ok anymore", + guild=mock_partial_guild, + options=mock_options, + name_localizations={locales.Locale.TR: "hhh"}, + description_localizations={locales.Locale.TR: "jello"}, + default_member_permissions=permissions.Permissions.ADMINISTRATOR, + nsfw=True, + ) + + assert result is patched_deserialize_slash_command.return_value + patched_deserialize_slash_command.assert_called_once_with( + patched__create_application_command.return_value, guild_id=123 + ) + patched__create_application_command.assert_awaited_once_with( + application=mock_application, + type=commands.CommandType.SLASH, + name="okokok", + description="not ok anymore", + guild=mock_partial_guild, + options=mock_options, + name_localizations={"tr": "hhh"}, + description_localizations={"tr": "jello"}, + default_member_permissions=permissions.Permissions.ADMINISTRATOR, + nsfw=True, + ) + + async def test_create_context_menu_command( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_application: applications.Application, + ): + with ( + mock.patch.object( + rest_client.entity_factory, "deserialize_context_menu_command" + ) as patched_deserialize_context_menu_command, + mock.patch.object(rest_client, "_create_application_command") as patched__create_application_command, + ): + result = await rest_client.create_context_menu_command( + mock_application, + commands.CommandType.USER, + "okokok", + guild=mock_partial_guild, + default_member_permissions=permissions.Permissions.ADMINISTRATOR, + nsfw=True, + name_localizations={locales.Locale.TR: "hhh"}, + ) + + assert result is patched_deserialize_context_menu_command.return_value + patched_deserialize_context_menu_command.assert_called_once_with( + patched__create_application_command.return_value, guild_id=123 + ) + patched__create_application_command.assert_awaited_once_with( + application=mock_application, + type=commands.CommandType.USER, + name="okokok", + guild=mock_partial_guild, + default_member_permissions=permissions.Permissions.ADMINISTRATOR, + nsfw=True, + name_localizations={"tr": "hhh"}, + ) + + async def test_set_application_commands_with_guild( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_application: applications.Application, + ): + expected_route = routes.PUT_APPLICATION_GUILD_COMMANDS.compile(application=111, guild=123) mock_command_builder = mock.Mock() - result = await rest_client.set_application_commands( - StubModel(4321231), [mock_command_builder], StubModel(6543234) - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "9459329932"}] + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_command") as patched_deserialize_command, + ): + result = await rest_client.set_application_commands( + mock_application, [mock_command_builder], mock_partial_guild + ) - assert result == [rest_client._entity_factory.deserialize_command.return_value] - rest_client._entity_factory.deserialize_command.assert_called_once_with({"id": "9459329932"}, guild_id=6543234) - rest_client._request.assert_awaited_once_with(expected_route, json=[mock_command_builder.build.return_value]) - mock_command_builder.build.assert_called_once_with(rest_client._entity_factory) + assert result == [patched_deserialize_command.return_value] + patched_deserialize_command.assert_called_once_with({"id": "9459329932"}, guild_id=123) + patched__request.assert_awaited_once_with(expected_route, json=[mock_command_builder.build.return_value]) + mock_command_builder.build.assert_called_once_with(rest_client.entity_factory) - async def test_set_application_commands_without_guild(self, rest_client): - expected_route = routes.PUT_APPLICATION_COMMANDS.compile(application=4321231) - rest_client._request = mock.AsyncMock(return_value=[{"id": "9459329932"}]) + async def test_set_application_commands_without_guild( + self, rest_client: rest.RESTClientImpl, mock_application: applications.Application + ): + expected_route = routes.PUT_APPLICATION_COMMANDS.compile(application=111) mock_command_builder = mock.Mock() - result = await rest_client.set_application_commands(StubModel(4321231), [mock_command_builder]) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "9459329932"}] + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_command") as patched_deserialize_command, + ): + result = await rest_client.set_application_commands(mock_application, [mock_command_builder]) - assert result == [rest_client._entity_factory.deserialize_command.return_value] - rest_client._entity_factory.deserialize_command.assert_called_once_with({"id": "9459329932"}, guild_id=None) - rest_client._request.assert_awaited_once_with(expected_route, json=[mock_command_builder.build.return_value]) - mock_command_builder.build.assert_called_once_with(rest_client._entity_factory) + assert result == [patched_deserialize_command.return_value] + patched_deserialize_command.assert_called_once_with({"id": "9459329932"}, guild_id=None) + patched__request.assert_awaited_once_with(expected_route, json=[mock_command_builder.build.return_value]) + mock_command_builder.build.assert_called_once_with(rest_client.entity_factory) - async def test_set_application_commands_without_guild_handles_unknown_command_types(self, rest_client): + async def test_set_application_commands_without_guild_handles_unknown_command_types( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_application: applications.Application, + ): mock_command = mock.Mock() - expected_route = routes.PUT_APPLICATION_GUILD_COMMANDS.compile(application=532123123, guild=453123) - rest_client._entity_factory.deserialize_command.side_effect = [ - errors.UnrecognisedEntityError("meow"), - mock_command, - ] - rest_client._request = mock.AsyncMock(return_value=[{"id": "435765"}, {"id": "4949493933"}]) + expected_route = routes.PUT_APPLICATION_GUILD_COMMANDS.compile(application=111, guild=123) mock_command_builder = mock.Mock() - result = await rest_client.set_application_commands( - StubModel(532123123), [mock_command_builder], StubModel(453123) - ) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value=[{"id": "435765"}, {"id": "4949493933"}], + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, + "deserialize_command", + side_effect=[errors.UnrecognisedEntityError("meow"), mock_command], + ) as patched_deserialize_command, + ): + result = await rest_client.set_application_commands( + mock_application, [mock_command_builder], mock_partial_guild + ) - assert result == [mock_command] - rest_client._entity_factory.deserialize_command.assert_has_calls( - [mock.call({"id": "435765"}, guild_id=453123), mock.call({"id": "4949493933"}, guild_id=453123)] - ) - rest_client._request.assert_awaited_once_with(expected_route, json=[mock_command_builder.build.return_value]) - mock_command_builder.build.assert_called_once_with(rest_client._entity_factory) + assert result == [mock_command] + patched_deserialize_command.assert_has_calls( + [mock.call({"id": "435765"}, guild_id=123), mock.call({"id": "4949493933"}, guild_id=123)] + ) + patched__request.assert_awaited_once_with(expected_route, json=[mock_command_builder.build.return_value]) + mock_command_builder.build.assert_called_once_with(rest_client.entity_factory) - async def test_edit_application_command_with_optionals(self, rest_client): - expected_route = routes.PATCH_APPLICATION_GUILD_COMMAND.compile( - application=1235432, guild=54123, command=3451231 - ) - rest_client._request = mock.AsyncMock(return_value={"id": "94594994"}) - mock_option = object() + async def test_edit_application_command_with_optionals( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_application: applications.Application, + mock_partial_command: commands.PartialCommand, + ): + expected_route = routes.PATCH_APPLICATION_GUILD_COMMAND.compile(application=111, guild=123, command=666) + mock_option = mock.Mock() - result = await rest_client.edit_application_command( - StubModel(1235432), - StubModel(3451231), - StubModel(54123), - name="ok sis", - description="cancelled", - options=[mock_option], - default_member_permissions=permissions.Permissions.BAN_MEMBERS, - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "94594994"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_command") as patched_deserialize_command, + mock.patch.object( + rest_client.entity_factory, "serialize_command_option" + ) as patched_serialize_command_option, + ): + result = await rest_client.edit_application_command( + mock_application, + mock_partial_command, + mock_partial_guild, + name="ok sis", + description="cancelled", + options=[mock_option], + default_member_permissions=permissions.Permissions.BAN_MEMBERS, + ) - assert result is rest_client._entity_factory.deserialize_command.return_value - rest_client._entity_factory.deserialize_command.assert_called_once_with( - rest_client._request.return_value, guild_id=54123 - ) - rest_client._request.assert_awaited_once_with( - expected_route, - json={ - "name": "ok sis", - "description": "cancelled", - "options": [rest_client._entity_factory.serialize_command_option.return_value], - "default_member_permissions": 4, - }, - ) - rest_client._entity_factory.serialize_command_option.assert_called_once_with(mock_option) + assert result is patched_deserialize_command.return_value + patched_deserialize_command.assert_called_once_with(patched__request.return_value, guild_id=123) + patched__request.assert_awaited_once_with( + expected_route, + json={ + "name": "ok sis", + "description": "cancelled", + "options": [patched_serialize_command_option.return_value], + "default_member_permissions": 4, + }, + ) + patched_serialize_command_option.assert_called_once_with(mock_option) - async def test_edit_application_command_without_optionals(self, rest_client): - expected_route = routes.PATCH_APPLICATION_COMMAND.compile(application=1235432, command=3451231) - rest_client._request = mock.AsyncMock(return_value={"id": "94594994"}) + async def test_edit_application_command_without_optionals( + self, + rest_client: rest.RESTClientImpl, + mock_application: applications.Application, + mock_partial_command: commands.PartialCommand, + ): + expected_route = routes.PATCH_APPLICATION_COMMAND.compile(application=111, command=666) - result = await rest_client.edit_application_command(StubModel(1235432), StubModel(3451231)) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "94594994"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_command") as patched_deserialize_command, + ): + result = await rest_client.edit_application_command(mock_application, mock_partial_command) - assert result is rest_client._entity_factory.deserialize_command.return_value - rest_client._entity_factory.deserialize_command.assert_called_once_with( - rest_client._request.return_value, guild_id=None - ) - rest_client._request.assert_awaited_once_with(expected_route, json={}) + assert result is patched_deserialize_command.return_value + patched_deserialize_command.assert_called_once_with(patched__request.return_value, guild_id=None) + patched__request.assert_awaited_once_with(expected_route, json={}) async def test_edit_application_command_standardizes_default_member_permissions( - self, rest_client: rest.RESTClientImpl + self, + rest_client: rest.RESTClientImpl, + mock_application: applications.Application, + mock_partial_command: commands.PartialCommand, ): - expected_route = routes.PATCH_APPLICATION_COMMAND.compile(application=1235432, command=3451231) - rest_client._request = mock.AsyncMock(return_value={"id": "94594994"}) + expected_route = routes.PATCH_APPLICATION_COMMAND.compile(application=111, command=666) - result = await rest_client.edit_application_command( - StubModel(1235432), StubModel(3451231), default_member_permissions=permissions.Permissions.NONE - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "94594994"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_command") as patched_deserialize_command, + ): + result = await rest_client.edit_application_command( + mock_application, mock_partial_command, default_member_permissions=permissions.Permissions.NONE + ) - assert result is rest_client._entity_factory.deserialize_command.return_value - rest_client._entity_factory.deserialize_command.assert_called_once_with( - rest_client._request.return_value, guild_id=None - ) - rest_client._request.assert_awaited_once_with(expected_route, json={"default_member_permissions": None}) + assert result is patched_deserialize_command.return_value + patched_deserialize_command.assert_called_once_with(patched__request.return_value, guild_id=None) + patched__request.assert_awaited_once_with(expected_route, json={"default_member_permissions": None}) - async def test_delete_application_command_with_guild(self, rest_client): - expected_route = routes.DELETE_APPLICATION_GUILD_COMMAND.compile( - application=312312, command=65234323, guild=5421312 - ) - rest_client._request = mock.AsyncMock() + async def test_delete_application_command_with_guild( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_application: applications.Application, + mock_partial_command: commands.PartialCommand, + ): + expected_route = routes.DELETE_APPLICATION_GUILD_COMMAND.compile(application=111, command=666, guild=123) - await rest_client.delete_application_command(StubModel(312312), StubModel(65234323), StubModel(5421312)) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_application_command(mock_application, mock_partial_command, mock_partial_guild) - rest_client._request.assert_awaited_once_with(expected_route) + patched__request.assert_awaited_once_with(expected_route) - async def test_delete_application_command_without_guild(self, rest_client): - expected_route = routes.DELETE_APPLICATION_COMMAND.compile(application=312312, command=65234323) - rest_client._request = mock.AsyncMock() + async def test_delete_application_command_without_guild( + self, + rest_client: rest.RESTClientImpl, + mock_application: applications.Application, + mock_partial_command: commands.PartialCommand, + ): + expected_route = routes.DELETE_APPLICATION_COMMAND.compile(application=111, command=666) - await rest_client.delete_application_command(StubModel(312312), StubModel(65234323)) + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_application_command(mock_application, mock_partial_command) - rest_client._request.assert_awaited_once_with(expected_route) + patched__request.assert_awaited_once_with(expected_route) - async def test_fetch_application_guild_commands_permissions(self, rest_client): - expected_route = routes.GET_APPLICATION_GUILD_COMMANDS_PERMISSIONS.compile(application=321431, guild=54123) - mock_command_payload = object() - rest_client._request = mock.AsyncMock(return_value=[mock_command_payload]) + async def test_fetch_application_guild_commands_permissions( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_application: applications.Application, + ): + expected_route = routes.GET_APPLICATION_GUILD_COMMANDS_PERMISSIONS.compile(application=111, guild=123) + mock_command_payload = mock.Mock() - result = await rest_client.fetch_application_guild_commands_permissions(321431, 54123) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[mock_command_payload] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_command_permissions" + ) as patched_deserialize_guild_command_permissions, + ): + result = await rest_client.fetch_application_guild_commands_permissions( + mock_application, mock_partial_guild + ) - assert result == [rest_client._entity_factory.deserialize_guild_command_permissions.return_value] - rest_client._entity_factory.deserialize_guild_command_permissions.assert_called_once_with(mock_command_payload) - rest_client._request.assert_awaited_once_with(expected_route) + assert result == [patched_deserialize_guild_command_permissions.return_value] + patched_deserialize_guild_command_permissions.assert_called_once_with(mock_command_payload) + patched__request.assert_awaited_once_with(expected_route) - async def test_fetch_application_command_permissions(self, rest_client): - expected_route = routes.GET_APPLICATION_COMMAND_PERMISSIONS.compile( - application=543421, guild=123321321, command=543123 - ) + async def test_fetch_application_command_permissions( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_application: applications.Application, + mock_partial_command: commands.PartialCommand, + ): + expected_route = routes.GET_APPLICATION_COMMAND_PERMISSIONS.compile(application=111, guild=123, command=666) mock_command_payload = {"id": "9393939393"} - rest_client._request = mock.AsyncMock(return_value=mock_command_payload) - result = await rest_client.fetch_application_command_permissions(543421, 123321321, 543123) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=mock_command_payload + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_command_permissions" + ) as patched_deserialize_guild_command_permissions, + ): + result = await rest_client.fetch_application_command_permissions( + mock_application, mock_partial_guild, mock_partial_command + ) - assert result is rest_client._entity_factory.deserialize_guild_command_permissions.return_value - rest_client._entity_factory.deserialize_guild_command_permissions.assert_called_once_with(mock_command_payload) - rest_client._request.assert_awaited_once_with(expected_route) + assert result is patched_deserialize_guild_command_permissions.return_value + patched_deserialize_guild_command_permissions.assert_called_once_with(mock_command_payload) + patched__request.assert_awaited_once_with(expected_route) - async def test_set_application_command_permissions(self, rest_client): - route = routes.PUT_APPLICATION_COMMAND_PERMISSIONS.compile(application=2321, guild=431, command=666666) - mock_permission = object() + async def test_set_application_command_permissions( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_application: applications.Application, + mock_partial_command: commands.PartialCommand, + ): + route = routes.PUT_APPLICATION_COMMAND_PERMISSIONS.compile(application=111, guild=123, command=666) + mock_permission = mock.Mock() mock_command_payload = {"id": "29292929"} - rest_client._request = mock.AsyncMock(return_value=mock_command_payload) - result = await rest_client.set_application_command_permissions(2321, 431, 666666, [mock_permission]) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=mock_command_payload + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_guild_command_permissions" + ) as patched_deserialize_guild_command_permissions, + mock.patch.object( + rest_client.entity_factory, "serialize_command_permission" + ) as patched_serialize_command_permission, + ): + result = await rest_client.set_application_command_permissions( + mock_application, mock_partial_guild, mock_partial_command, [mock_permission] + ) - assert result is rest_client._entity_factory.deserialize_guild_command_permissions.return_value - rest_client._entity_factory.deserialize_guild_command_permissions.assert_called_once_with(mock_command_payload) - rest_client._request.assert_awaited_once_with( - route, json={"permissions": [rest_client._entity_factory.serialize_command_permission.return_value]} - ) + assert result is patched_deserialize_guild_command_permissions.return_value + patched_deserialize_guild_command_permissions.assert_called_once_with(mock_command_payload) + patched__request.assert_awaited_once_with( + route, json={"permissions": [patched_serialize_command_permission.return_value]} + ) - async def test_fetch_interaction_response(self, rest_client): - expected_route = routes.GET_INTERACTION_RESPONSE.compile(webhook=1235432, token="go homo or go gnomo") - rest_client._request = mock.AsyncMock(return_value={"id": "94949494949"}) + async def test_fetch_interaction_response( + self, rest_client: rest.RESTClientImpl, mock_application: applications.Application + ): + expected_route = routes.GET_INTERACTION_RESPONSE.compile(webhook=111, token="go homo or go gnomo") - result = await rest_client.fetch_interaction_response(StubModel(1235432), "go homo or go gnomo") + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "94949494949"} + ) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + ): + result = await rest_client.fetch_interaction_response(mock_application, "go homo or go gnomo") - assert result is rest_client._entity_factory.deserialize_message.return_value - rest_client._entity_factory.deserialize_message.assert_called_once_with(rest_client._request.return_value) - rest_client._request.assert_awaited_once_with(expected_route, auth=None) + assert result is patched_deserialize_message.return_value + patched_deserialize_message.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with(expected_route, auth=None) - async def test_create_interaction_response_when_form(self, rest_client): - attachment_obj = object() - attachment_obj2 = object() - component_obj = object() - component_obj2 = object() - embed_obj = object() - embed_obj2 = object() - poll_obj = object() + async def test_create_interaction_response_when_form( + self, rest_client: rest.RESTClientImpl, mock_partial_interaction: interactions.PartialInteraction + ): + attachment_obj = mock.Mock() + attachment_obj2 = mock.Mock() + component_obj = mock.Mock() + component_obj2 = mock.Mock() + embed_obj = mock.Mock() + embed_obj2 = mock.Mock() + poll_obj = mock.Mock() mock_form = mock.Mock() mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") - expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=432, token="some token") - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, mock_form)) - rest_client._request = mock.AsyncMock(return_value={"interaction": "callback"}) + expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=777, token="some token") - await rest_client.create_interaction_response( - StubModel(432), - "some token", - 1, - "some content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - poll=poll_obj, - tts=True, - flags=120, - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"interaction": "callback"} + ) as patched_request, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, mock_form) + ) as patched__build_message_payload, + ): + await rest_client.create_interaction_response( + mock_partial_interaction, + "some token", + 1, + "some content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + poll=poll_obj, + tts=True, + flags=120, + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + ) - rest_client._build_message_payload.assert_called_once_with( + patched__build_message_payload.assert_called_once_with( content="some content", attachment=attachment_obj, attachments=[attachment_obj2], @@ -6588,44 +8656,53 @@ async def test_create_interaction_response_when_form(self, rest_client): mock_form.add_field.assert_called_once_with( "payload_json", b'{"type":1,"data":{"testing":"ensure_in_test"}}', content_type="application/json" ) - rest_client._request.assert_awaited_once_with( + patched_request.assert_awaited_once_with( expected_route, form_builder=mock_form, auth=None, query={"with_response": "true"} ) - async def test_create_interaction_response_when_no_form(self, rest_client): - attachment_obj = object() - attachment_obj2 = object() - component_obj = object() - component_obj2 = object() - embed_obj = object() - embed_obj2 = object() - poll_obj = object() + async def test_create_interaction_response_when_no_form( + self, rest_client: rest.RESTClientImpl, mock_partial_interaction: interactions.PartialInteraction + ): + attachment_obj = mock.Mock() + attachment_obj2 = mock.Mock() + component_obj = mock.Mock() + component_obj2 = mock.Mock() + embed_obj = mock.Mock() + embed_obj2 = mock.Mock() + poll_obj = mock.Mock() mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") - expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=432, token="some token") + expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=777, token="some token") rest_client._build_message_payload = mock.Mock(return_value=(mock_body, None)) - rest_client._request = mock.AsyncMock(return_value={"interaction": "callback"}) - await rest_client.create_interaction_response( - StubModel(432), - "some token", - 1, - "some content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - poll=poll_obj, - tts=True, - flags=120, - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"interaction": "callback"} + ) as patched__request, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, None) + ) as patched__build_message_payload, + ): + await rest_client.create_interaction_response( + mock_partial_interaction, + "some token", + 1, + "some content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + poll=poll_obj, + tts=True, + flags=120, + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + ) - rest_client._build_message_payload.assert_called_once_with( + patched__build_message_payload.assert_called_once_with( content="some content", attachment=attachment_obj, attachments=[attachment_obj2], @@ -6640,215 +8717,252 @@ async def test_create_interaction_response_when_no_form(self, rest_client): user_mentions=[9876], role_mentions=[1234], ) - rest_client._request.assert_awaited_once_with( + + patched__request.assert_awaited_once_with( expected_route, json={"type": 1, "data": {"testing": "ensure_in_test"}}, query={"with_response": "true"}, auth=None, ) - async def test_create_interaction_voice_message_response(self, rest_client): - interaction = StubModel(432) + async def test_create_interaction_voice_message_response( + self, rest_client: rest.RESTClientImpl, mock_partial_interaction: interactions.PartialInteraction + ): token = "some token" - expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=interaction, token=token) - attachment_obj = object() + expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=777, token=token) + attachment_obj = mock.Mock() mock_form = mock.Mock() mock_body = data_binding.JSONObjectBuilder() - rest_client._request = mock.AsyncMock(return_value={"interaction": "callback"}) - rest_client._build_voice_message_payload = mock.Mock(return_value=(mock_body, mock_form)) - await rest_client.create_interaction_voice_message_response( - interaction, token=token, attachment=attachment_obj, waveform="AAA", duration=3, flags=54123 - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"interaction": "callback"} + ) as patched__request, + mock.patch.object( + rest_client, "_build_voice_message_payload", return_value=(mock_body, mock_form) + ) as patched__build_voice_message_payload, + ): + await rest_client.create_interaction_voice_message_response( + mock_partial_interaction, + token=token, + attachment=attachment_obj, + waveform="AAA", + duration=3, + flags=54123, + ) - rest_client._build_voice_message_payload.assert_called_once_with( + patched__build_voice_message_payload.assert_called_once_with( attachment=attachment_obj, flags=54123, waveform="AAA", duration=3 ) - rest_client._request.assert_awaited_once_with( + patched__request.assert_awaited_once_with( expected_route, form_builder=mock_form, query={"with_response": "true"}, auth=None ) - async def test_edit_interaction_response_when_form(self, rest_client): - attachment_obj = object() - attachment_obj2 = object() - component_obj = object() - component_obj2 = object() - embed_obj = object() - embed_obj2 = object() + async def test_edit_interaction_response_when_form( + self, rest_client: rest.RESTClientImpl, mock_application: applications.Application + ): + attachment_obj = mock.Mock() + attachment_obj2 = mock.Mock() + component_obj = mock.Mock() + component_obj2 = mock.Mock() + embed_obj = mock.Mock() + embed_obj2 = mock.Mock() mock_form = mock.Mock() mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") - expected_route = routes.PATCH_INTERACTION_RESPONSE.compile(webhook=432, token="some token") - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, mock_form)) - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) - - returned = await rest_client.edit_interaction_response( - StubModel(432), - "some token", - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - ) - assert returned is rest_client._entity_factory.deserialize_message.return_value + expected_route = routes.PATCH_INTERACTION_RESPONSE.compile(webhook=111, token="some token") - rest_client._build_message_payload.assert_called_once_with( - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - edit=True, - ) - mock_form.add_field.assert_called_once_with( - "payload_json", b'{"testing":"ensure_in_test"}', content_type="application/json" - ) - rest_client._request.assert_awaited_once_with(expected_route, form_builder=mock_form, auth=None) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 123} + ) as patched__request, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, mock_form) + ) as patched__build_message_payload, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + ): + returned = await rest_client.edit_interaction_response( + mock_application, + "some token", + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + ) + assert returned is patched_deserialize_message.return_value + + patched__build_message_payload.assert_called_once_with( + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + edit=True, + ) + mock_form.add_field.assert_called_once_with( + "payload_json", b'{"testing":"ensure_in_test"}', content_type="application/json" + ) + patched__request.assert_awaited_once_with(expected_route, form_builder=mock_form, auth=None) + patched_deserialize_message.assert_called_once_with({"message_id": 123}) - async def test_edit_interaction_response_when_no_form(self, rest_client): - attachment_obj = object() - attachment_obj2 = object() - component_obj = object() - component_obj2 = object() - embed_obj = object() - embed_obj2 = object() + async def test_edit_interaction_response_when_no_form( + self, rest_client: rest.RESTClientImpl, mock_application: applications.Application + ): + attachment_obj = mock.Mock() + attachment_obj2 = mock.Mock() + component_obj = mock.Mock() + component_obj2 = mock.Mock() + embed_obj = mock.Mock() + embed_obj2 = mock.Mock() mock_body = data_binding.JSONObjectBuilder() mock_body.put("testing", "ensure_in_test") - expected_route = routes.PATCH_INTERACTION_RESPONSE.compile(webhook=432, token="some token") - rest_client._build_message_payload = mock.Mock(return_value=(mock_body, None)) - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) - - returned = await rest_client.edit_interaction_response( - StubModel(432), - "some token", - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - ) - assert returned is rest_client._entity_factory.deserialize_message.return_value - - rest_client._build_message_payload.assert_called_once_with( - content="new content", - attachment=attachment_obj, - attachments=[attachment_obj2], - component=component_obj, - components=[component_obj2], - embed=embed_obj, - embeds=[embed_obj2], - mentions_everyone=False, - user_mentions=[9876], - role_mentions=[1234], - edit=True, - ) - rest_client._request.assert_awaited_once_with(expected_route, json={"testing": "ensure_in_test"}, auth=None) - rest_client._entity_factory.deserialize_message.assert_called_once_with({"message_id": 123}) - - async def test_edit_interaction_voice_message_response(self, rest_client): - interaction = StubModel(432) - token = "some token" - expected_route = routes.PATCH_INTERACTION_RESPONSE.compile(webhook=interaction, token=token) - attachment_obj = object() - mock_form = mock.Mock() - mock_body = data_binding.JSONObjectBuilder() - rest_client._request = mock.AsyncMock(return_value={"message_id": 123}) - rest_client._build_voice_message_payload = mock.Mock(return_value=(mock_body, mock_form)) - - await rest_client.edit_interaction_voice_message_response( - interaction, token=token, attachment=attachment_obj, waveform="AAA", duration=3 - ) - - rest_client._build_voice_message_payload.assert_called_once_with( - attachment=attachment_obj, waveform="AAA", duration=3 - ) + expected_route = routes.PATCH_INTERACTION_RESPONSE.compile(webhook=111, token="some token") - rest_client._request.assert_awaited_once_with(expected_route, form_builder=mock_form, auth=None) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"message_id": 123} + ) as patched__request, + mock.patch.object( + rest_client, "_build_message_payload", return_value=(mock_body, None) + ) as patched__build_message_payload, + mock.patch.object(rest_client.entity_factory, "deserialize_message") as patched_deserialize_message, + ): + returned = await rest_client.edit_interaction_response( + mock_application, + "some token", + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + ) + assert returned is patched_deserialize_message.return_value + + patched__build_message_payload.assert_called_once_with( + content="new content", + attachment=attachment_obj, + attachments=[attachment_obj2], + component=component_obj, + components=[component_obj2], + embed=embed_obj, + embeds=[embed_obj2], + mentions_everyone=False, + user_mentions=[9876], + role_mentions=[1234], + edit=True, + ) + patched__request.assert_awaited_once_with(expected_route, json={"testing": "ensure_in_test"}, auth=None) + patched_deserialize_message.assert_called_once_with({"message_id": 123}) - async def test_delete_interaction_response(self, rest_client): - expected_route = routes.DELETE_INTERACTION_RESPONSE.compile(webhook=1235431, token="go homo now") - rest_client._request = mock.AsyncMock() + async def test_delete_interaction_response( + self, rest_client: rest.RESTClientImpl, mock_application: applications.Application + ): + expected_route = routes.DELETE_INTERACTION_RESPONSE.compile(webhook=111, token="go homo now") - await rest_client.delete_interaction_response(StubModel(1235431), "go homo now") + with mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request: + await rest_client.delete_interaction_response(mock_application, "go homo now") - rest_client._request.assert_awaited_once_with(expected_route, auth=None) + patched__request.assert_awaited_once_with(expected_route, auth=None) - async def test_create_autocomplete_response(self, rest_client): - expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=1235431, token="snek") - rest_client._request = mock.AsyncMock(return_value={"interaction": "callback"}) + async def test_create_autocomplete_response( + self, rest_client: rest.RESTClientImpl, mock_partial_interaction: interactions.PartialInteraction + ): + expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=777, token="snek") choices = [ special_endpoints.AutocompleteChoiceBuilder(name="c", value="d"), special_endpoints.AutocompleteChoiceBuilder(name="eee", value="fff"), ] - await rest_client.create_autocomplete_response(StubModel(1235431), "snek", choices) - rest_client._request.assert_awaited_once_with( + with mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"interaction": "callback"} + ) as patched__request: + await rest_client.create_autocomplete_response(mock_partial_interaction, "snek", choices) + + patched__request.assert_awaited_once_with( expected_route, json={"type": 8, "data": {"choices": [{"name": "c", "value": "d"}, {"name": "eee", "value": "fff"}]}}, query={"with_response": "true"}, auth=None, ) - async def test_create_autocomplete_response_for_deprecated_command_choices(self, rest_client): - expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=1235431, token="snek") - rest_client._request = mock.AsyncMock(return_value={"interaction": "callback"}) + async def test_create_autocomplete_response_for_deprecated_command_choices( + self, rest_client: rest.RESTClientImpl, mock_partial_interaction: interactions.PartialInteraction + ): + expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=777, token="snek") - choices = [commands.CommandChoice(name="a", value="b"), commands.CommandChoice(name="foo", value="bar")] - await rest_client.create_autocomplete_response(StubModel(1235431), "snek", choices) + choice_1 = mock.Mock(name="a", value="b") + choice_2 = mock.Mock(name="foo", value="bar") + with mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"interaction": "callback"} + ) as patched__request: + await rest_client.create_autocomplete_response(mock_partial_interaction, "snek", [choice_1, choice_2]) - rest_client._request.assert_awaited_once_with( + patched__request.assert_awaited_once_with( expected_route, - json={"type": 8, "data": {"choices": [{"name": "a", "value": "b"}, {"name": "foo", "value": "bar"}]}}, + json={ + "type": 8, + "data": { + "choices": [ + {"name": choice_1.name, "value": choice_1.value}, + {"name": choice_2.name, "value": choice_2.value}, + ] + }, + }, query={"with_response": "true"}, auth=None, ) - async def test_create_modal_response(self, rest_client): - expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=1235431, token="snek") - rest_client._request = mock.AsyncMock(return_value={"interaction": "callback"}) + async def test_create_modal_response( + self, rest_client: rest.RESTClientImpl, mock_partial_interaction: interactions.PartialInteraction + ): + expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=777, token="snek") mock_payload = mock.Mock() mock_files = mock.Mock() component = mock.Mock(build=mock.Mock(return_value=(mock_payload, mock_files))) - await rest_client.create_modal_response( - StubModel(1235431), "snek", title="title", custom_id="idd", component=component - ) + with mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"interaction": "callback"} + ) as patched__request: + await rest_client.create_modal_response( + mock_partial_interaction, "snek", title="title", custom_id="idd", component=component + ) - rest_client._request.assert_awaited_once_with( + patched__request.assert_awaited_once_with( expected_route, json={"type": 9, "data": {"title": "title", "custom_id": "idd", "components": [mock_payload]}}, query={"with_response": "true"}, auth=None, ) - async def test_create_modal_response_with_plural_args(self, rest_client): - expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=1235431, token="snek") + async def test_create_modal_response_with_plural_args( + self, rest_client: rest.RESTClientImpl, mock_partial_interaction: interactions.PartialInteraction + ): + expected_route = routes.POST_INTERACTION_RESPONSE.compile(interaction=777, token="snek") rest_client._request = mock.AsyncMock(return_value={"interaction": "callback"}) mock_payload = mock.Mock() mock_files = mock.Mock() component = mock.Mock(build=mock.Mock(return_value=(mock_payload, mock_files))) await rest_client.create_modal_response( - StubModel(1235431), "snek", title="title", custom_id="idd", components=[component] + mock_partial_interaction, "snek", title="title", custom_id="idd", components=[component] ) rest_client._request.assert_awaited_once_with( @@ -6858,363 +8972,530 @@ async def test_create_modal_response_with_plural_args(self, rest_client): auth=None, ) - async def test_create_modal_response_when_both_component_and_components_passed(self, rest_client): + async def test_create_modal_response_when_both_component_and_components_passed( + self, rest_client: rest.RESTClientImpl, mock_partial_interaction: interactions.PartialInteraction + ): with pytest.raises(ValueError, match="Must specify exactly only one of 'component' or 'components'"): await rest_client.create_modal_response( - StubModel(1235431), "snek", title="title", custom_id="idd", component="not none", components=[] + mock_partial_interaction, "snek", title="title", custom_id="idd", component=mock.Mock(), components=[] ) - async def test_fetch_scheduled_event(self, rest_client: rest.RESTClientImpl): - expected_route = routes.GET_GUILD_SCHEDULED_EVENT.compile(guild=453123, scheduled_event=222332323) - rest_client._request = mock.AsyncMock(return_value={"id": "4949494949"}) + async def test_fetch_scheduled_event( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_scheduled_event: scheduled_events.ScheduledEvent, + ): + expected_route = routes.GET_GUILD_SCHEDULED_EVENT.compile(guild=123, scheduled_event=888) - result = await rest_client.fetch_scheduled_event(StubModel(453123), StubModel(222332323)) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "4949494949"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_scheduled_event" + ) as patched_deserialize_scheduled_event, + ): + result = await rest_client.fetch_scheduled_event(mock_partial_guild, mock_scheduled_event) - assert result is rest_client._entity_factory.deserialize_scheduled_event.return_value - rest_client._entity_factory.deserialize_scheduled_event.assert_called_once_with({"id": "4949494949"}) - rest_client._request.assert_awaited_once_with(expected_route, query={"with_user_count": "true"}) + assert result is patched_deserialize_scheduled_event.return_value + patched_deserialize_scheduled_event.assert_called_once_with({"id": "4949494949"}) + patched__request.assert_awaited_once_with(expected_route, query={"with_user_count": "true"}) - async def test_fetch_scheduled_events(self, rest_client: rest.RESTClientImpl): - expected_route = routes.GET_GUILD_SCHEDULED_EVENTS.compile(guild=65234123) - rest_client._request = mock.AsyncMock(return_value=[{"id": "494920234", "type": "1"}]) + async def test_fetch_scheduled_events( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + expected_route = routes.GET_GUILD_SCHEDULED_EVENTS.compile(guild=123) - result = await rest_client.fetch_scheduled_events(StubModel(65234123)) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "494920234", "type": "1"}] + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_scheduled_event" + ) as patched_deserialize_scheduled_event, + ): + result = await rest_client.fetch_scheduled_events(mock_partial_guild) - assert result == [rest_client._entity_factory.deserialize_scheduled_event.return_value] - rest_client._entity_factory.deserialize_scheduled_event.assert_called_once_with( - {"id": "494920234", "type": "1"} - ) - rest_client._request.assert_awaited_once_with(expected_route, query={"with_user_count": "true"}) + assert result == [patched_deserialize_scheduled_event.return_value] + patched_deserialize_scheduled_event.assert_called_once_with({"id": "494920234", "type": "1"}) + patched__request.assert_awaited_once_with(expected_route, query={"with_user_count": "true"}) - async def test_fetch_scheduled_events_handles_unrecognised_events(self, rest_client: rest.RESTClientImpl): + async def test_fetch_scheduled_events_handles_unrecognised_events( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): mock_event = mock.Mock() - rest_client._entity_factory.deserialize_scheduled_event.side_effect = [ - errors.UnrecognisedEntityError("evil laugh"), - mock_event, - ] - expected_route = routes.GET_GUILD_SCHEDULED_EVENTS.compile(guild=65234123) - rest_client._request = mock.AsyncMock( - return_value=[{"id": "432234", "type": "1"}, {"id": "4939394", "type": "494949"}] - ) + expected_route = routes.GET_GUILD_SCHEDULED_EVENTS.compile(guild=123) - result = await rest_client.fetch_scheduled_events(StubModel(65234123)) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value=[{"id": "432234", "type": "1"}, {"id": "4939394", "type": "494949"}], + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, + "deserialize_scheduled_event", + side_effect=[errors.UnrecognisedEntityError("evil laugh"), mock_event], + ) as patched_deserialize_scheduled_event, + ): + result = await rest_client.fetch_scheduled_events(mock_partial_guild) - assert result == [mock_event] - rest_client._entity_factory.deserialize_scheduled_event.assert_has_calls( - [mock.call({"id": "432234", "type": "1"}), mock.call({"id": "4939394", "type": "494949"})] - ) - rest_client._request.assert_awaited_once_with(expected_route, query={"with_user_count": "true"}) + assert result == [mock_event] + patched_deserialize_scheduled_event.assert_has_calls( + [mock.call({"id": "432234", "type": "1"}), mock.call({"id": "4939394", "type": "494949"})] + ) + patched__request.assert_awaited_once_with(expected_route, query={"with_user_count": "true"}) - async def test_create_stage_event(self, rest_client: rest.RESTClientImpl, file_resource_patch): - expected_route = routes.POST_GUILD_SCHEDULED_EVENT.compile(guild=123321) - rest_client._request = mock.AsyncMock(return_value={"id": "494949", "name": "MEOsdasdWWWWW"}) + async def test_create_stage_event( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_stage_channel: channels.GuildStageChannel, + file_resource_patch: files.Resource[typing.Any], + ): + expected_route = routes.POST_GUILD_SCHEDULED_EVENT.compile(guild=123) - result = await rest_client.create_stage_event( - StubModel(123321), - StubModel(7654345), - "boob man", - datetime.datetime(2001, 1, 1, 17, 42, 41, 891222, tzinfo=datetime.timezone.utc), - description="o", - end_time=datetime.datetime(2002, 2, 2, 17, 42, 41, 891222, tzinfo=datetime.timezone.utc), - image="tksksk.txt", - privacy_level=654134, - reason="bye bye", - ) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"id": "494949", "name": "MEOsdasdWWWWW"}, + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_scheduled_stage_event" + ) as patched_deserialize_scheduled_stage_event, + ): + result = await rest_client.create_stage_event( + mock_partial_guild, + mock_guild_stage_channel, + "boob man", + datetime.datetime(2001, 1, 1, 17, 42, 41, 891222, tzinfo=datetime.timezone.utc), + description="o", + end_time=datetime.datetime(2002, 2, 2, 17, 42, 41, 891222, tzinfo=datetime.timezone.utc), + image="tksksk.txt", + privacy_level=654134, + reason="bye bye", + ) - assert result is rest_client._entity_factory.deserialize_scheduled_stage_event.return_value - rest_client._entity_factory.deserialize_scheduled_stage_event.assert_called_once_with( - {"id": "494949", "name": "MEOsdasdWWWWW"} - ) - rest_client._request.assert_awaited_once_with( - expected_route, - json={ - "channel_id": "7654345", - "name": "boob man", - "description": "o", - "entity_type": scheduled_events.ScheduledEventType.STAGE_INSTANCE, - "privacy_level": 654134, - "scheduled_start_time": "2001-01-01T17:42:41.891222+00:00", - "scheduled_end_time": "2002-02-02T17:42:41.891222+00:00", - "image": "some data", - }, - reason="bye bye", - ) + assert result is patched_deserialize_scheduled_stage_event.return_value + patched_deserialize_scheduled_stage_event.assert_called_once_with({"id": "494949", "name": "MEOsdasdWWWWW"}) + patched__request.assert_awaited_once_with( + expected_route, + json={ + "channel_id": "45613", + "name": "boob man", + "description": "o", + "entity_type": scheduled_events.ScheduledEventType.STAGE_INSTANCE, + "privacy_level": 654134, + "scheduled_start_time": "2001-01-01T17:42:41.891222+00:00", + "scheduled_end_time": "2002-02-02T17:42:41.891222+00:00", + "image": "some data", + }, + reason="bye bye", + ) - async def test_create_stage_event_without_optionals(self, rest_client: rest.RESTClientImpl): - expected_route = routes.POST_GUILD_SCHEDULED_EVENT.compile(guild=234432234) - rest_client._request = mock.AsyncMock(return_value={"id": "494949", "name": "MEOWWWWW"}) + async def test_create_stage_event_without_optionals( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_stage_channel: channels.GuildStageChannel, + ): + expected_route = routes.POST_GUILD_SCHEDULED_EVENT.compile(guild=123) - result = await rest_client.create_stage_event( - StubModel(234432234), - StubModel(654234432), - "boob man", - datetime.datetime(2021, 3, 11, 17, 42, 41, 891222, tzinfo=datetime.timezone.utc), - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "494949", "name": "MEOWWWWW"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_scheduled_stage_event" + ) as patched_deserialize_scheduled_stage_event, + ): + result = await rest_client.create_stage_event( + mock_partial_guild, + mock_guild_stage_channel, + "boob man", + datetime.datetime(2021, 3, 11, 17, 42, 41, 891222, tzinfo=datetime.timezone.utc), + ) - assert result is rest_client._entity_factory.deserialize_scheduled_stage_event.return_value - rest_client._entity_factory.deserialize_scheduled_stage_event.assert_called_once_with( - {"id": "494949", "name": "MEOWWWWW"} - ) - rest_client._request.assert_awaited_once_with( - expected_route, - json={ - "channel_id": "654234432", - "name": "boob man", - "entity_type": scheduled_events.ScheduledEventType.STAGE_INSTANCE, - "privacy_level": scheduled_events.EventPrivacyLevel.GUILD_ONLY, - "scheduled_start_time": "2021-03-11T17:42:41.891222+00:00", - }, - reason=undefined.UNDEFINED, - ) + assert result is patched_deserialize_scheduled_stage_event.return_value + patched_deserialize_scheduled_stage_event.assert_called_once_with({"id": "494949", "name": "MEOWWWWW"}) + patched__request.assert_awaited_once_with( + expected_route, + json={ + "channel_id": "45613", + "name": "boob man", + "entity_type": scheduled_events.ScheduledEventType.STAGE_INSTANCE, + "privacy_level": scheduled_events.EventPrivacyLevel.GUILD_ONLY, + "scheduled_start_time": "2021-03-11T17:42:41.891222+00:00", + }, + reason=undefined.UNDEFINED, + ) - async def test_create_voice_event(self, rest_client: rest.RESTClientImpl, file_resource_patch): - expected_route = routes.POST_GUILD_SCHEDULED_EVENT.compile(guild=76234123) - rest_client._request = mock.AsyncMock(return_value={"id": "494942342439", "name": "MEOW"}) + async def test_create_voice_event( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_stage_channel: channels.GuildStageChannel, + file_resource_patch: files.Resource[typing.Any], + ): + expected_route = routes.POST_GUILD_SCHEDULED_EVENT.compile(guild=123) - result = await rest_client.create_voice_event( - StubModel(76234123), - StubModel(65243123), - "boom man", - datetime.datetime(2021, 3, 9, 13, 42, 41, 891222, tzinfo=datetime.timezone.utc), - description="hhhhh", - end_time=datetime.datetime(2069, 3, 9, 13, 1, 41, 891222, tzinfo=datetime.timezone.utc), - image="meow.txt", - privacy_level=6523123, - reason="it was the {insert political part here}", - ) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"id": "494942342439", "name": "MEOW"}, + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_scheduled_voice_event" + ) as patched_deserialize_scheduled_voice_event, + ): + result = await rest_client.create_voice_event( + mock_partial_guild, + mock_guild_stage_channel, + "boom man", + datetime.datetime(2021, 3, 9, 13, 42, 41, 891222, tzinfo=datetime.timezone.utc), + description="hhhhh", + end_time=datetime.datetime(2069, 3, 9, 13, 1, 41, 891222, tzinfo=datetime.timezone.utc), + image="meow.txt", + privacy_level=6523123, + reason="it was the {insert political part here}", + ) - assert result is rest_client._entity_factory.deserialize_scheduled_voice_event.return_value - rest_client._entity_factory.deserialize_scheduled_voice_event.assert_called_once_with( - {"id": "494942342439", "name": "MEOW"} - ) - rest_client._request.assert_awaited_once_with( - expected_route, - json={ - "channel_id": "65243123", - "name": "boom man", - "entity_type": scheduled_events.ScheduledEventType.VOICE, - "privacy_level": 6523123, - "scheduled_start_time": "2021-03-09T13:42:41.891222+00:00", - "scheduled_end_time": "2069-03-09T13:01:41.891222+00:00", - "description": "hhhhh", - "image": "some data", - }, - reason="it was the {insert political part here}", - ) + assert result is patched_deserialize_scheduled_voice_event.return_value + patched_deserialize_scheduled_voice_event.assert_called_once_with({"id": "494942342439", "name": "MEOW"}) + patched__request.assert_awaited_once_with( + expected_route, + json={ + "channel_id": "45613", + "name": "boom man", + "entity_type": scheduled_events.ScheduledEventType.VOICE, + "privacy_level": 6523123, + "scheduled_start_time": "2021-03-09T13:42:41.891222+00:00", + "scheduled_end_time": "2069-03-09T13:01:41.891222+00:00", + "description": "hhhhh", + "image": "some data", + }, + reason="it was the {insert political part here}", + ) - async def test_create_voice_event_without_optionals(self, rest_client: rest.RESTClientImpl): - expected_route = routes.POST_GUILD_SCHEDULED_EVENT.compile(guild=76234123) - rest_client._request = mock.AsyncMock(return_value={"id": "123321123", "name": "MEOW"}) + async def test_create_voice_event_without_optionals( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_stage_channel: channels.GuildStageChannel, + ): + expected_route = routes.POST_GUILD_SCHEDULED_EVENT.compile(guild=123) - result = await rest_client.create_voice_event( - StubModel(76234123), - StubModel(65243123), - "boom man", - datetime.datetime(2021, 3, 9, 13, 42, 41, 891222, tzinfo=datetime.timezone.utc), - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "123321123", "name": "MEOW"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_scheduled_voice_event" + ) as patched_deserialize_scheduled_voice_event, + ): + result = await rest_client.create_voice_event( + mock_partial_guild, + mock_guild_stage_channel, + "boom man", + datetime.datetime(2021, 3, 9, 13, 42, 41, 891222, tzinfo=datetime.timezone.utc), + ) - assert result is rest_client._entity_factory.deserialize_scheduled_voice_event.return_value - rest_client._entity_factory.deserialize_scheduled_voice_event.assert_called_once_with( - {"id": "123321123", "name": "MEOW"} - ) - rest_client._request.assert_awaited_once_with( - expected_route, - json={ - "channel_id": "65243123", - "name": "boom man", - "entity_type": scheduled_events.ScheduledEventType.VOICE, - "privacy_level": scheduled_events.EventPrivacyLevel.GUILD_ONLY, - "scheduled_start_time": "2021-03-09T13:42:41.891222+00:00", - }, - reason=undefined.UNDEFINED, - ) + assert result is patched_deserialize_scheduled_voice_event.return_value + patched_deserialize_scheduled_voice_event.assert_called_once_with({"id": "123321123", "name": "MEOW"}) + patched__request.assert_awaited_once_with( + expected_route, + json={ + "channel_id": "45613", + "name": "boom man", + "entity_type": scheduled_events.ScheduledEventType.VOICE, + "privacy_level": scheduled_events.EventPrivacyLevel.GUILD_ONLY, + "scheduled_start_time": "2021-03-09T13:42:41.891222+00:00", + }, + reason=undefined.UNDEFINED, + ) - async def test_create_external_event(self, rest_client: rest.RESTClientImpl, file_resource_patch): - expected_route = routes.POST_GUILD_SCHEDULED_EVENT.compile(guild=34232412) - rest_client._request = mock.AsyncMock(return_value={"id": "494949", "name": "MerwwerEOW"}) + async def test_create_external_event( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + file_resource_patch: files.Resource[typing.Any], + ): + expected_route = routes.POST_GUILD_SCHEDULED_EVENT.compile(guild=123) - result = await rest_client.create_external_event( - StubModel(34232412), - "hi", - "Outside", - datetime.datetime(2021, 3, 6, 2, 42, 41, 891222, tzinfo=datetime.timezone.utc), - datetime.datetime(2023, 5, 6, 16, 42, 41, 891222, tzinfo=datetime.timezone.utc), - description="This is a description", - image="icon.png", - privacy_level=6454, - reason="chairman meow", - ) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"id": "494949", "name": "MerwwerEOW"}, + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_scheduled_external_event" + ) as patched_deserialize_scheduled_external_event, + ): + result = await rest_client.create_external_event( + mock_partial_guild, + "hi", + "Outside", + datetime.datetime(2021, 3, 6, 2, 42, 41, 891222, tzinfo=datetime.timezone.utc), + datetime.datetime(2023, 5, 6, 16, 42, 41, 891222, tzinfo=datetime.timezone.utc), + description="This is a description", + image="icon.png", + privacy_level=6454, + reason="chairman meow", + ) - assert result is rest_client._entity_factory.deserialize_scheduled_external_event.return_value - rest_client._entity_factory.deserialize_scheduled_external_event.assert_called_once_with( - {"id": "494949", "name": "MerwwerEOW"} - ) - rest_client._request.assert_awaited_once_with( - expected_route, - json={ - "name": "hi", - "entity_metadata": {"location": "Outside"}, - "entity_type": scheduled_events.ScheduledEventType.EXTERNAL, - "privacy_level": 6454, - "scheduled_start_time": "2021-03-06T02:42:41.891222+00:00", - "scheduled_end_time": "2023-05-06T16:42:41.891222+00:00", - "description": "This is a description", - "image": "some data", - }, - reason="chairman meow", - ) + assert result is patched_deserialize_scheduled_external_event.return_value + patched_deserialize_scheduled_external_event.assert_called_once_with({"id": "494949", "name": "MerwwerEOW"}) + patched__request.assert_awaited_once_with( + expected_route, + json={ + "name": "hi", + "entity_metadata": {"location": "Outside"}, + "entity_type": scheduled_events.ScheduledEventType.EXTERNAL, + "privacy_level": 6454, + "scheduled_start_time": "2021-03-06T02:42:41.891222+00:00", + "scheduled_end_time": "2023-05-06T16:42:41.891222+00:00", + "description": "This is a description", + "image": "some data", + }, + reason="chairman meow", + ) - async def test_create_external_event_without_optionals(self, rest_client: rest.RESTClientImpl): - expected_route = routes.POST_GUILD_SCHEDULED_EVENT.compile(guild=34232412) - rest_client._request = mock.AsyncMock(return_value={"id": "494923443249", "name": "MEOW"}) + async def test_create_external_event_without_optionals( + self, rest_client: rest.RESTClientImpl, mock_partial_guild: guilds.PartialGuild + ): + expected_route = routes.POST_GUILD_SCHEDULED_EVENT.compile(guild=123) - result = await rest_client.create_external_event( - StubModel(34232412), - "hi", - "Outside", - datetime.datetime(2021, 3, 6, 2, 42, 41, 891222, tzinfo=datetime.timezone.utc), - datetime.datetime(2023, 5, 6, 16, 42, 41, 891222, tzinfo=datetime.timezone.utc), - ) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"id": "494923443249", "name": "MEOW"}, + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_scheduled_external_event" + ) as patched_deserialize_scheduled_external_event, + ): + result = await rest_client.create_external_event( + mock_partial_guild, + "hi", + "Outside", + datetime.datetime(2021, 3, 6, 2, 42, 41, 891222, tzinfo=datetime.timezone.utc), + datetime.datetime(2023, 5, 6, 16, 42, 41, 891222, tzinfo=datetime.timezone.utc), + ) - assert result is rest_client._entity_factory.deserialize_scheduled_external_event.return_value - rest_client._entity_factory.deserialize_scheduled_external_event.assert_called_once_with( - {"id": "494923443249", "name": "MEOW"} - ) - rest_client._request.assert_awaited_once_with( - expected_route, - json={ - "name": "hi", - "entity_metadata": {"location": "Outside"}, - "entity_type": scheduled_events.ScheduledEventType.EXTERNAL, - "privacy_level": scheduled_events.EventPrivacyLevel.GUILD_ONLY, - "scheduled_start_time": "2021-03-06T02:42:41.891222+00:00", - "scheduled_end_time": "2023-05-06T16:42:41.891222+00:00", - }, - reason=undefined.UNDEFINED, - ) + assert result is patched_deserialize_scheduled_external_event.return_value + patched_deserialize_scheduled_external_event.assert_called_once_with({"id": "494923443249", "name": "MEOW"}) + patched__request.assert_awaited_once_with( + expected_route, + json={ + "name": "hi", + "entity_metadata": {"location": "Outside"}, + "entity_type": scheduled_events.ScheduledEventType.EXTERNAL, + "privacy_level": scheduled_events.EventPrivacyLevel.GUILD_ONLY, + "scheduled_start_time": "2021-03-06T02:42:41.891222+00:00", + "scheduled_end_time": "2023-05-06T16:42:41.891222+00:00", + }, + reason=undefined.UNDEFINED, + ) - async def test_edit_scheduled_event(self, rest_client: rest.RESTClientImpl, file_resource_patch): - expected_route = routes.PATCH_GUILD_SCHEDULED_EVENT.compile(guild=345543, scheduled_event=123321123) - rest_client._request = mock.AsyncMock(return_value={"id": "494949", "name": "MEO43345W"}) + async def test_edit_scheduled_event( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_text_channel: channels.GuildTextChannel, + mock_scheduled_event: scheduled_events.ScheduledEvent, + file_resource_patch: files.Resource[typing.Any], + ): + expected_route = routes.PATCH_GUILD_SCHEDULED_EVENT.compile(guild=123, scheduled_event=888) - result = await rest_client.edit_scheduled_event( - StubModel(345543), - StubModel(123321123), - channel=StubModel(45423423), - description="hihihi", - entity_type=scheduled_events.ScheduledEventType.VOICE, - image="icon.png", - location="Trans-land", - name="Nihongo", - privacy_level=69, - start_time=datetime.datetime(2022, 3, 6, 12, 42, 41, 891222, tzinfo=datetime.timezone.utc), - end_time=datetime.datetime(2022, 5, 6, 12, 42, 41, 891222, tzinfo=datetime.timezone.utc), - status=64, - reason="go home", - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "494949", "name": "MEO43345W"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_scheduled_event" + ) as patched_deserialize_scheduled_event, + ): + result = await rest_client.edit_scheduled_event( + mock_partial_guild, + mock_scheduled_event, + channel=mock_guild_text_channel, + description="hihihi", + entity_type=scheduled_events.ScheduledEventType.VOICE, + image="icon.png", + location="Trans-land", + name="Nihongo", + privacy_level=69, + start_time=datetime.datetime(2022, 3, 6, 12, 42, 41, 891222, tzinfo=datetime.timezone.utc), + end_time=datetime.datetime(2022, 5, 6, 12, 42, 41, 891222, tzinfo=datetime.timezone.utc), + status=64, + reason="go home", + ) - assert result is rest_client._entity_factory.deserialize_scheduled_event.return_value - rest_client._entity_factory.deserialize_scheduled_event.assert_called_once_with( - {"id": "494949", "name": "MEO43345W"} - ) - rest_client._request.assert_awaited_once_with( - expected_route, - json={ - "channel_id": "45423423", - "entity_metadata": {"location": "Trans-land"}, - "name": "Nihongo", - "privacy_level": 69, - "scheduled_start_time": "2022-03-06T12:42:41.891222+00:00", - "scheduled_end_time": "2022-05-06T12:42:41.891222+00:00", - "description": "hihihi", - "entity_type": scheduled_events.ScheduledEventType.VOICE, - "status": 64, - "image": "some data", - }, - reason="go home", - ) + assert result is patched_deserialize_scheduled_event.return_value + patched_deserialize_scheduled_event.assert_called_once_with({"id": "494949", "name": "MEO43345W"}) + patched__request.assert_awaited_once_with( + expected_route, + json={ + "channel_id": "4560", + "entity_metadata": {"location": "Trans-land"}, + "name": "Nihongo", + "privacy_level": 69, + "scheduled_start_time": "2022-03-06T12:42:41.891222+00:00", + "scheduled_end_time": "2022-05-06T12:42:41.891222+00:00", + "description": "hihihi", + "entity_type": scheduled_events.ScheduledEventType.VOICE, + "status": 64, + "image": "some data", + }, + reason="go home", + ) - async def test_edit_scheduled_event_with_null_fields(self, rest_client: rest.RESTClientImpl): - expected_route = routes.PATCH_GUILD_SCHEDULED_EVENT.compile(guild=345543, scheduled_event=123321123) - rest_client._request = mock.AsyncMock(return_value={"id": "494949", "name": "ME222222OW"}) + async def test_edit_scheduled_event_with_null_fields( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_scheduled_event: scheduled_events.ScheduledEvent, + ): + expected_route = routes.PATCH_GUILD_SCHEDULED_EVENT.compile(guild=123, scheduled_event=888) - result = await rest_client.edit_scheduled_event( - StubModel(345543), StubModel(123321123), channel=None, description=None, end_time=None - ) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"id": "494949", "name": "ME222222OW"}, + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_scheduled_event" + ) as patched_deserialize_scheduled_event, + ): + result = await rest_client.edit_scheduled_event( + mock_partial_guild, mock_scheduled_event, channel=None, description=None, end_time=None + ) - assert result is rest_client._entity_factory.deserialize_scheduled_event.return_value - rest_client._entity_factory.deserialize_scheduled_event.assert_called_once_with( - {"id": "494949", "name": "ME222222OW"} - ) - rest_client._request.assert_awaited_once_with( - expected_route, - json={"channel_id": None, "description": None, "scheduled_end_time": None}, - reason=undefined.UNDEFINED, - ) + assert result is patched_deserialize_scheduled_event.return_value + patched_deserialize_scheduled_event.assert_called_once_with({"id": "494949", "name": "ME222222OW"}) + patched__request.assert_awaited_once_with( + expected_route, + json={"channel_id": None, "description": None, "scheduled_end_time": None}, + reason=undefined.UNDEFINED, + ) - async def test_edit_scheduled_event_without_optionals(self, rest_client: rest.RESTClientImpl): - expected_route = routes.PATCH_GUILD_SCHEDULED_EVENT.compile(guild=345543, scheduled_event=123321123) - rest_client._request = mock.AsyncMock(return_value={"id": "494123321949", "name": "MEOW"}) + async def test_edit_scheduled_event_without_optionals( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_scheduled_event: scheduled_events.ScheduledEvent, + ): + expected_route = routes.PATCH_GUILD_SCHEDULED_EVENT.compile(guild=123, scheduled_event=888) - result = await rest_client.edit_scheduled_event(StubModel(345543), StubModel(123321123)) + with ( + mock.patch.object( + rest_client, + "_request", + new_callable=mock.AsyncMock, + return_value={"id": "494123321949", "name": "MEOW"}, + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_scheduled_event" + ) as patched_deserialize_scheduled_event, + ): + result = await rest_client.edit_scheduled_event(mock_partial_guild, mock_scheduled_event) - assert result is rest_client._entity_factory.deserialize_scheduled_event.return_value - rest_client._entity_factory.deserialize_scheduled_event.assert_called_once_with( - {"id": "494123321949", "name": "MEOW"} - ) - rest_client._request.assert_awaited_once_with(expected_route, json={}, reason=undefined.UNDEFINED) + assert result is patched_deserialize_scheduled_event.return_value + patched_deserialize_scheduled_event.assert_called_once_with({"id": "494123321949", "name": "MEOW"}) + patched__request.assert_awaited_once_with(expected_route, json={}, reason=undefined.UNDEFINED) - async def test_edit_scheduled_event_when_changing_to_external(self, rest_client: rest.RESTClientImpl): - expected_route = routes.PATCH_GUILD_SCHEDULED_EVENT.compile(guild=345543, scheduled_event=123321123) - rest_client._request = mock.AsyncMock(return_value={"id": "49342344949", "name": "MEOW"}) + async def test_edit_scheduled_event_when_changing_to_external( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_guild_text_channel: channels.GuildTextChannel, + mock_scheduled_event: scheduled_events.ScheduledEvent, + ): + expected_route = routes.PATCH_GUILD_SCHEDULED_EVENT.compile(guild=123, scheduled_event=888) - result = await rest_client.edit_scheduled_event( - StubModel(345543), - StubModel(123321123), - entity_type=scheduled_events.ScheduledEventType.EXTERNAL, - channel=StubModel(5461231), - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "49342344949", "name": "MEOW"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_scheduled_event" + ) as patched_deserialize_scheduled_event, + ): + result = await rest_client.edit_scheduled_event( + mock_partial_guild, + mock_scheduled_event, + entity_type=scheduled_events.ScheduledEventType.EXTERNAL, + channel=mock_guild_text_channel, + ) - assert result is rest_client._entity_factory.deserialize_scheduled_event.return_value - rest_client._entity_factory.deserialize_scheduled_event.assert_called_once_with( - {"id": "49342344949", "name": "MEOW"} - ) - rest_client._request.assert_awaited_once_with( - expected_route, - json={"channel_id": "5461231", "entity_type": scheduled_events.ScheduledEventType.EXTERNAL}, - reason=undefined.UNDEFINED, - ) + assert result is patched_deserialize_scheduled_event.return_value + patched_deserialize_scheduled_event.assert_called_once_with({"id": "49342344949", "name": "MEOW"}) + patched__request.assert_awaited_once_with( + expected_route, + json={"channel_id": "4560", "entity_type": scheduled_events.ScheduledEventType.EXTERNAL}, + reason=undefined.UNDEFINED, + ) async def test_edit_scheduled_event_when_changing_to_external_and_channel_id_not_explicitly_passed( - self, rest_client: rest.RESTClientImpl + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_scheduled_event: scheduled_events.ScheduledEvent, ): - expected_route = routes.PATCH_GUILD_SCHEDULED_EVENT.compile(guild=345543, scheduled_event=123321123) - rest_client._request = mock.AsyncMock(return_value={"id": "494949", "name": "MEOW"}) + expected_route = routes.PATCH_GUILD_SCHEDULED_EVENT.compile(guild=123, scheduled_event=888) - result = await rest_client.edit_scheduled_event( - StubModel(345543), StubModel(123321123), entity_type=scheduled_events.ScheduledEventType.EXTERNAL - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "494949", "name": "MEOW"} + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_scheduled_event" + ) as patched_deserialize_scheduled_event, + ): + result = await rest_client.edit_scheduled_event( + mock_partial_guild, mock_scheduled_event, entity_type=scheduled_events.ScheduledEventType.EXTERNAL + ) - assert result is rest_client._entity_factory.deserialize_scheduled_event.return_value - rest_client._entity_factory.deserialize_scheduled_event.assert_called_once_with( - {"id": "494949", "name": "MEOW"} - ) - rest_client._request.assert_awaited_once_with( - expected_route, - json={"channel_id": None, "entity_type": scheduled_events.ScheduledEventType.EXTERNAL}, - reason=undefined.UNDEFINED, - ) + assert result is patched_deserialize_scheduled_event.return_value + patched_deserialize_scheduled_event.assert_called_once_with({"id": "494949", "name": "MEOW"}) + patched__request.assert_awaited_once_with( + expected_route, + json={"channel_id": None, "entity_type": scheduled_events.ScheduledEventType.EXTERNAL}, + reason=undefined.UNDEFINED, + ) - async def test_delete_scheduled_event(self, rest_client: rest.RESTClientImpl): - expected_route = routes.DELETE_GUILD_SCHEDULED_EVENT.compile(guild=54531123, scheduled_event=123321123321) - rest_client._request = mock.AsyncMock() + async def test_delete_scheduled_event( + self, + rest_client: rest.RESTClientImpl, + mock_partial_guild: guilds.PartialGuild, + mock_scheduled_event: scheduled_events.ScheduledEvent, + ): + expected_route = routes.DELETE_GUILD_SCHEDULED_EVENT.compile(guild=123, scheduled_event=888) - await rest_client.delete_scheduled_event(StubModel(54531123), StubModel(123321123321)) + with ( + mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_scheduled_event"), + ): + await rest_client.delete_scheduled_event(mock_partial_guild, mock_scheduled_event) - rest_client._request.assert_awaited_once_with(expected_route) + patched__request.assert_awaited_once_with(expected_route) - async def test_fetch_stage_instance(self, rest_client): - expected_route = routes.GET_STAGE_INSTANCE.compile(channel=123) + async def test_fetch_stage_instance( + self, rest_client: rest.RESTClientImpl, mock_guild_stage_channel: channels.GuildStageChannel + ): + expected_route = routes.GET_STAGE_INSTANCE.compile(channel=45613) mock_payload = { "id": "8406", "guild_id": "19703", @@ -7223,17 +9504,29 @@ async def test_fetch_stage_instance(self, rest_client): "privacy_level": 1, "discoverable_disabled": False, } - rest_client._request = mock.AsyncMock(return_value=mock_payload) - result = await rest_client.fetch_stage_instance(channel=StubModel(123)) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=mock_payload + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_stage_instance" + ) as patched_deserialize_stage_instance, + ): + result = await rest_client.fetch_stage_instance(channel=mock_guild_stage_channel) - assert result is rest_client._entity_factory.deserialize_stage_instance.return_value - rest_client._request.assert_called_once_with(expected_route) - rest_client._entity_factory.deserialize_stage_instance.assert_called_once_with(mock_payload) + assert result is patched_deserialize_stage_instance.return_value + patched__request.assert_called_once_with(expected_route) + patched_deserialize_stage_instance.assert_called_once_with(mock_payload) - async def test_create_stage_instance(self, rest_client): + async def test_create_stage_instance( + self, + rest_client: rest.RESTClientImpl, + mock_guild_stage_channel: channels.GuildStageChannel, + mock_scheduled_event: scheduled_events.ScheduledEvent, + ): expected_route = routes.POST_STAGE_INSTANCE.compile() - expected_json = {"channel_id": "7334", "topic": "ur mom", "guild_scheduled_event_id": "3361203239"} + expected_json = {"channel_id": "45613", "topic": "ur mom", "guild_scheduled_event_id": "888"} mock_payload = { "id": "8406", "guild_id": "19703", @@ -7243,18 +9536,30 @@ async def test_create_stage_instance(self, rest_client): "guild_scheduled_event_id": "3361203239", "discoverable_disabled": False, } - rest_client._request = mock.AsyncMock(return_value=mock_payload) - result = await rest_client.create_stage_instance( - channel=StubModel(7334), topic="ur mom", scheduled_event_id=StubModel(3361203239), reason="testing" - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=mock_payload + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_stage_instance" + ) as patched_deserialize_stage_instance, + ): + result = await rest_client.create_stage_instance( + channel=mock_guild_stage_channel, + topic="ur mom", + scheduled_event_id=mock_scheduled_event, + reason="testing", + ) - assert result is rest_client._entity_factory.deserialize_stage_instance.return_value - rest_client._request.assert_called_once_with(expected_route, json=expected_json, reason="testing") - rest_client._entity_factory.deserialize_stage_instance.assert_called_once_with(mock_payload) + assert result is patched_deserialize_stage_instance.return_value + patched__request.assert_called_once_with(expected_route, json=expected_json, reason="testing") + patched_deserialize_stage_instance.assert_called_once_with(mock_payload) - async def test_edit_stage_instance(self, rest_client): - expected_route = routes.PATCH_STAGE_INSTANCE.compile(channel=7334) + async def test_edit_stage_instance( + self, rest_client: rest.RESTClientImpl, mock_guild_stage_channel: channels.GuildStageChannel + ): + expected_route = routes.PATCH_STAGE_INSTANCE.compile(channel=45613) expected_json = {"topic": "ur mom", "privacy_level": 2} mock_payload = { "id": "8406", @@ -7264,109 +9569,173 @@ async def test_edit_stage_instance(self, rest_client): "privacy_level": 2, "discoverable_disabled": False, } - rest_client._request = mock.AsyncMock(return_value=mock_payload) - - result = await rest_client.edit_stage_instance( - channel=StubModel(7334), - topic="ur mom", - privacy_level=stage_instances.StageInstancePrivacyLevel.GUILD_ONLY, - reason="testing", - ) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=mock_payload + ) as patched__request, + mock.patch.object( + rest_client.entity_factory, "deserialize_stage_instance" + ) as patched_deserialize_stage_instance, + ): + result = await rest_client.edit_stage_instance( + channel=mock_guild_stage_channel, + topic="ur mom", + privacy_level=stage_instances.StageInstancePrivacyLevel.GUILD_ONLY, + reason="testing", + ) - assert result is rest_client._entity_factory.deserialize_stage_instance.return_value - rest_client._request.assert_called_once_with(expected_route, json=expected_json, reason="testing") - rest_client._entity_factory.deserialize_stage_instance.assert_called_once_with(mock_payload) + assert result is patched_deserialize_stage_instance.return_value + patched__request.assert_called_once_with(expected_route, json=expected_json, reason="testing") + patched_deserialize_stage_instance.assert_called_once_with(mock_payload) - async def test_delete_stage_instance(self, rest_client): - expected_route = routes.DELETE_STAGE_INSTANCE.compile(channel=7334) - rest_client._request = mock.AsyncMock() + async def test_delete_stage_instance( + self, rest_client: rest.RESTClientImpl, mock_guild_stage_channel: channels.GuildStageChannel + ): + expected_route = routes.DELETE_STAGE_INSTANCE.compile(channel=45613) - await rest_client.delete_stage_instance(channel=StubModel(7334), reason="testing") + with ( + mock.patch.object(rest_client, "_request", new_callable=mock.AsyncMock) as patched__request, + mock.patch.object(rest_client.entity_factory, "deserialize_stage_instance"), + ): + await rest_client.delete_stage_instance(channel=mock_guild_stage_channel, reason="testing") - rest_client._request.assert_called_once_with(expected_route, reason="testing") + patched__request.assert_called_once_with(expected_route, reason="testing") - async def test_fetch_poll_voters(self, rest_client: rest.RESTClientImpl): + async def test_fetch_poll_voters( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): expected_route = routes.GET_POLL_ANSWER.compile( - channel=StubModel(45874392), message=StubModel(398475938475), answer=StubModel(4) + channel=mock_guild_text_channel, message=mock_message, answer=snowflakes.Snowflake(4) ) - rest_client._request = mock.AsyncMock(return_value=[{"id": "1234"}]) - - with mock.patch.object( - rest_client._entity_factory, "deserialize_user", return_value=mock.Mock() - ) as patched_deserialize_user: + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[{"id": "1234"}] + ) as patched__request, + mock.patch.object( + rest_client._entity_factory, "deserialize_user", return_value=mock.Mock() + ) as patched_deserialize_user, + ): await rest_client.fetch_poll_voters( - StubModel(45874392), StubModel(398475938475), StubModel(4), after=StubModel(43587935), limit=6 + mock_guild_text_channel, + mock_message, + snowflakes.Snowflake(4), + after=snowflakes.Snowflake(574893), + limit=6, ) - patched_deserialize_user.assert_called_once_with({"id": "1234"}) + patched_deserialize_user.assert_called_once_with({"id": "1234"}) - rest_client._request.assert_awaited_once_with(expected_route, query={"after": "43587935", "limit": "6"}) + patched__request.assert_awaited_once_with(expected_route, query={"after": "574893", "limit": "6"}) - async def test_end_poll(self, rest_client: rest.RESTClientImpl): + async def test_end_poll( + self, + rest_client: rest.RESTClientImpl, + mock_guild_text_channel: channels.GuildTextChannel, + mock_message: messages.Message, + ): expected_route = routes.POST_EXPIRE_POLL.compile( - channel=StubModel(45874392), message=StubModel(398475938475), answer=StubModel(4) + channel=mock_guild_text_channel, message=mock_message, answer=snowflakes.Snowflake(4) ) message_obj = mock.Mock() - rest_client._request = mock.AsyncMock(return_value={"id": "398475938475"}) - - rest_client._entity_factory.deserialize_message = mock.Mock(return_value=message_obj) - - response = await rest_client.end_poll(StubModel(45874392), StubModel(398475938475)) - - rest_client._request.assert_awaited_once_with(expected_route) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value={"id": "398475938475"} + ) as patched__request, + mock.patch.object( + rest_client._entity_factory, "deserialize_message", return_value=message_obj + ) as patched_deserialize_message, + ): + response = await rest_client.end_poll(mock_guild_text_channel, mock_message) assert response is message_obj - async def test_fetch_auto_mod_rules(self, rest_client: rest.RESTClientImpl): + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_message.assert_called_once_with(patched__request.return_value) + + async def test_fetch_auto_mod_rules( + self, rest_client: rest.RESTClientImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_payload_1 = {"id": "432123"} mock_payload_2 = {"id": "949494994"} mock_result_1 = mock.Mock() mock_result_2 = mock.Mock() - rest_client._entity_factory.deserialize_auto_mod_rule.side_effect = [mock_result_1, mock_result_2] - expected_route = routes.GET_GUILD_AUTO_MODERATION_RULES.compile(guild=123321) - rest_client._request = mock.AsyncMock(return_value=[mock_payload_1, mock_payload_2]) + expected_route = routes.GET_GUILD_AUTO_MODERATION_RULES.compile(guild=123) - result = await rest_client.fetch_auto_mod_rules(StubModel(123321)) + with ( + mock.patch.object( + rest_client, "_request", new_callable=mock.AsyncMock, return_value=[mock_payload_1, mock_payload_2] + ) as patched__request, + mock.patch.object( + rest_client._entity_factory, "deserialize_auto_mod_rule", side_effect=[mock_result_1, mock_result_2] + ) as patched_deserialize_auto_mod_rule, + ): + result = await rest_client.fetch_auto_mod_rules(hikari_partial_guild) assert result == [mock_result_1, mock_result_2] - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_auto_mod_rule.assert_has_calls( - [mock.call(mock_payload_1), mock.call(mock_payload_2)] - ) + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_auto_mod_rule.assert_has_calls([mock.call(mock_payload_1), mock.call(mock_payload_2)]) - async def test_fetch_auto_mod_rule(self, rest_client: rest.RESTClientImpl): - expected_route = routes.GET_GUILD_AUTO_MODERATION_RULE.compile(guild=123321, rule=5443123) - rest_client._request = mock.AsyncMock(return_value={"id": "442123"}) + async def test_fetch_auto_mod_rule( + self, + rest_client: rest.RESTClientImpl, + hikari_partial_guild: guilds.PartialGuild, + mock_auto_mod_rule: auto_mod.AutoModRule, + ): + expected_route = routes.GET_GUILD_AUTO_MODERATION_RULE.compile(guild=123, rule=999) - result = await rest_client.fetch_auto_mod_rule(StubModel(123321), StubModel(5443123)) + with ( + mock.patch.object( + rest_client, "_request", new=mock.AsyncMock(return_value={"id": "442123"}) + ) as patched__request, + mock.patch.object( + rest_client._entity_factory, "deserialize_auto_mod_rule" + ) as patched_deserialize_auto_mod_rule, + ): + result = await rest_client.fetch_auto_mod_rule(hikari_partial_guild, mock_auto_mod_rule) - assert result is rest_client._entity_factory.deserialize_auto_mod_rule.return_value - rest_client._request.assert_awaited_once_with(expected_route) - rest_client._entity_factory.deserialize_auto_mod_rule.assert_called_once_with(rest_client._request.return_value) + assert result is patched_deserialize_auto_mod_rule.return_value + patched__request.assert_awaited_once_with(expected_route) + patched_deserialize_auto_mod_rule.assert_called_once_with(patched__request.return_value) - async def test_create_auto_mod_rule(self, rest_client: rest.RESTClientImpl): + async def test_create_auto_mod_rule( + self, rest_client: rest.RESTClientImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_action = mock.Mock(special_endpoints.AutoModBlockMessageActionBuilder) - expected_route = routes.POST_GUILD_AUTO_MODERATION_RULE.compile(guild=123321) - rest_client._request = mock.AsyncMock(return_value={"id": "494949494"}) - - result = await rest_client.create_auto_mod_rule( - StubModel(123321), - name="meow", - event_type=auto_mod.AutoModEventType.MESSAGE_SEND, - trigger=special_endpoints.AutoModKeywordTriggerBuilder(keyword_filter=["hello", "world"]), - actions=[mock_action], - enabled=False, - exempt_roles=[StubModel(4212), StubModel(43123)], - exempt_channels=[StubModel(566), StubModel(333), StubModel(222)], - reason="a reason meow", - ) + expected_route = routes.POST_GUILD_AUTO_MODERATION_RULE.compile(guild=123) - assert result is rest_client._entity_factory.deserialize_auto_mod_rule.return_value - rest_client._entity_factory.deserialize_auto_mod_rule.assert_called_once_with(rest_client._request.return_value) - rest_client._request.assert_awaited_once_with( + with ( + mock.patch.object( + rest_client, "_request", new=mock.AsyncMock(return_value={"id": "494949494"}) + ) as patched__request, + mock.patch.object( + rest_client._entity_factory, "deserialize_auto_mod_rule" + ) as patched_deserialize_auto_mod_rule, + ): + result = await rest_client.create_auto_mod_rule( + hikari_partial_guild, + name="meow", + event_type=auto_mod.AutoModEventType.MESSAGE_SEND, + trigger=special_endpoints.AutoModKeywordTriggerBuilder(keyword_filter=["hello", "world"]), + actions=[mock_action], + enabled=False, + exempt_roles=[make_partial_role(4212), make_partial_role(43123)], + exempt_channels=[ + make_guild_text_channel(566), + make_guild_text_channel(333), + make_guild_text_channel(222), + ], + reason="a reason meow", + ) + + assert result is patched_deserialize_auto_mod_rule.return_value + patched_deserialize_auto_mod_rule.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with( expected_route, json={ "name": "meow", @@ -7382,22 +9751,31 @@ async def test_create_auto_mod_rule(self, rest_client: rest.RESTClientImpl): ) mock_action.build.assert_called_once_with() - async def test_create_auto_mod_rule_partial(self, rest_client: rest.RESTClientImpl): + async def test_create_auto_mod_rule_partial( + self, rest_client: rest.RESTClientImpl, hikari_partial_guild: guilds.PartialGuild + ): mock_action = mock.Mock(special_endpoints.AutoModBlockMessageActionBuilder) - expected_route = routes.POST_GUILD_AUTO_MODERATION_RULE.compile(guild=123321) - rest_client._request = mock.AsyncMock(return_value={"id": "494949494"}) + expected_route = routes.POST_GUILD_AUTO_MODERATION_RULE.compile(guild=123) - result = await rest_client.create_auto_mod_rule( - StubModel(123321), - name="meow", - event_type=auto_mod.AutoModEventType.MESSAGE_SEND, - trigger=special_endpoints.AutoModKeywordTriggerBuilder(keyword_filter=["hello", "world"]), - actions=[mock_action], - ) + with ( + mock.patch.object( + rest_client, "_request", new=mock.AsyncMock(return_value={"id": "494949494"}) + ) as patched__request, + mock.patch.object( + rest_client._entity_factory, "deserialize_auto_mod_rule" + ) as patched_deserialize_auto_mod_rule, + ): + result = await rest_client.create_auto_mod_rule( + hikari_partial_guild, + name="meow", + event_type=auto_mod.AutoModEventType.MESSAGE_SEND, + trigger=special_endpoints.AutoModKeywordTriggerBuilder(keyword_filter=["hello", "world"]), + actions=[mock_action], + ) - assert result is rest_client._entity_factory.deserialize_auto_mod_rule.return_value - rest_client._entity_factory.deserialize_auto_mod_rule.assert_called_once_with(rest_client._request.return_value) - rest_client._request.assert_awaited_once_with( + assert result is patched_deserialize_auto_mod_rule.return_value + patched_deserialize_auto_mod_rule.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with( expected_route, json={ "name": "meow", @@ -7410,27 +9788,43 @@ async def test_create_auto_mod_rule_partial(self, rest_client: rest.RESTClientIm ) mock_action.build.assert_called_once_with() - async def test_edit_auto_mod_rule(self, rest_client: rest.RESTClientImpl): + async def test_edit_auto_mod_rule( + self, + rest_client: rest.RESTClientImpl, + hikari_partial_guild: guilds.PartialGuild, + mock_auto_mod_rule: auto_mod.AutoModRule, + ): mock_action = mock.Mock(special_endpoints.AutoModBlockMessageActionBuilder) - expected_route = routes.PATCH_GUILD_AUTO_MODERATION_RULE.compile(guild=123321, rule=5412123) - rest_client._request = mock.AsyncMock(return_value={"id": "494949494"}) - - result = await rest_client.edit_auto_mod_rule( - StubModel(123321), - StubModel(5412123), - name="meow", - event_type=auto_mod.AutoModEventType.MESSAGE_SEND, - trigger=special_endpoints.AutoModKeywordTriggerBuilder(keyword_filter=["hello", "world"]), - actions=[mock_action], - enabled=False, - exempt_roles=[StubModel(4545), StubModel(5656)], - exempt_channels=[StubModel(555), StubModel(666), StubModel(777)], - reason="nyaa nyaa", - ) + expected_route = routes.PATCH_GUILD_AUTO_MODERATION_RULE.compile(guild=123, rule=999) - assert result is rest_client._entity_factory.deserialize_auto_mod_rule.return_value - rest_client._entity_factory.deserialize_auto_mod_rule.assert_called_once_with(rest_client._request.return_value) - rest_client._request.assert_awaited_once_with( + with ( + mock.patch.object( + rest_client, "_request", new=mock.AsyncMock(return_value={"id": "494949494"}) + ) as patched__request, + mock.patch.object( + rest_client._entity_factory, "deserialize_auto_mod_rule" + ) as patched_deserialize_auto_mod_rule, + ): + result = await rest_client.edit_auto_mod_rule( + hikari_partial_guild, + mock_auto_mod_rule, + name="meow", + event_type=auto_mod.AutoModEventType.MESSAGE_SEND, + trigger=special_endpoints.AutoModKeywordTriggerBuilder(keyword_filter=["hello", "world"]), + actions=[mock_action], + enabled=False, + exempt_roles=[make_partial_role(4545), make_partial_role(5656)], + exempt_channels=[ + make_guild_text_channel(555), + make_guild_text_channel(666), + make_guild_text_channel(777), + ], + reason="nyaa nyaa", + ) + + assert result is patched_deserialize_auto_mod_rule.return_value + patched_deserialize_auto_mod_rule.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with( expected_route, json={ "name": "meow", @@ -7445,22 +9839,38 @@ async def test_edit_auto_mod_rule(self, rest_client: rest.RESTClientImpl): ) mock_action.build.assert_called_once_with() - async def test_edit_auto_mod_rule_partial(self, rest_client: rest.RESTClientImpl): - expected_route = routes.PATCH_GUILD_AUTO_MODERATION_RULE.compile(guild=123321, rule=44332222) - rest_client._request = mock.AsyncMock(return_value={"id": "494949494"}) + async def test_edit_auto_mod_rule_partial( + self, + rest_client: rest.RESTClientImpl, + hikari_partial_guild: guilds.PartialGuild, + mock_auto_mod_rule: auto_mod.AutoModRule, + ): + expected_route = routes.PATCH_GUILD_AUTO_MODERATION_RULE.compile(guild=123, rule=999) - result = await rest_client.edit_auto_mod_rule(StubModel(123321), StubModel(44332222)) + with ( + mock.patch.object( + rest_client, "_request", new=mock.AsyncMock(return_value={"id": "494949494"}) + ) as patched__request, + mock.patch.object( + rest_client._entity_factory, "deserialize_auto_mod_rule" + ) as patched_deserialize_auto_mod_rule, + ): + result = await rest_client.edit_auto_mod_rule(hikari_partial_guild, mock_auto_mod_rule) - assert result is rest_client._entity_factory.deserialize_auto_mod_rule.return_value - rest_client._entity_factory.deserialize_auto_mod_rule.assert_called_once_with(rest_client._request.return_value) - rest_client._request.assert_awaited_once_with(expected_route, json={}, reason=undefined.UNDEFINED) - rest_client._entity_factory.serialize_auto_mod_action.assert_not_called() + assert result is patched_deserialize_auto_mod_rule.return_value + patched_deserialize_auto_mod_rule.assert_called_once_with(patched__request.return_value) + patched__request.assert_awaited_once_with(expected_route, json={}, reason=undefined.UNDEFINED) - async def test_delete_auto_mod_rule(self, rest_client: rest.RESTClientImpl): - expected_route = routes.DELETE_GUILD_AUTO_MODERATION_RULE.compile(guild=54123, rule=651234) - rest_client._request = mock.AsyncMock() + async def test_delete_auto_mod_rule( + self, + rest_client: rest.RESTClientImpl, + hikari_partial_guild: guilds.PartialGuild, + mock_auto_mod_rule: auto_mod.AutoModRule, + ): + expected_route = routes.DELETE_GUILD_AUTO_MODERATION_RULE.compile(guild=123, rule=999) - result = await rest_client.delete_auto_mod_rule(StubModel(54123), StubModel(651234), reason="ok hi") + with mock.patch.object(rest_client, "_request", new=mock.AsyncMock()) as patched__request: + result = await rest_client.delete_auto_mod_rule(hikari_partial_guild, mock_auto_mod_rule, reason="ok hi") assert result is None - rest_client._request.assert_awaited_once_with(expected_route, reason="ok hi") + patched__request.assert_awaited_once_with(expected_route, reason="ok hi") diff --git a/tests/hikari/impl/test_rest_bot.py b/tests/hikari/impl/test_rest_bot.py index 5218de67de..a81ebff5e1 100644 --- a/tests/hikari/impl/test_rest_bot.py +++ b/tests/hikari/impl/test_rest_bot.py @@ -22,7 +22,6 @@ import asyncio import concurrent.futures -import contextlib import sys import mock @@ -42,52 +41,47 @@ class TestRESTBot: @pytest.fixture - def mock_interaction_server(self): + def mock_interaction_server(self) -> interaction_server_impl.InteractionServer: return mock.Mock(interaction_server_impl.InteractionServer) @pytest.fixture - def mock_rest_client(self): + def mock_rest_client(self) -> rest_impl.RESTClientImpl: return mock.Mock(rest_impl.RESTClientImpl) @pytest.fixture - def mock_entity_factory(self): + def mock_entity_factory(self) -> entity_factory_impl.EntityFactoryImpl: return mock.Mock(entity_factory_impl.EntityFactoryImpl) @pytest.fixture - def mock_http_settings(self): + def mock_http_settings(self) -> config.HTTPSettings: return mock.Mock(config.HTTPSettings) @pytest.fixture - def mock_proxy_settings(self): + def mock_proxy_settings(self) -> config.ProxySettings: return mock.Mock(config.ProxySettings) @pytest.fixture - def mock_executor(self): + def mock_executor(self) -> concurrent.futures.Executor: return mock.Mock(concurrent.futures.Executor) @pytest.fixture def mock_rest_bot( self, - mock_interaction_server, - mock_rest_client, - mock_entity_factory, - mock_http_settings, - mock_proxy_settings, - mock_executor, + mock_interaction_server: interaction_server_impl.InteractionServer, + mock_rest_client: rest_impl.RESTClientImpl, + mock_entity_factory: entity_factory_impl.EntityFactoryImpl, + mock_http_settings: config.HTTPSettings, + mock_proxy_settings: config.ProxySettings, + mock_executor: concurrent.futures.Executor, ): - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(ux, "init_logging")) - stack.enter_context(mock.patch.object(ux, "warn_if_not_optimized")) - stack.enter_context(mock.patch.object(rest_bot_impl.RESTBot, "print_banner")) - stack.enter_context( - mock.patch.object(entity_factory_impl, "EntityFactoryImpl", return_value=mock_entity_factory) - ) - stack.enter_context(mock.patch.object(rest_impl, "RESTClientImpl", return_value=mock_rest_client)) - stack.enter_context( - mock.patch.object(interaction_server_impl, "InteractionServer", return_value=mock_interaction_server) - ) - - with stack: + with ( + mock.patch.object(ux, "init_logging"), + mock.patch.object(ux, "warn_if_not_optimized"), + mock.patch.object(rest_bot_impl.RESTBot, "print_banner"), + mock.patch.object(entity_factory_impl, "EntityFactoryImpl", return_value=mock_entity_factory), + mock.patch.object(rest_impl, "RESTClientImpl", return_value=mock_rest_client), + mock.patch.object(interaction_server_impl, "InteractionServer", return_value=mock_interaction_server), + ): return hikari_test_helpers.mock_class_namespace(rest_bot_impl.RESTBot, slots_=False)( "token", http_settings=mock_http_settings, @@ -97,25 +91,27 @@ def mock_rest_bot( ) def test___init__( - self, mock_http_settings, mock_proxy_settings, mock_entity_factory, mock_rest_client, mock_interaction_server + self, + mock_http_settings: config.HTTPSettings, + mock_proxy_settings: config.ProxySettings, + mock_entity_factory: entity_factory_impl.EntityFactoryImpl, + mock_rest_client: rest_impl.RESTClientImpl, + mock_interaction_server: interaction_server_impl.InteractionServer, ): mock_executor = mock.Mock() - stack = contextlib.ExitStack() - patched_init_logging = stack.enter_context(mock.patch.object(ux, "init_logging")) - patched_warn_if_not_optimized = stack.enter_context(mock.patch.object(ux, "warn_if_not_optimized")) - patched_print_banner = stack.enter_context(mock.patch.object(rest_bot_impl.RESTBot, "print_banner")) - patched_entity_factory = stack.enter_context( - mock.patch.object(entity_factory_impl, "EntityFactoryImpl", return_value=mock_entity_factory) - ) - patched_rest_client = stack.enter_context( - mock.patch.object(rest_impl, "RESTClientImpl", return_value=mock_rest_client) - ) - patched_interaction_server = stack.enter_context( - mock.patch.object(interaction_server_impl, "InteractionServer", return_value=mock_interaction_server) - ) - - with stack: + with ( + mock.patch.object(ux, "init_logging") as patched_init_logging, + mock.patch.object(ux, "warn_if_not_optimized") as patched_warn_if_not_optimized, + mock.patch.object(rest_bot_impl.RESTBot, "print_banner") as patched_print_banner, + mock.patch.object( + entity_factory_impl, "EntityFactoryImpl", return_value=mock_entity_factory + ) as patched_entity_factory, + mock.patch.object(rest_impl, "RESTClientImpl", return_value=mock_rest_client) as patched_rest_client, + mock.patch.object( + interaction_server_impl, "InteractionServer", return_value=mock_interaction_server + ) as patched_interaction_server, + ): result = rest_bot_impl.RESTBot( "token", "token_type", @@ -160,34 +156,28 @@ def test___init__( assert result.executor is mock_executor def test___init___parses_string_public_key(self): - cls = hikari_test_helpers.mock_class_namespace(rest_bot_impl.RESTBot, print_banner=mock.Mock()) - - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(ux, "init_logging")) - stack.enter_context(mock.patch.object(ux, "warn_if_not_optimized")) - stack.enter_context(mock.patch.object(rest_bot_impl.RESTBot, "print_banner")) - stack.enter_context(mock.patch.object(interaction_server_impl, "InteractionServer")) - - with stack: - result = cls(object(), "token_type", "6f66646f646f646f6f") - - interaction_server_impl.InteractionServer.assert_called_once_with( - entity_factory=result.entity_factory, public_key=b"ofdododoo", rest_client=result.rest - ) + with ( + mock.patch.object(ux, "init_logging"), + mock.patch.object(ux, "warn_if_not_optimized"), + mock.patch.object(rest_bot_impl.RESTBot, "print_banner"), + mock.patch.object(interaction_server_impl, "InteractionServer") as interaction_server, + ): + result = rest_bot_impl.RESTBot(mock.Mock(), "token_type", "6f66646f646f646f6f") + + interaction_server.assert_called_once_with( + entity_factory=result.entity_factory, public_key=b"ofdododoo", rest_client=result.rest + ) def test___init___strips_token(self): - cls = hikari_test_helpers.mock_class_namespace(rest_bot_impl.RESTBot, print_banner=mock.Mock()) - - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(ux, "init_logging")) - stack.enter_context(mock.patch.object(ux, "warn_if_not_optimized")) - rest_client = stack.enter_context(mock.patch.object(rest_impl, "RESTClientImpl")) - http_settings = stack.enter_context(mock.patch.object(config, "HTTPSettings")) - proxy_settings = stack.enter_context(mock.patch.object(config, "ProxySettings")) - stack.enter_context(mock.patch.object(interaction_server_impl, "InteractionServer")) - - with stack: - result = cls("\n\r sddsa tokenoken \n", "token_type") + with ( + mock.patch.object(ux, "init_logging"), + mock.patch.object(ux, "warn_if_not_optimized"), + mock.patch.object(rest_impl, "RESTClientImpl") as rest_client, + mock.patch.object(config, "HTTPSettings") as http_settings, + mock.patch.object(config, "ProxySettings") as proxy_settings, + mock.patch.object(interaction_server_impl, "InteractionServer"), + ): + result = rest_bot_impl.RESTBot("\n\r sddsa tokenoken \n", "token_type") rest_client.assert_called_once_with( cache=None, @@ -203,43 +193,42 @@ def test___init___strips_token(self): ) def test___init___generates_default_settings(self): - cls = hikari_test_helpers.mock_class_namespace(rest_bot_impl.RESTBot, print_banner=mock.Mock()) - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(ux, "init_logging")) - stack.enter_context(mock.patch.object(ux, "warn_if_not_optimized")) - stack.enter_context(mock.patch.object(rest_bot_impl.RESTBot, "print_banner")) - stack.enter_context(mock.patch.object(rest_impl, "RESTClientImpl")) - stack.enter_context(mock.patch.object(config, "HTTPSettings")) - stack.enter_context(mock.patch.object(config, "ProxySettings")) - stack.enter_context(mock.patch.object(interaction_server_impl, "InteractionServer")) - - with stack: - result = cls("token") - - rest_impl.RESTClientImpl.assert_called_once_with( - cache=None, - entity_factory=result.entity_factory, - executor=None, - http_settings=config.HTTPSettings.return_value, - max_rate_limit=300.0, - max_retries=3, - proxy_settings=config.ProxySettings.return_value, - rest_url=None, - token="token", - token_type="Bot", - ) - - config.HTTPSettings.assert_called_once() - config.ProxySettings.assert_called_once() - assert result.http_settings is config.HTTPSettings.return_value - assert result.proxy_settings is config.ProxySettings.return_value + with ( + mock.patch.object(ux, "init_logging"), + mock.patch.object(ux, "warn_if_not_optimized"), + mock.patch.object(rest_bot_impl.RESTBot, "print_banner"), + mock.patch.object(rest_impl, "RESTClientImpl") as rest_client, + mock.patch.object(config, "HTTPSettings") as http_settings, + mock.patch.object(config, "ProxySettings") as proxy_settings, + mock.patch.object(interaction_server_impl, "InteractionServer"), + ): + result = rest_bot_impl.RESTBot("token") - @pytest.mark.parametrize(("close_event", "expected"), [(object(), True), (None, False)]) - def test_is_alive_property(self, mock_rest_bot, close_event, expected): + rest_client.assert_called_once_with( + cache=None, + entity_factory=result.entity_factory, + executor=None, + http_settings=http_settings.return_value, + max_rate_limit=300.0, + max_retries=3, + proxy_settings=proxy_settings.return_value, + rest_url=None, + token="token", + token_type="Bot", + ) + http_settings.assert_called_once_with() + proxy_settings.assert_called_once_with() + assert result.http_settings is http_settings.return_value + assert result.proxy_settings is proxy_settings.return_value + + @pytest.mark.parametrize(("close_event", "expected"), [(mock.Mock(), True), (None, False)]) + def test_is_alive_property( + self, mock_rest_bot: rest_bot_impl.RESTBot, close_event: asyncio.Event | None, expected: bool + ): mock_rest_bot._close_event = close_event assert mock_rest_bot.is_alive is expected - def test_print_banner(self, mock_rest_bot): + def test_print_banner(self, mock_rest_bot: rest_bot_impl.RESTBot): with mock.patch.object(ux, "print_banner") as print_banner: mock_rest_bot.print_banner( "okokok", allow_color=True, force_color=False, extra_args={"test_key": "test_value"} @@ -291,7 +280,10 @@ def test_remove_startup_callback_when_not_present(self, mock_rest_bot: rest_bot_ @pytest.mark.asyncio async def test_close( - self, mock_rest_bot: rest_bot_impl.RESTBot, mock_interaction_server: mock.Mock, mock_rest_client: mock.Mock + self, + mock_rest_bot: rest_bot_impl.RESTBot, + mock_interaction_server: interaction_server_impl.InteractionServer, + mock_rest_client: rest_impl.RESTClientImpl, ): mock_shutdown_1 = mock.AsyncMock() mock_shutdown_2 = mock.AsyncMock() @@ -312,7 +304,10 @@ async def test_close( @pytest.mark.asyncio async def test_close_when_shutdown_callback_raises( - self, mock_rest_bot: rest_bot_impl.RESTBot, mock_interaction_server: mock.Mock, mock_rest_client: mock.Mock + self, + mock_rest_bot: rest_bot_impl.RESTBot, + mock_interaction_server: interaction_server_impl.InteractionServer, + mock_rest_client: rest_impl.RESTClientImpl, ): mock_error = KeyError("Too many catgirls") mock_shutdown_1 = mock.AsyncMock(side_effect=mock_error) @@ -336,7 +331,10 @@ async def test_close_when_shutdown_callback_raises( @pytest.mark.asyncio async def test_close_when_is_closing( - self, mock_rest_bot: rest_bot_impl.RESTBot, mock_interaction_server: mock.Mock, mock_rest_client: mock.Mock + self, + mock_rest_bot: rest_bot_impl.RESTBot, + mock_interaction_server: interaction_server_impl.InteractionServer, + mock_rest_client: rest_impl.RESTClientImpl, ): mock_shutdown_1 = mock.AsyncMock() mock_shutdown_2 = mock.AsyncMock() @@ -358,12 +356,12 @@ async def test_close_when_is_closing( mock_shutdown_2.assert_not_called() @pytest.mark.asyncio - async def test_close_when_inactive(self, mock_rest_bot): + async def test_close_when_inactive(self, mock_rest_bot: rest_bot_impl.RESTBot): with pytest.raises(errors.ComponentStateConflictError): await mock_rest_bot.close() @pytest.mark.asyncio - async def test_join(self, mock_rest_bot): + async def test_join(self, mock_rest_bot: rest_bot_impl.RESTBot): mock_rest_bot._close_event = mock.AsyncMock() await mock_rest_bot.join() @@ -371,12 +369,14 @@ async def test_join(self, mock_rest_bot): mock_rest_bot._close_event.wait.assert_awaited_once() @pytest.mark.asyncio - async def test_join_when_not_alive(self, mock_rest_bot): + async def test_join_when_not_alive(self, mock_rest_bot: rest_bot_impl.RESTBot): with pytest.raises(errors.ComponentStateConflictError): await mock_rest_bot.join() @pytest.mark.asyncio - async def test_on_interaction(self, mock_rest_bot, mock_interaction_server): + async def test_on_interaction( + self, mock_rest_bot: rest_bot_impl.RESTBot, mock_interaction_server: interaction_server_impl.InteractionServer + ): mock_interaction_server.on_interaction = mock.AsyncMock() result = await mock_rest_bot.on_interaction(b"1", b"2", b"3") @@ -384,21 +384,20 @@ async def test_on_interaction(self, mock_rest_bot, mock_interaction_server): assert result is mock_interaction_server.on_interaction.return_value mock_interaction_server.on_interaction.assert_awaited_once_with(b"1", b"2", b"3") - def test_run(self, mock_rest_bot): - mock_socket = object() - mock_context = object() + def test_run(self, mock_rest_bot: rest_bot_impl.RESTBot): + mock_socket = mock.Mock() + mock_context = mock.Mock() mock_rest_bot._executor = None mock_rest_bot.start = mock.Mock() mock_rest_bot.join = mock.Mock() - stack = contextlib.ExitStack() - check_for_updates = stack.enter_context(mock.patch.object(ux, "check_for_updates")) - handle_interrupts = stack.enter_context( - mock.patch.object(signals, "handle_interrupts", return_value=hikari_test_helpers.ContextManagerMock()) - ) - get_or_make_loop = stack.enter_context(mock.patch.object(aio, "get_or_make_loop")) - - with stack: + with ( + mock.patch.object(ux, "check_for_updates") as check_for_updates, + mock.patch.object( + signals, "handle_interrupts", return_value=hikari_test_helpers.ContextManagerMock() + ) as handle_interrupts, + mock.patch.object(aio, "get_or_make_loop") as get_or_make_loop, + ): mock_rest_bot.run( asyncio_debug=False, backlog=321, @@ -444,7 +443,7 @@ def test_run(self, mock_rest_bot): ) get_or_make_loop.return_value.close.assert_not_called() - def test_run_when_close_loop(self, mock_rest_bot): + def test_run_when_close_loop(self, mock_rest_bot: rest_bot_impl.RESTBot): mock_rest_bot.start = mock.Mock() mock_rest_bot.join = mock.Mock() @@ -454,7 +453,7 @@ def test_run_when_close_loop(self, mock_rest_bot): destroy_loop.assert_called_once_with(get_or_make_loop.return_value, rest_bot_impl._LOGGER) - def test_run_when_asyncio_debug(self, mock_rest_bot): + def test_run_when_asyncio_debug(self, mock_rest_bot: rest_bot_impl.RESTBot): mock_rest_bot.start = mock.Mock() mock_rest_bot.join = mock.Mock() @@ -463,7 +462,7 @@ def test_run_when_asyncio_debug(self, mock_rest_bot): get_or_make_loop.return_value.set_debug.assert_called_once_with(True) - def test_run_with_coroutine_tracking_depth(self, mock_rest_bot): + def test_run_with_coroutine_tracking_depth(self, mock_rest_bot: rest_bot_impl.RESTBot): mock_rest_bot.start = mock.Mock() mock_rest_bot.join = mock.Mock() @@ -475,13 +474,15 @@ def test_run_with_coroutine_tracking_depth(self, mock_rest_bot): set_tracking_depth.assert_called_once_with(42) - def test_run_when_already_running(self, mock_rest_bot): - mock_rest_bot._close_event = object() + def test_run_when_already_running(self, mock_rest_bot: rest_bot_impl.RESTBot): + mock_rest_bot._close_event = mock.Mock() with pytest.raises(errors.ComponentStateConflictError): mock_rest_bot.run() - def test_run_closes_executor_when_present(self, mock_rest_bot, mock_executor): + def test_run_closes_executor_when_present( + self, mock_rest_bot: rest_bot_impl.RESTBot, mock_executor: concurrent.futures.Executor + ): mock_rest_bot.start = mock.Mock() mock_rest_bot.join = mock.Mock() @@ -500,14 +501,14 @@ def test_run_closes_executor_when_present(self, mock_rest_bot, mock_executor): reuse_address=True, reuse_port=False, shutdown_timeout=534.534, - socket=object(), - ssl_context=object(), + socket=mock.Mock(), + ssl_context=mock.Mock(), ) mock_executor.shutdown.assert_called_once_with(wait=True) assert mock_rest_bot.executor is None - def test_run_ignores_close_executor_when_not_present(self, mock_rest_bot): + def test_run_ignores_close_executor_when_not_present(self, mock_rest_bot: rest_bot_impl.RESTBot): mock_rest_bot.start = mock.Mock() mock_rest_bot.join = mock.Mock() mock_rest_bot._executor = None @@ -527,25 +528,33 @@ def test_run_ignores_close_executor_when_not_present(self, mock_rest_bot): reuse_address=True, reuse_port=False, shutdown_timeout=534.534, - socket=object(), - ssl_context=object(), + socket=mock.Mock(), + ssl_context=mock.Mock(), ) assert mock_rest_bot.executor is None @pytest.mark.asyncio async def test_start( - self, mock_rest_bot: rest_bot_impl.RESTBot, mock_interaction_server: mock.Mock, mock_rest_client: mock.Mock + self, + mock_rest_bot: rest_bot_impl.RESTBot, + mock_interaction_server: interaction_server_impl.InteractionServer, + mock_rest_client: rest_impl.RESTClientImpl, ): - mock_socket = object() - mock_ssl_context = object() + mock_socket = mock.Mock() + mock_ssl_context = mock.Mock() mock_callback_1 = mock.AsyncMock() mock_callback_2 = mock.AsyncMock() mock_rest_bot.add_startup_callback(mock_callback_1) mock_rest_bot.add_startup_callback(mock_callback_2) mock_rest_bot._is_closing = True - with mock.patch.object(ux, "check_for_updates"): + with ( + mock.patch.object(mock_rest_client, "start") as patched_start, + mock.patch.object(mock_rest_client, "close") as patched_close, + mock.patch.object(mock_interaction_server, "start") as patched_interaction_server_start, + mock.patch.object(ux, "check_for_updates") as patched_check_for_updates, + ): await mock_rest_bot.start( backlog=34123, check_for_updates=False, @@ -559,9 +568,9 @@ async def test_start( ssl_context=mock_ssl_context, ) - ux.check_for_updates.assert_not_called() + patched_check_for_updates.assert_not_called() - mock_interaction_server.start.assert_awaited_once_with( + patched_interaction_server_start.assert_awaited_once_with( backlog=34123, host="hostostosot", port=123123123, @@ -572,18 +581,21 @@ async def test_start( shutdown_timeout=4312312.3132132, ssl_context=mock_ssl_context, ) - mock_rest_client.start.assert_called_once_with() - mock_rest_client.close.assert_not_called() + patched_start.assert_called_once_with() + patched_close.assert_not_called() assert mock_rest_bot._is_closing is False mock_callback_1.assert_awaited_once_with(mock_rest_bot) mock_callback_2.assert_awaited_once_with(mock_rest_bot) @pytest.mark.asyncio async def test_start_when_startup_callback_raises( - self, mock_rest_bot: rest_bot_impl.RESTBot, mock_interaction_server: mock.Mock, mock_rest_client: mock.Mock + self, + mock_rest_bot: rest_bot_impl.RESTBot, + mock_interaction_server: interaction_server_impl.InteractionServer, + mock_rest_client: rest_impl.RESTClientImpl, ): - mock_socket = object() - mock_ssl_context = object() + mock_socket = mock.Mock() + mock_ssl_context = mock.Mock() mock_rest_bot._is_closing = True mock_error = TypeError("Not a real catgirl") mock_callback_1 = mock.AsyncMock(side_effect=mock_error) @@ -591,7 +603,12 @@ async def test_start_when_startup_callback_raises( mock_rest_bot.add_startup_callback(mock_callback_1) mock_rest_bot.add_startup_callback(mock_callback_2) - with mock.patch.object(ux, "check_for_updates"): + with ( + mock.patch.object(mock_rest_client, "start") as patched_start, + mock.patch.object(mock_rest_client, "close") as patched_close, + mock.patch.object(mock_interaction_server, "start") as patched_interaction_server_start, + mock.patch.object(ux, "check_for_updates") as patched_check_for_updates, + ): with pytest.raises(TypeError) as exc_info: await mock_rest_bot.start( backlog=34123, @@ -607,22 +624,26 @@ async def test_start_when_startup_callback_raises( ) assert exc_info.value is mock_error - ux.check_for_updates.assert_not_called() + patched_check_for_updates.assert_not_called() - mock_interaction_server.start.assert_not_called() - mock_rest_client.start.assert_called_once_with() - mock_rest_client.close.assert_awaited_once_with() + patched_interaction_server_start.assert_not_called() + patched_start.assert_called_once_with() + patched_close.assert_awaited_once_with() assert mock_rest_bot._is_closing is False mock_callback_1.assert_awaited_once_with(mock_rest_bot) mock_callback_2.assert_not_called() @pytest.mark.asyncio - async def test_start_checks_for_update(self, mock_rest_bot, mock_http_settings, mock_proxy_settings): - stack = contextlib.ExitStack() - stack.enter_context(mock.patch.object(asyncio, "create_task")) - stack.enter_context(mock.patch.object(ux, "check_for_updates", new=mock.Mock())) - - with stack: + async def test_start_checks_for_update( + self, + mock_rest_bot: rest_bot_impl.RESTBot, + mock_http_settings: config.HTTPSettings, + mock_proxy_settings: config.ProxySettings, + ): + with ( + mock.patch.object(asyncio, "create_task") as patched_create_task, + mock.patch.object(ux, "check_for_updates", new=mock.Mock()), + ): await mock_rest_bot.start( backlog=34123, check_for_updates=True, @@ -631,19 +652,19 @@ async def test_start_checks_for_update(self, mock_rest_bot, mock_http_settings, path="patpatpapt", reuse_address=True, reuse_port=False, - socket=object(), + socket=mock.Mock(), shutdown_timeout=4312312.3132132, - ssl_context=object(), + ssl_context=mock.Mock(), ) - asyncio.create_task.assert_called_once_with( + patched_create_task.assert_called_once_with( ux.check_for_updates.return_value, name="check for package updates" ) ux.check_for_updates.assert_called_once_with(mock_http_settings, mock_proxy_settings) @pytest.mark.asyncio - async def test_start_when_is_alive(self, mock_rest_bot): - mock_rest_bot._close_event = object() + async def test_start_when_is_alive(self, mock_rest_bot: rest_bot_impl.RESTBot): + mock_rest_bot._close_event = mock.Mock() with mock.patch.object(ux, "check_for_updates", new=mock.Mock()) as check_for_updates: with pytest.raises(errors.ComponentStateConflictError): @@ -651,17 +672,21 @@ async def test_start_when_is_alive(self, mock_rest_bot): check_for_updates.assert_not_called() - def test_get_listener(self, mock_rest_bot, mock_interaction_server): - mock_type = object() + def test_get_listener( + self, mock_rest_bot: rest_bot_impl.RESTBot, mock_interaction_server: interaction_server_impl.InteractionServer + ): + mock_type = mock.Mock() result = mock_rest_bot.get_listener(mock_type) assert result is mock_interaction_server.get_listener.return_value mock_interaction_server.get_listener.assert_called_once_with(mock_type) - def test_set_listener(self, mock_rest_bot, mock_interaction_server): - mock_type = object() - mock_listener = object() + def test_set_listener( + self, mock_rest_bot: rest_bot_impl.RESTBot, mock_interaction_server: interaction_server_impl.InteractionServer + ): + mock_type = mock.Mock() + mock_listener = mock.Mock() mock_rest_bot.set_listener(mock_type, mock_listener, replace=True) diff --git a/tests/hikari/impl/test_shard.py b/tests/hikari/impl/test_shard.py index a480400645..12cb336545 100644 --- a/tests/hikari/impl/test_shard.py +++ b/tests/hikari/impl/test_shard.py @@ -27,6 +27,7 @@ import platform import re import sys +import typing import aiohttp import mock @@ -77,7 +78,7 @@ def test__serialize_activity_when_activity_is_not_None(): [("Testing!", None, "Custom Status", "Testing!"), ("Blah name!", "Testing!", "Blah name!", "Testing!")], ) def test__serialize_activity_custom_activity_syntactic_sugar( - activity_name, activity_state, expected_name, expected_state + activity_name: str, activity_state: str | None, expected_name: str, expected_state: str ): activity = presences.Activity(name=activity_name, state=activity_state, type=presences.ActivityType.CUSTOM) @@ -117,7 +118,7 @@ def __init__(self, *, type=None, data=None, extra=None): class TestGatewayTransport: @pytest.fixture - def transport_impl(self): + def transport_impl(self) -> shard._GatewayTransport: return hikari_test_helpers.mock_class_namespace(shard._GatewayTransport, slots_=False)( ws=mock.Mock(), exit_stack=mock.AsyncMock(), @@ -128,7 +129,7 @@ def transport_impl(self): ) @pytest.mark.asyncio - async def test_send_close(self, transport_impl): + async def test_send_close(self, transport_impl: shard._GatewayTransport): transport_impl._sent_close = False with mock.patch.object(asyncio, "wait_for", return_value=mock.AsyncMock()) as wait_for: @@ -141,7 +142,7 @@ async def test_send_close(self, transport_impl): sleep.assert_awaited_once_with(0.25) @pytest.mark.asyncio - async def test_send_close_when_TimeoutError(self, transport_impl): + async def test_send_close_when_TimeoutError(self, transport_impl: shard._GatewayTransport): transport_impl._sent_close = False transport_impl._ws.close.side_effect = asyncio.TimeoutError @@ -153,7 +154,7 @@ async def test_send_close_when_TimeoutError(self, transport_impl): sleep.assert_awaited_once_with(0.25) @pytest.mark.asyncio - async def test_send_close_when_already_sent(self, transport_impl): + async def test_send_close_when_already_sent(self, transport_impl: shard._GatewayTransport): transport_impl._sent_close = True with mock.patch.object(aiohttp.ClientWebSocketResponse, "close", side_effect=asyncio.TimeoutError) as close: @@ -163,7 +164,7 @@ async def test_send_close_when_already_sent(self, transport_impl): @pytest.mark.asyncio @pytest.mark.parametrize("trace", [True, False]) - async def test_receive_json(self, transport_impl, trace): + async def test_receive_json(self, transport_impl: shard._GatewayTransport, trace: bool): transport_impl._receive_and_check = mock.AsyncMock() transport_impl._logger = mock.Mock(enabled_for=mock.Mock(return_value=trace)) @@ -174,7 +175,7 @@ async def test_receive_json(self, transport_impl, trace): @pytest.mark.asyncio @pytest.mark.parametrize("trace", [True, False]) - async def test_send_json(self, transport_impl, trace): + async def test_send_json(self, transport_impl: shard._GatewayTransport, trace: bool): transport_impl._ws.send_bytes = mock.AsyncMock() transport_impl._logger = mock.Mock(enabled_for=mock.Mock(return_value=trace)) transport_impl._dumps = mock.Mock(return_value=b"some data") @@ -184,14 +185,14 @@ async def test_send_json(self, transport_impl, trace): transport_impl._ws.send_bytes.assert_awaited_once_with(b"some data") @pytest.mark.asyncio - async def test__handle_other_message_when_TEXT(self, transport_impl): + async def test__handle_other_message_when_TEXT(self, transport_impl: shard._GatewayTransport): stub_response = StubResponse(type=aiohttp.WSMsgType.TEXT) with pytest.raises(errors.GatewayError, match="Unexpected message type received TEXT, expected BINARY"): transport_impl._handle_other_message(stub_response) @pytest.mark.asyncio - async def test__handle_other_message_when_BINARY(self, transport_impl): + async def test__handle_other_message_when_BINARY(self, transport_impl: shard._GatewayTransport): stub_response = StubResponse(type=aiohttp.WSMsgType.BINARY) with pytest.raises(errors.GatewayError, match="Unexpected message type received BINARY, expected TEXT"): @@ -208,7 +209,9 @@ async def test__handle_other_message_when_BINARY(self, transport_impl): errors.ShardCloseCode.RATE_LIMITED, ], ) - def test__handle_other_message_when_message_type_is_CLOSE_and_should_reconnect(self, code, transport_impl): + def test__handle_other_message_when_message_type_is_CLOSE_and_should_reconnect( + self, code: int | errors.ShardCloseCode, transport_impl: shard._GatewayTransport + ): stub_response = StubResponse(type=aiohttp.WSMsgType.CLOSE, extra="some error extra", data=code) with pytest.raises(errors.GatewayServerClosedConnectionError) as exinfo: @@ -220,7 +223,9 @@ def test__handle_other_message_when_message_type_is_CLOSE_and_should_reconnect(s assert exception.can_reconnect is True @pytest.mark.parametrize("code", [*range(4010, 4020), 5000]) - def test__handle_other_message_when_message_type_is_CLOSE_and_should_not_reconnect(self, code, transport_impl): + def test__handle_other_message_when_message_type_is_CLOSE_and_should_not_reconnect( + self, code: int, transport_impl: shard._GatewayTransport + ): stub_response = StubResponse(type=aiohttp.WSMsgType.CLOSE, extra="don't reconnect", data=code) with pytest.raises(errors.GatewayServerClosedConnectionError) as exinfo: @@ -231,19 +236,19 @@ def test__handle_other_message_when_message_type_is_CLOSE_and_should_not_reconne assert exception.code == int(code) assert exception.can_reconnect is False - def test__handle_other_message_when_message_type_is_CLOSING(self, transport_impl): + def test__handle_other_message_when_message_type_is_CLOSING(self, transport_impl: shard._GatewayTransport): stub_response = StubResponse(type=aiohttp.WSMsgType.CLOSING) with pytest.raises(errors.GatewayError, match="Socket has closed"): transport_impl._handle_other_message(stub_response) - def test__handle_other_message_when_message_type_is_CLOSED(self, transport_impl): + def test__handle_other_message_when_message_type_is_CLOSED(self, transport_impl: shard._GatewayTransport): stub_response = StubResponse(type=aiohttp.WSMsgType.CLOSED) with pytest.raises(errors.GatewayError, match="Socket has closed"): transport_impl._handle_other_message(stub_response) - def test__handle_other_message_when_message_type_is_unknown(self, transport_impl): + def test__handle_other_message_when_message_type_is_unknown(self, transport_impl: shard._GatewayTransport): stub_response = mock.AsyncMock(return_value=StubResponse(type=aiohttp.WSMsgType.ERROR)) exception = Exception("some error") transport_impl._ws.exception = mock.Mock(return_value=exception) @@ -267,7 +272,13 @@ def test__handle_other_message_when_message_type_is_unknown(self, transport_impl ], ) @pytest.mark.asyncio - async def test_connect(self, http_settings, proxy_settings, compression, expected_instance): + async def test_connect( + self, + http_settings: mock.Mock, + proxy_settings: mock.Mock, + compression: shard_api.GatewayCompression, + expected_instance: typing.Any, + ): logger = mock.Mock() log_filterer = mock.Mock() client_session = mock.Mock() @@ -326,7 +337,7 @@ async def test_connect(self, http_settings, proxy_settings, compression, expecte sleep.assert_not_called() @pytest.mark.asyncio - async def test_connect_when_error_while_connecting(self, http_settings, proxy_settings): + async def test_connect_when_error_while_connecting(self, http_settings: mock.Mock, proxy_settings: mock.Mock): logger = mock.Mock() log_filterer = mock.Mock() client_session = mock.Mock() @@ -346,8 +357,8 @@ async def test_connect_when_error_while_connecting(self, http_settings, proxy_se logger=logger, url="https://some.url", log_filterer=log_filterer, - loads=object(), - dumps=object(), + loads=mock.Mock(), + dumps=mock.Mock(), compression=True, ) @@ -367,7 +378,9 @@ async def test_connect_when_error_while_connecting(self, http_settings, proxy_se (asyncio.TimeoutError("some error"), "Timeout exceeded"), ], ) - async def test_connect_when_expected_error_while_connecting(self, http_settings, proxy_settings, error, reason): + async def test_connect_when_expected_error_while_connecting( + self, http_settings: mock.Mock, proxy_settings: mock.Mock, error: Exception, reason: str + ): logger = mock.Mock() log_filterer = mock.Mock() client_session = mock.Mock() @@ -390,8 +403,8 @@ async def test_connect_when_expected_error_while_connecting(self, http_settings, url="https://some.url", log_filterer=log_filterer, compression=True, - loads=object(), - dumps=object(), + loads=mock.Mock(), + dumps=mock.Mock(), ) exit_stack.aclose.assert_awaited_once_with() @@ -400,7 +413,7 @@ async def test_connect_when_expected_error_while_connecting(self, http_settings, class TestGatewayBasicTransport: @pytest.fixture - def transport_impl(self): + def transport_impl(self) -> shard._GatewayBasicTransport: return shard._GatewayBasicTransport( ws=mock.Mock(), exit_stack=mock.AsyncMock(), @@ -411,7 +424,7 @@ def transport_impl(self): ) @pytest.mark.asyncio - async def test__receive_and_check(self, transport_impl): + async def test__receive_and_check(self, transport_impl: shard._GatewayBasicTransport): transport_impl._ws.receive = mock.AsyncMock( return_value=StubResponse(type=aiohttp.WSMsgType.TEXT, data="some text") ) @@ -421,7 +434,7 @@ async def test__receive_and_check(self, transport_impl): transport_impl._ws.receive.assert_awaited_once_with() @pytest.mark.asyncio - async def test__receive_and_check_when_message_type_is_unknown(self, transport_impl): + async def test__receive_and_check_when_message_type_is_unknown(self, transport_impl: shard._GatewayBasicTransport): transport_impl._ws.receive = mock.AsyncMock(return_value=StubResponse(type=aiohttp.WSMsgType.BINARY)) with pytest.raises( @@ -446,7 +459,9 @@ def transport_impl(self): ) @pytest.mark.asyncio - async def test__receive_and_check_when_payload_split_across_frames(self, transport_impl): + async def test__receive_and_check_when_payload_split_across_frames( + self, transport_impl: shard._GatewayZlibStreamTransport + ): response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"x\xda\xf2H\xcd\xc9") response2 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\xc9W(\xcf/\xcaIQ\x04\x00\x00") response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\xff\xff") @@ -457,7 +472,9 @@ async def test__receive_and_check_when_payload_split_across_frames(self, transpo assert transport_impl._ws.receive.call_count == 3 @pytest.mark.asyncio - async def test__receive_and_check_when_full_payload_in_one_frame(self, transport_impl): + async def test__receive_and_check_when_full_payload_in_one_frame( + self, transport_impl: shard._GatewayZlibStreamTransport + ): response = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"x\xdaJLD\x07\x00\x00\x00\x00\xff\xff") transport_impl._ws.receive = mock.AsyncMock(return_value=response) @@ -466,7 +483,9 @@ async def test__receive_and_check_when_full_payload_in_one_frame(self, transport transport_impl._ws.receive.assert_awaited_once_with() @pytest.mark.asyncio - async def test__receive_and_check_when_message_type_is_unknown(self, transport_impl): + async def test__receive_and_check_when_message_type_is_unknown( + self, transport_impl: shard._GatewayZlibStreamTransport + ): transport_impl._ws.receive = mock.AsyncMock(return_value=StubResponse(type=aiohttp.WSMsgType.TEXT)) with pytest.raises( @@ -476,7 +495,9 @@ async def test__receive_and_check_when_message_type_is_unknown(self, transport_i await transport_impl._receive_and_check() @pytest.mark.asyncio - async def test__receive_and_check_when_issue_during_reception_of_multiple_frames(self, transport_impl): + async def test__receive_and_check_when_issue_during_reception_of_multiple_frames( + self, transport_impl: shard._GatewayZlibStreamTransport + ): response1 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"x\xda\xf2H\xcd\xc9") response2 = StubResponse(type=aiohttp.WSMsgType.ERROR, data="Something broke!") response3 = StubResponse(type=aiohttp.WSMsgType.BINARY, data=b"\x00\xff\xff") @@ -490,7 +511,7 @@ async def test__receive_and_check_when_issue_during_reception_of_multiple_frames @pytest.fixture -def client(http_settings, proxy_settings): +def client(http_settings: mock.Mock, proxy_settings: mock.Mock): return shard.GatewayShardImpl( event_manager=mock.Mock(), event_factory=mock.Mock(), @@ -503,7 +524,7 @@ def client(http_settings, proxy_settings): class TestGatewayShardImpl: - def test_using_etf_is_unsupported(self, http_settings, proxy_settings): + def test_using_etf_is_unsupported(self, http_settings: mock.Mock, proxy_settings: mock.Mock): with pytest.raises(NotImplementedError, match="Unsupported gateway data format: etf"): shard.GatewayShardImpl( event_manager=mock.Mock(), @@ -517,21 +538,23 @@ def test_using_etf_is_unsupported(self, http_settings, proxy_settings): compression="testing", ) - def test_heartbeat_latency_property(self, client): + def test_heartbeat_latency_property(self, client: shard.GatewayShardImpl): client._heartbeat_latency = 420 assert client.heartbeat_latency == 420 - def test_id_property(self, client): + def test_id_property(self, client: shard.GatewayShardImpl): client._shard_id = 101 assert client.id == 101 - def test_intents_property(self, client): + def test_intents_property(self, client: shard.GatewayShardImpl): mock_intents = object() client._intents = mock_intents assert client.intents is mock_intents @pytest.mark.parametrize(("keep_alive_task", "expected"), [(None, False), ("some", True)]) - def test_is_alive_property(self, client, keep_alive_task, expected): + def test_is_alive_property( + self, client: shard.GatewayShardImpl, keep_alive_task: typing.Any | None, expected: bool + ): client._keep_alive_task = keep_alive_task assert client.is_alive is expected @@ -547,7 +570,9 @@ def test_is_alive_property(self, client, keep_alive_task, expected): ("something", False, False), ], ) - def test_is_connected_property(self, client, ws, handshake_event, expected): + def test_is_connected_property( + self, client: shard.GatewayShardImpl, ws: typing.Any | None, handshake_event: bool | None, expected: bool + ): client._ws = ws client._handshake_event = ( None if handshake_event is None else mock.Mock(is_set=mock.Mock(return_value=handshake_event)) @@ -555,16 +580,16 @@ def test_is_connected_property(self, client, ws, handshake_event, expected): assert client.is_connected is expected - def test_shard_count_property(self, client): + def test_shard_count_property(self, client: shard.GatewayShardImpl): client._shard_count = 69 assert client.shard_count == 69 - def test_shard__check_if_connected_when_not_alive(self, client): + def test_shard__check_if_connected_when_not_alive(self, client: shard.GatewayShardImpl): with mock.patch.object(shard.GatewayShardImpl, "is_connected", new=False): with pytest.raises(errors.ComponentStateConflictError): client._check_if_connected() - def test_shard__check_if_connected_when_alive(self, client): + def test_shard__check_if_connected_when_alive(self, client: shard.GatewayShardImpl): with mock.patch.object(shard.GatewayShardImpl, "is_connected", new=True): client._check_if_connected() @@ -576,7 +601,12 @@ def test_shard__check_if_connected_when_alive(self, client): ) @pytest.mark.parametrize("activity", [presences.Activity(name="foo"), None]) def test__serialize_and_store_presence_payload_when_all_args_undefined( - self, client, idle_since, afk, status, activity + self, + client: shard.GatewayShardImpl, + idle_since: datetime.datetime | None, + afk: bool, + status: presences.Status, + activity: presences.Activity | None, ): client._activity = activity client._idle_since = idle_since @@ -616,7 +646,14 @@ def test__serialize_and_store_presence_payload_when_all_args_undefined( [presences.Status.DO_NOT_DISTURB, presences.Status.IDLE, presences.Status.ONLINE, presences.Status.OFFLINE], ) @pytest.mark.parametrize("activity", [presences.Activity(name="foo"), None]) - def test__serialize_and_store_presence_payload_sets_state(self, client, idle_since, afk, status, activity): + def test__serialize_and_store_presence_payload_sets_state( + self, + client: shard.GatewayShardImpl, + idle_since: datetime.datetime | None, + afk: bool, + status: presences.Status, + activity: presences.Activity | None, + ): client._serialize_and_store_presence_payload(idle_since=idle_since, afk=afk, status=status, activity=activity) assert client._activity == activity @@ -624,7 +661,7 @@ def test__serialize_and_store_presence_payload_sets_state(self, client, idle_sin assert client._is_afk == afk assert client._status == status - def test_get_user_id(self, client): + def test_get_user_id(self, client: shard.GatewayShardImpl): client._user_id = 123 with mock.patch.object(shard.GatewayShardImpl, "_check_if_connected") as check_if_alive: @@ -635,13 +672,13 @@ def test_get_user_id(self, client): @pytest.mark.asyncio class TestGatewayShardImplAsync: - async def test_close_when_no_keep_alive_task(self, client): + async def test_close_when_no_keep_alive_task(self, client: shard.GatewayShardImpl): client._keep_alive_task = None with pytest.raises(errors.ComponentStateConflictError): await client.close() - async def test_close_when_closing_event_set(self, client): + async def test_close_when_closing_event_set(self, client: shard.GatewayShardImpl): client._keep_alive_task = mock.Mock(cancel=mock.AsyncMock()) client._non_priority_rate_limit = mock.Mock() client._total_rate_limit = mock.Mock() @@ -655,7 +692,7 @@ async def test_close_when_closing_event_set(self, client): client._non_priority_rate_limit.close.assert_not_called() client._total_rate_limit.close.assert_not_called() - async def test_close_when_closing_event_not_set(self, client): + async def test_close_when_closing_event_not_set(self, client: shard.GatewayShardImpl): cancel_async_mock = mock.Mock() class TaskMock: @@ -686,13 +723,13 @@ def assert_awaited_once(self): client._non_priority_rate_limit.close.assert_called_once_with() client._total_rate_limit.close.assert_called_once_with() - async def test_join_when_not_alive(self, client): + async def test_join_when_not_alive(self, client: shard.GatewayShardImpl): client._keep_alive_task = None with pytest.raises(errors.ComponentStateConflictError): await client.join() - async def test_join(self, client): + async def test_join(self, client: shard.GatewayShardImpl): client._keep_alive_task = object() with mock.patch.object(asyncio, "wait_for") as wait_for: @@ -702,7 +739,7 @@ async def test_join(self, client): shield.assert_called_once_with(client._keep_alive_task) wait_for.assert_awaited_once_with(shield.return_value, timeout=None) - async def test__send_json(self, client): + async def test__send_json(self, client: shard.GatewayShardImpl): client._total_rate_limit = mock.AsyncMock() client._non_priority_rate_limit = mock.AsyncMock() client._ws = mock.AsyncMock() @@ -714,7 +751,7 @@ async def test__send_json(self, client): client._total_rate_limit.acquire.assert_awaited_once_with() client._ws.send_json.assert_awaited_once_with(data) - async def test__send_json_when_priority(self, client): + async def test__send_json_when_priority(self, client: shard.GatewayShardImpl): client._total_rate_limit = mock.AsyncMock() client._non_priority_rate_limit = mock.AsyncMock() client._ws = mock.AsyncMock() @@ -726,7 +763,9 @@ async def test__send_json_when_priority(self, client): client._total_rate_limit.acquire.assert_awaited_once_with() client._ws.send_json.assert_awaited_once_with(data) - async def test_request_guild_members_when_no_query_and_no_limit_and_GUILD_MEMBERS_not_enabled(self, client): + async def test_request_guild_members_when_no_query_and_no_limit_and_GUILD_MEMBERS_not_enabled( + self, client: shard.GatewayShardImpl + ): client._intents = intents.Intents.GUILD_INTEGRATIONS with mock.patch.object(shard.GatewayShardImpl, "_check_if_connected") as check_if_alive: @@ -735,7 +774,9 @@ async def test_request_guild_members_when_no_query_and_no_limit_and_GUILD_MEMBER check_if_alive.assert_called_once_with() - async def test_request_guild_members_when_presences_and_GUILD_PRESENCES_not_enabled(self, client): + async def test_request_guild_members_when_presences_and_GUILD_PRESENCES_not_enabled( + self, client: shard.GatewayShardImpl + ): client._intents = intents.Intents.GUILD_INTEGRATIONS with mock.patch.object(shard.GatewayShardImpl, "_check_if_connected") as check_if_alive: @@ -744,7 +785,9 @@ async def test_request_guild_members_when_presences_and_GUILD_PRESENCES_not_enab check_if_alive.assert_called_once_with() - async def test_request_guild_members_when_presences_false_and_GUILD_PRESENCES_not_enabled(self, client): + async def test_request_guild_members_when_presences_false_and_GUILD_PRESENCES_not_enabled( + self, client: shard.GatewayShardImpl + ): client._intents = intents.Intents.GUILD_INTEGRATIONS with mock.patch.object(shard.GatewayShardImpl, "_send_json") as send_json: @@ -758,7 +801,9 @@ async def test_request_guild_members_when_presences_false_and_GUILD_PRESENCES_no check_if_alive.assert_called_once_with() @pytest.mark.parametrize("kwargs", [{"query": "some query"}, {"limit": 1}]) - async def test_request_guild_members_when_specifiying_users_with_limit_or_query(self, client, kwargs): + async def test_request_guild_members_when_specifiying_users_with_limit_or_query( + self, client: shard.GatewayShardImpl, kwargs: dict[str, typing.Any] + ): client._intents = intents.Intents.GUILD_INTEGRATIONS with mock.patch.object(shard.GatewayShardImpl, "_check_if_connected") as check_if_alive: @@ -768,7 +813,9 @@ async def test_request_guild_members_when_specifiying_users_with_limit_or_query( check_if_alive.assert_called_once_with() @pytest.mark.parametrize("limit", [-1, 101]) - async def test_request_guild_members_when_limit_under_0_or_over_100(self, client, limit): + async def test_request_guild_members_when_limit_under_0_or_over_100( + self, client: shard.GatewayShardImpl, limit: int + ): client._intents = intents.Intents.ALL with mock.patch.object(shard.GatewayShardImpl, "_check_if_connected") as check_if_alive: @@ -777,7 +824,7 @@ async def test_request_guild_members_when_limit_under_0_or_over_100(self, client check_if_alive.assert_called_once_with() - async def test_request_guild_members_when_users_over_100(self, client): + async def test_request_guild_members_when_users_over_100(self, client: shard.GatewayShardImpl): client._intents = intents.Intents.ALL with mock.patch.object(shard.GatewayShardImpl, "_check_if_connected") as check_if_alive: @@ -786,7 +833,7 @@ async def test_request_guild_members_when_users_over_100(self, client): check_if_alive.assert_called_once_with() - async def test_request_guild_members_when_nonce_over_32_chars(self, client): + async def test_request_guild_members_when_nonce_over_32_chars(self, client: shard.GatewayShardImpl): client._intents = intents.Intents.ALL with mock.patch.object(shard.GatewayShardImpl, "_check_if_connected") as check_if_alive: @@ -796,7 +843,7 @@ async def test_request_guild_members_when_nonce_over_32_chars(self, client): check_if_alive.assert_called_once_with() @pytest.mark.parametrize("include_presences", [True, False]) - async def test_request_guild_members(self, client, include_presences): + async def test_request_guild_members(self, client: shard.GatewayShardImpl, include_presences: bool): client._intents = intents.Intents.ALL with mock.patch.object(shard.GatewayShardImpl, "_send_json") as send_json: @@ -809,13 +856,13 @@ async def test_request_guild_members(self, client, include_presences): check_if_alive.assert_called_once_with() @pytest.mark.parametrize("attr", ["_keep_alive_task", "_handshake_event"]) - async def test_start_when_already_running(self, client, attr): + async def test_start_when_already_running(self, client: shard.GatewayShardImpl, attr: str): setattr(client, attr, object()) with pytest.raises(errors.ComponentStateConflictError): await client.start() - async def test_start_when_shard_closed_before_starting(self, client): + async def test_start_when_shard_closed_before_starting(self, client: shard.GatewayShardImpl): client._keep_alive_task = None client._shard_id = 20 handshake_event = mock.Mock(is_set=mock.Mock(return_value=False)) @@ -833,7 +880,7 @@ async def test_start_when_shard_closed_before_starting(self, client): assert client._keep_alive_task is None - async def test_start(self, client): + async def test_start(self, client: shard.GatewayShardImpl): client._keep_alive_task = None client._shard_id = 20 handshake_event = mock.Mock(is_set=mock.Mock(return_value=True)) @@ -855,7 +902,7 @@ async def test_start(self, client): shield.assert_called_once_with(create_task.return_value) first_completed.assert_awaited_once_with(handshake_event.wait.return_value, shield.return_value) - async def test_update_presence(self, client): + async def test_update_presence(self, client: shard.GatewayShardImpl): with mock.patch.object(shard.GatewayShardImpl, "_serialize_and_store_presence_payload") as presence: with mock.patch.object(shard.GatewayShardImpl, "_check_if_connected") as check_if_alive: with mock.patch.object(shard.GatewayShardImpl, "_send_json") as send_json: @@ -866,7 +913,7 @@ async def test_update_presence(self, client): send_json.assert_awaited_once_with({"op": 3, "d": presence.return_value}) check_if_alive.assert_called_once_with() - async def test_update_voice_state(self, client): + async def test_update_voice_state(self, client: shard.GatewayShardImpl): with mock.patch.object(shard.GatewayShardImpl, "_check_if_connected") as check_if_alive: with mock.patch.object(shard.GatewayShardImpl, "_send_json") as send_json: await client.update_voice_state(123456, 6969420, self_mute=False, self_deaf=True) @@ -876,7 +923,7 @@ async def test_update_voice_state(self, client): ) check_if_alive.assert_called_once_with() - async def test_update_voice_state_without_optionals(self, client): + async def test_update_voice_state_without_optionals(self, client: shard.GatewayShardImpl): with mock.patch.object(shard.GatewayShardImpl, "_check_if_connected") as check_if_alive: with mock.patch.object(shard.GatewayShardImpl, "_send_json") as send_json: await client.update_voice_state(123456, 6969420) @@ -885,7 +932,7 @@ async def test_update_voice_state_without_optionals(self, client): check_if_alive.assert_called_once_with() @hikari_test_helpers.timeout() - async def test__heartbeat(self, client): + async def test__heartbeat(self, client: shard.GatewayShardImpl): client._last_heartbeat_sent = 5 client._logger = mock.Mock() @@ -906,7 +953,7 @@ class ExitException(Exception): ... sleep.assert_has_awaits([mock.call(20), mock.call(20)]) @hikari_test_helpers.timeout() - async def test__heartbeat_when_zombie(self, client): + async def test__heartbeat_when_zombie(self, client: shard.GatewayShardImpl): client._last_heartbeat_sent = 10 client._logger = mock.Mock() @@ -916,13 +963,15 @@ async def test__heartbeat_when_zombie(self, client): sleep.assert_not_called() - async def test__connect_when_ws(self, client): + async def test__connect_when_ws(self, client: shard.GatewayShardImpl): client._ws = object() with pytest.raises(errors.ComponentStateConflictError): await client._connect() - async def test__connect_when_not_reconnecting(self, client, http_settings, proxy_settings): + async def test__connect_when_not_reconnecting( + self, client: shard.GatewayShardImpl, http_settings: mock.Mock, proxy_settings: mock.Mock + ): ws = mock.AsyncMock() ws.receive_json.return_value = {"op": 10, "d": {"heartbeat_interval": 10}} client._compression = shard_api.GatewayCompression.PAYLOAD_ZLIB_STREAM @@ -1016,7 +1065,9 @@ async def test__connect_when_not_reconnecting(self, client, http_settings, proxy client._handshake_event.wait.return_value, shielded_heartbeat_task, shielded_poll_events_task ) - async def test__connect_when_reconnecting(self, client, http_settings, proxy_settings): + async def test__connect_when_reconnecting( + self, client: shard.GatewayShardImpl, http_settings: mock.Mock, proxy_settings: mock.Mock + ): ws = mock.AsyncMock() ws.receive_json.return_value = {"op": 10, "d": {"heartbeat_interval": 10}} client._compression = shard_api.GatewayCompression.PAYLOAD_ZLIB_STREAM @@ -1086,7 +1137,7 @@ async def test__connect_when_reconnecting(self, client, http_settings, proxy_set client._handshake_event.wait.return_value, shielded_heartbeat_task, shielded_poll_events_task ) - async def test__connect_when_op_received_is_not_HELLO(self, client): + async def test__connect_when_op_received_is_not_HELLO(self, client: shard.GatewayShardImpl): ws = mock.AsyncMock() ws.receive_json.return_value = {"op": 0, "d": {"not": "hello"}} client._gateway_url = "somewhere.com" @@ -1107,9 +1158,9 @@ async def test__connect_when_op_received_is_not_HELLO(self, client): ) @pytest.mark.skip("TODO") - async def test__keep_alive(self, client): ... + async def test__keep_alive(self, client: shard.GatewayShardImpl): ... - async def test__send_heartbeat(self, client): + async def test__send_heartbeat(self, client: shard.GatewayShardImpl) -> None: client._last_heartbeat_sent = 0 client._seq = 10 @@ -1120,7 +1171,7 @@ async def test__send_heartbeat(self, client): send_json.assert_awaited_once_with({"op": 1, "d": 10}, priority=True) assert client._last_heartbeat_sent == 200 - async def test__poll_events_on_dispatch(self, client): + async def test__poll_events_on_dispatch(self, client: shard.GatewayShardImpl): payload = {"op": 0, "t": "SOMETHING", "d": {"some": "test"}, "s": 101} client._ws = mock.Mock(receive_json=mock.AsyncMock(side_effect=[payload, RuntimeError])) @@ -1136,7 +1187,7 @@ async def test__poll_events_on_dispatch(self, client): client._event_manager.consume_raw_event.assert_called_once_with("SOMETHING", client, {"some": "test"}) client._handshake_event.set.assert_not_called() - async def test__poll_events_on_dispatch_when_READY(self, client): + async def test__poll_events_on_dispatch_when_READY(self, client: shard.GatewayShardImpl): data = { "v": 10, "session_id": 100001, @@ -1166,7 +1217,7 @@ async def test__poll_events_on_dispatch_when_READY(self, client): client._event_manager.consume_raw_event.assert_called_once_with("READY", client, data) client._handshake_event.set.assert_called_once_with() - async def test__poll_events_on_dispatch_when_RESUMED(self, client): + async def test__poll_events_on_dispatch_when_RESUMED(self, client: shard.GatewayShardImpl): payload = {"op": 0, "t": "RESUMED", "d": {"some": "test"}, "s": 101} client._ws = mock.Mock(receive_json=mock.AsyncMock(side_effect=[payload, RuntimeError])) @@ -1182,7 +1233,7 @@ async def test__poll_events_on_dispatch_when_RESUMED(self, client): client._event_manager.consume_raw_event.assert_called_once_with("RESUMED", client, {"some": "test"}) client._handshake_event.set.assert_called_once_with() - async def test__poll_events_on_heartbeat_ack(self, client): + async def test__poll_events_on_heartbeat_ack(self, client: shard.GatewayShardImpl): payload = {"op": 11} client._ws = mock.Mock(receive_json=mock.AsyncMock(side_effect=[payload, RuntimeError])) @@ -1200,7 +1251,7 @@ async def test__poll_events_on_heartbeat_ack(self, client): assert client._heartbeat_latency == 1.5 client._handshake_event.set.assert_not_called() - async def test__poll_events_on_heartbeat(self, client): + async def test__poll_events_on_heartbeat(self, client: shard.GatewayShardImpl): payload = {"op": 1} client._ws = mock.Mock(receive_json=mock.AsyncMock(side_effect=[payload, RuntimeError])) @@ -1214,7 +1265,7 @@ async def test__poll_events_on_heartbeat(self, client): send_heartbeat.assert_awaited_once_with() client._handshake_event.set.assert_not_called() - async def test__poll_events_on_reconnect(self, client): + async def test__poll_events_on_reconnect(self, client: shard.GatewayShardImpl): payload = {"op": 7} client._ws = mock.Mock(receive_json=mock.AsyncMock(side_effect=[payload, RuntimeError])) @@ -1225,7 +1276,7 @@ async def test__poll_events_on_reconnect(self, client): assert client._ws.receive_json.await_count == 1 client._handshake_event.set.assert_not_called() - async def test__poll_events_on_invalid_session_when_can_resume(self, client): + async def test__poll_events_on_invalid_session_when_can_resume(self, client: shard.GatewayShardImpl): payload = {"op": 9, "d": True} client._seq = 123 @@ -1240,7 +1291,7 @@ async def test__poll_events_on_invalid_session_when_can_resume(self, client): assert client._session_id == 456 client._handshake_event.set.assert_not_called() - async def test__poll_events_on_invalid_session_when_cant_resume(self, client): + async def test__poll_events_on_invalid_session_when_cant_resume(self, client: shard.GatewayShardImpl): payload = {"op": 9, "d": False} client._seq = 123 diff --git a/tests/hikari/impl/test_special_endpoints.py b/tests/hikari/impl/test_special_endpoints.py index ce4b03b932..17a8d2609f 100644 --- a/tests/hikari/impl/test_special_endpoints.py +++ b/tests/hikari/impl/test_special_endpoints.py @@ -48,15 +48,15 @@ class TestTypingIndicator: @pytest.fixture - def typing_indicator(self): + def typing_indicator(self) -> special_endpoints.TypingIndicator: return hikari_test_helpers.mock_class_namespace(special_endpoints.TypingIndicator, init_=False) - def test___enter__(self, typing_indicator): - # ruff gets annoyed if we use "with" here so here's a hacky alternative + def test___enter__(self, typing_indicator: special_endpoints.TypingIndicator): + # flake8 gets annoyed if we use "with" here so here's a hacky alternative with pytest.raises(TypeError, match=" is async-only, did you mean 'async with'?"): typing_indicator().__enter__() - def test___exit__(self, typing_indicator): + def test___exit__(self, typing_indicator: special_endpoints.TypingIndicator): try: typing_indicator().__exit__(None, None, None) except AttributeError as exc: @@ -1005,7 +1005,7 @@ def test_set_flags(self): def test_build(self): builder = special_endpoints.InteractionDeferredBuilder(base_interactions.ResponseType.DEFERRED_MESSAGE_CREATE) - result, attachments = builder.build(object()) + result, attachments = builder.build(mock.Mock()) assert result == {"type": base_interactions.ResponseType.DEFERRED_MESSAGE_CREATE} assert attachments == () @@ -1015,7 +1015,7 @@ def test_build_with_flags(self): base_interactions.ResponseType.DEFERRED_MESSAGE_CREATE ).set_flags(64) - result, attachments = builder.build(object()) + result, attachments = builder.build(mock.Mock()) assert result == {"type": base_interactions.ResponseType.DEFERRED_MESSAGE_CREATE, "data": {"flags": 64}} assert attachments == () @@ -1050,7 +1050,7 @@ def test_attachments_property_when_undefined(self): assert builder.attachments is undefined.UNDEFINED def test_components_property(self): - mock_component = object() + mock_component = mock.Mock() builder = special_endpoints.InteractionMessageBuilder(4).add_component(mock_component) assert builder.components == [mock_component] @@ -1061,7 +1061,7 @@ def test_components_property_when_undefined(self): assert builder.components is undefined.UNDEFINED def test_embeds_property(self): - mock_embed = object() + mock_embed = mock.Mock() builder = special_endpoints.InteractionMessageBuilder(4).add_embed(mock_embed) assert builder.embeds == [mock_embed] @@ -1082,9 +1082,9 @@ def test_is_tts_property(self): assert builder.is_tts is False def test_mentions_everyone_property(self): - builder = special_endpoints.InteractionMessageBuilder(4).set_mentions_everyone([123, 453]) + builder = special_endpoints.InteractionMessageBuilder(4).set_mentions_everyone(True) - assert builder.mentions_everyone == [123, 453] + assert builder.mentions_everyone is True def test_role_mentions_property(self): builder = special_endpoints.InteractionMessageBuilder(4).set_role_mentions([999]) @@ -1099,8 +1099,8 @@ def test_user_mentions_property(self): def test_build(self): mock_entity_factory = mock.Mock() mock_component = mock.Mock() - mock_embed = object() - mock_serialized_embed = object() + mock_embed = mock.Mock() + mock_serialized_embed = mock.Mock() mock_entity_factory.serialize_embed.return_value = (mock_serialized_embed, []) builder = ( special_endpoints.InteractionMessageBuilder(base_interactions.ResponseType.MESSAGE_CREATE) @@ -1169,15 +1169,15 @@ def test_build_for_partial_when_empty_lists(self): def test_build_handles_attachments(self): mock_entity_factory = mock.Mock() mock_message_attachment = mock.Mock(messages.Attachment, id=123, filename="testing") - mock_file_attachment = object() - mock_embed = object() - mock_embed_attachment = object() + mock_file_attachment = mock.Mock() + mock_embed = mock.Mock() + mock_embed_attachment = mock.Mock() mock_entity_factory.serialize_embed.return_value = (mock_embed, [mock_embed_attachment]) builder = ( special_endpoints.InteractionMessageBuilder(base_interactions.ResponseType.MESSAGE_CREATE) .add_attachment(mock_file_attachment) .add_attachment(mock_message_attachment) - .add_embed(object()) + .add_embed(mock.Mock()) ) with mock.patch.object(files, "ensure_resource") as ensure_resource: @@ -1243,27 +1243,27 @@ class TestCommandBuilder: def stub_command(self) -> type[special_endpoints.CommandBuilder]: return hikari_test_helpers.mock_class_namespace(special_endpoints.CommandBuilder) - def test_name_property(self, stub_command): + def test_name_property(self, stub_command: type[special_endpoints.CommandBuilder]): builder = stub_command("NOOOOO").set_name("aaaaa") assert builder.name == "aaaaa" - def test_id_property(self, stub_command): + def test_id_property(self, stub_command: type[special_endpoints.CommandBuilder]): builder = stub_command("OKSKDKSDK").set_id(3212123) assert builder.id == 3212123 - def test_default_member_permissions(self, stub_command): + def test_default_member_permissions(self, stub_command: type[special_endpoints.CommandBuilder]): builder = stub_command("oksksksk").set_default_member_permissions(permissions.Permissions.ADMINISTRATOR) assert builder.default_member_permissions == permissions.Permissions.ADMINISTRATOR - def test_is_nsfw_property(self, stub_command): + def test_is_nsfw_property(self, stub_command: type[special_endpoints.CommandBuilder]): builder = stub_command("oksksksk").set_is_nsfw(True) assert builder.is_nsfw is True - def test_name_localizations_property(self, stub_command): + def test_name_localizations_property(self, stub_command: type[special_endpoints.CommandBuilder]): builder = stub_command("oksksksk").set_name_localizations({"aaa": "bbb", "ccc": "DDd"}) assert builder.name_localizations == {"aaa": "bbb", "ccc": "DDd"} @@ -1277,7 +1277,7 @@ def test_description_property(self): def test_options_property(self): builder = special_endpoints.SlashCommandBuilder("OKSKDKSDK", "inmjfdsmjiooikjsa") - mock_option = object() + mock_option = mock.Mock() assert builder.options == [] @@ -1287,7 +1287,7 @@ def test_options_property(self): def test_build_with_optional_data(self): mock_entity_factory = mock.Mock() - mock_option = object() + mock_option = mock.Mock() builder = ( special_endpoints.SlashCommandBuilder( "we are number", @@ -1393,7 +1393,7 @@ async def test_create_with_guild(self): class TestContextMenuBuilder: def test_build_with_optional_data(self): builder = ( - special_endpoints.ContextMenuCommandBuilder(commands.CommandType.USER, "we are number") + special_endpoints.ContextMenuCommandBuilder(name="we are number", type=commands.CommandType.USER) .set_id(3412312) .set_name_localizations({locales.Locale.TR: "merhaba"}) .set_default_member_permissions(permissions.Permissions.ADMINISTRATOR) @@ -1416,7 +1416,7 @@ def test_build_with_optional_data(self): } def test_build_without_optional_data(self): - builder = special_endpoints.ContextMenuCommandBuilder(commands.CommandType.MESSAGE, "nameeeee") + builder = special_endpoints.ContextMenuCommandBuilder(name="nameeeee", type=commands.CommandType.MESSAGE) result = builder.build(mock.Mock()) @@ -1425,7 +1425,7 @@ def test_build_without_optional_data(self): @pytest.mark.asyncio async def test_create(self): builder = ( - special_endpoints.ContextMenuCommandBuilder(commands.CommandType.USER, "we are number") + special_endpoints.ContextMenuCommandBuilder(name="we are number", type=commands.CommandType.USER) .set_default_member_permissions(permissions.Permissions.BAN_MEMBERS) .set_name_localizations({"meow": "nyan"}) .set_is_nsfw(True) @@ -1448,7 +1448,7 @@ async def test_create(self): @pytest.mark.asyncio async def test_create_with_guild(self): builder = ( - special_endpoints.ContextMenuCommandBuilder(commands.CommandType.USER, "we are number") + special_endpoints.ContextMenuCommandBuilder(name="we are number", type=commands.CommandType.USER) .set_default_member_permissions(permissions.Permissions.BAN_MEMBERS) .set_name_localizations({"en-ghibli": "meow"}) .set_is_nsfw(True) @@ -1501,20 +1501,22 @@ def style(self) -> components.ButtonStyle: return components.ButtonStyle.DANGER @pytest.fixture - def button(self): + def button(self) -> special_endpoints._ButtonBuilder: return Test_ButtonBuilder.ButtonBuilder(id=5855932, emoji=543123, label="a lebel", is_disabled=True) - def test_type_property(self, button): + def test_type_property(self, button: special_endpoints._ButtonBuilder): assert button.type is components.ComponentType.BUTTON - def test_style_property(self, button): + def test_style_property(self, button: special_endpoints._ButtonBuilder): assert button.style is components.ButtonStyle.DANGER - def test_emoji_property(self, button): + def test_emoji_property(self, button: special_endpoints._ButtonBuilder): assert button.emoji == 543123 @pytest.mark.parametrize("emoji", ["unicode", emojis.UnicodeEmoji("unicode")]) - def test_set_emoji_with_unicode_emoji(self, button, emoji): + def test_set_emoji_with_unicode_emoji( + self, button: special_endpoints._ButtonBuilder, emoji: str | emojis.UnicodeEmoji + ): result = button.set_emoji(emoji) assert result is button @@ -1522,8 +1524,12 @@ def test_set_emoji_with_unicode_emoji(self, button, emoji): assert button._emoji_id is undefined.UNDEFINED assert button._emoji_name == "unicode" - @pytest.mark.parametrize("emoji", [emojis.CustomEmoji(name="ok", id=34123123, is_animated=False), 34123123]) - def test_set_emoji_with_custom_emoji(self, button, emoji): + @pytest.mark.parametrize( + "emoji", [emojis.CustomEmoji(name="ok", id=snowflakes.Snowflake(34123123), is_animated=False), 34123123] + ) + def test_set_emoji_with_custom_emoji( + self, button: special_endpoints._ButtonBuilder, emoji: int | emojis.CustomEmoji + ): result = button.set_emoji(emoji) assert result is button @@ -1531,7 +1537,7 @@ def test_set_emoji_with_custom_emoji(self, button, emoji): assert button._emoji_id == "34123123" assert button._emoji_name is undefined.UNDEFINED - def test_set_emoji_with_undefined(self, button): + def test_set_emoji_with_undefined(self, button: special_endpoints._ButtonBuilder): result = button.set_emoji(undefined.UNDEFINED) assert result is button @@ -1539,17 +1545,19 @@ def test_set_emoji_with_undefined(self, button): assert button._emoji_name is undefined.UNDEFINED assert button._emoji is undefined.UNDEFINED - def test_set_label(self, button): + def test_set_label(self, button: special_endpoints._ButtonBuilder): assert button.set_label("hi hi") is button assert button.label == "hi hi" - def test_set_is_disabled(self, button): + def test_set_is_disabled(self, button: special_endpoints._ButtonBuilder): assert button.set_is_disabled(False) assert button.is_disabled is False - def test_build(self, button): + def test_build(self, button: special_endpoints._ButtonBuilder): payload, attachments = button.build() + assert attachments == [] + assert payload == { "id": 5855932, "type": components.ComponentType.BUTTON, @@ -1559,9 +1567,9 @@ def test_build(self, button): "disabled": True, } - assert attachments == [] - - @pytest.mark.parametrize("emoji", [123321, emojis.CustomEmoji(id=123321, name="", is_animated=True)]) + @pytest.mark.parametrize( + "emoji", [123321, emojis.CustomEmoji(id=snowflakes.Snowflake(123321), name="", is_animated=True)] + ) def test_build_with_custom_emoji(self, emoji: typing.Union[int, emojis.Emoji]): button = Test_ButtonBuilder.ButtonBuilder(emoji=emoji) @@ -1649,29 +1657,31 @@ def test_build(self): class TestSelectOptionBuilder: @pytest.fixture - def option(self): + def option(self) -> special_endpoints.SelectOptionBuilder: return special_endpoints.SelectOptionBuilder(label="ok", value="ok2") - def test_label_property(self, option): + def test_label_property(self, option: special_endpoints.SelectOptionBuilder): option.set_label("new_label") assert option.label == "new_label" - def test_value_property(self, option): + def test_value_property(self, option: special_endpoints.SelectOptionBuilder): option.set_value("aaaaaaaaaaaa") assert option.value == "aaaaaaaaaaaa" - def test_emoji_property(self, option): + def test_emoji_property(self, option: special_endpoints.SelectOptionBuilder): option._emoji = 123321 assert option.emoji == 123321 - def test_set_description(self, option): + def test_set_description(self, option: special_endpoints.SelectOptionBuilder): assert option.set_description("a desk") is option assert option.description == "a desk" @pytest.mark.parametrize("emoji", ["unicode", emojis.UnicodeEmoji("unicode")]) - def test_set_emoji_with_unicode_emoji(self, option, emoji): + def test_set_emoji_with_unicode_emoji( + self, option: special_endpoints.SelectOptionBuilder, emoji: str | emojis.UnicodeEmoji + ): result = option.set_emoji(emoji) assert result is option @@ -1679,8 +1689,12 @@ def test_set_emoji_with_unicode_emoji(self, option, emoji): assert option._emoji_id is undefined.UNDEFINED assert option._emoji_name == "unicode" - @pytest.mark.parametrize("emoji", [emojis.CustomEmoji(name="ok", id=34123123, is_animated=False), 34123123]) - def test_set_emoji_with_custom_emoji(self, option, emoji): + @pytest.mark.parametrize( + "emoji", [emojis.CustomEmoji(name="ok", id=snowflakes.Snowflake(34123123), is_animated=False), 34123123] + ) + def test_set_emoji_with_custom_emoji( + self, option: special_endpoints.SelectOptionBuilder, emoji: int | emojis.CustomEmoji + ): result = option.set_emoji(emoji) assert result is option @@ -1688,7 +1702,7 @@ def test_set_emoji_with_custom_emoji(self, option, emoji): assert option._emoji_id == "34123123" assert option._emoji_name is undefined.UNDEFINED - def test_set_emoji_with_undefined(self, option): + def test_set_emoji_with_undefined(self, option: special_endpoints.SelectOptionBuilder): result = option.set_emoji(undefined.UNDEFINED) assert result is option @@ -1696,7 +1710,7 @@ def test_set_emoji_with_undefined(self, option): assert option._emoji_name is undefined.UNDEFINED assert option._emoji is undefined.UNDEFINED - def test_set_is_default(self, option): + def test_set_is_default(self, option: special_endpoints.SelectOptionBuilder): assert option.set_is_default(True) is option assert option.is_default is True @@ -1742,24 +1756,24 @@ def test_type_property(self): assert menu.type == 123 - def test_custom_id_property(self, menu): + def test_custom_id_property(self, menu: special_endpoints.SelectMenuBuilder): menu.set_custom_id("ooooo") assert menu.custom_id == "ooooo" - def test_set_is_disabled(self, menu): + def test_set_is_disabled(self, menu: special_endpoints.SelectMenuBuilder): assert menu.set_is_disabled(True) is menu assert menu.is_disabled is True - def test_set_placeholder(self, menu): + def test_set_placeholder(self, menu: special_endpoints.SelectMenuBuilder): assert menu.set_placeholder("place") is menu assert menu.placeholder == "place" - def test_set_min_values(self, menu): + def test_set_min_values(self, menu: special_endpoints.SelectMenuBuilder): assert menu.set_min_values(1) is menu assert menu.min_values == 1 - def test_set_max_values(self, menu): + def test_set_max_values(self, menu: special_endpoints.SelectMenuBuilder): assert menu.set_max_values(25) is menu assert menu.max_values == 25 @@ -1806,11 +1820,11 @@ def test_build_without_optional_fields(self): class TestTextSelectMenuBuilder: @pytest.fixture - def menu(self): + def menu(self) -> special_endpoints.TextSelectMenuBuilder[typing.NoReturn]: return special_endpoints.TextSelectMenuBuilder(custom_id="o2o2o2") def test_parent_property(self): - mock_parent = object() + mock_parent = mock.Mock() menu = special_endpoints.TextSelectMenuBuilder(custom_id="o2o2o2", parent=mock_parent) assert menu.parent is mock_parent @@ -1835,8 +1849,8 @@ def test_add_option(self, menu: special_endpoints.TextSelectMenuBuilder[typing.N assert option.emoji == "e" assert option.is_default is True - def test_add_raw_option(self, menu): - mock_option = object() + def test_add_raw_option(self, menu: special_endpoints.TextSelectMenuBuilder[typing.NoReturn]): + mock_option = mock.Mock() menu.add_raw_option(mock_option) @@ -1942,41 +1956,41 @@ def test_build_without_optional_fields(self): class TestTextInput: @pytest.fixture - def text_input(self): + def text_input(self) -> special_endpoints.TextInputBuilder: return special_endpoints.TextInputBuilder(custom_id="o2o2o2", label="label") - def test_type_property(self, text_input): + def test_type_property(self, text_input: special_endpoints.TextInputBuilder): assert text_input.type is components.ComponentType.TEXT_INPUT - def test_set_style(self, text_input): + def test_set_style(self, text_input: special_endpoints.TextInputBuilder): assert text_input.set_style(components.TextInputStyle.PARAGRAPH) is text_input assert text_input.style == components.TextInputStyle.PARAGRAPH - def test_set_custom_id(self, text_input): + def test_set_custom_id(self, text_input: special_endpoints.TextInputBuilder): assert text_input.set_custom_id("custooom") is text_input assert text_input.custom_id == "custooom" - def test_set_label(self, text_input): + def test_set_label(self, text_input: special_endpoints.TextInputBuilder): assert text_input.set_label("labeeeel") is text_input assert text_input.label == "labeeeel" - def test_set_placeholder(self, text_input): + def test_set_placeholder(self, text_input: special_endpoints.TextInputBuilder): assert text_input.set_placeholder("place") is text_input assert text_input.placeholder == "place" - def test_set_required(self, text_input): + def test_set_required(self, text_input: special_endpoints.TextInputBuilder): assert text_input.set_required(True) is text_input assert text_input.is_required is True - def test_set_value(self, text_input): + def test_set_value(self, text_input: special_endpoints.TextInputBuilder): assert text_input.set_value("valueeeee") is text_input assert text_input.value == "valueeeee" - def test_set_min_length_(self, text_input): + def test_set_min_length_(self, text_input: special_endpoints.TextInputBuilder): assert text_input.set_min_length(10) is text_input assert text_input.min_length == 10 - def test_set_max_length(self, text_input): + def test_set_max_length(self, text_input: special_endpoints.TextInputBuilder): assert text_input.set_max_length(250) is text_input assert text_input.max_length == 250 @@ -2591,20 +2605,6 @@ def test_build(self): assert attachments == [files.ensure_resource("some-test-file.png"), files.ensure_resource("file.txt")] - def test_build_without_optional_fields(self): - container = special_endpoints.ContainerComponentBuilder(accent_color=None) - - payload, attachments = container.build() - - assert payload == { - "type": components.ComponentType.CONTAINER, - "accent_color": None, - "spoiler": False, - "components": [], - } - - assert attachments == [] - def test_build_without_undefined_fields(self): container = special_endpoints.ContainerComponentBuilder() diff --git a/tests/hikari/impl/test_voice.py b/tests/hikari/impl/test_voice.py index 7368b0fc0d..25ff66e67b 100644 --- a/tests/hikari/impl/test_voice.py +++ b/tests/hikari/impl/test_voice.py @@ -27,63 +27,68 @@ from hikari import errors from hikari import snowflakes +from hikari import traits +from hikari.api import voice as voice_api from hikari.events import voice_events from hikari.impl import voice -from tests.hikari import hikari_test_helpers class TestVoiceComponentImpl: @pytest.fixture - def mock_app(self): - return mock.Mock() + def mock_app(self) -> traits.GatewayBotAware: + return mock.Mock(traits.GatewayBotAware) @pytest.fixture - def voice_client(self, mock_app): - client = hikari_test_helpers.mock_class_namespace(voice.VoiceComponentImpl, slots_=False)(mock_app) + def voice_client(self, mock_app: traits.GatewayBotAware) -> voice.VoiceComponentImpl: + client = voice.VoiceComponentImpl(mock_app) client._is_alive = True return client - def test_is_alive_property(self, voice_client): - voice_client.is_alive is voice_client._is_alive + def test_is_alive_property(self, voice_client: voice.VoiceComponentImpl): + assert voice_client.is_alive is voice_client._is_alive - def test__check_if_alive_when_alive(self, voice_client): - voice_client._is_alive = True - voice_client._check_if_alive() + def test__check_if_alive_when_alive(self, voice_client: voice.VoiceComponentImpl): + with mock.patch.object(voice_client, "_is_alive", True): + assert voice_client._check_if_alive() is None - def test__check_if_alive_when_not_alive(self, voice_client): + def test__check_if_alive_when_not_alive(self, voice_client: voice.VoiceComponentImpl): voice_client._is_alive = False - with pytest.raises(errors.ComponentStateConflictError): + with mock.patch.object(voice_client, "_is_alive", False), pytest.raises(errors.ComponentStateConflictError): voice_client._check_if_alive() - def test__check_if_alive_when_closing(self, voice_client): - voice_client._is_alive = True - voice_client._is_closing = True - - with pytest.raises(errors.ComponentStateConflictError): + def test__check_if_alive_when_closing(self, voice_client: voice.VoiceComponentImpl): + with ( + mock.patch.object(voice_client, "_is_alive", True), + mock.patch.object(voice_client, "_is_closing", True), + pytest.raises(errors.ComponentStateConflictError), + ): voice_client._check_if_alive() @pytest.mark.asyncio - async def test_disconnect(self, voice_client): + async def test_disconnect(self, voice_client: voice.VoiceComponentImpl): mock_connection = mock.AsyncMock() mock_connection_2 = mock.AsyncMock() voice_client._connections = { snowflakes.Snowflake(123): mock_connection, snowflakes.Snowflake(5324): mock_connection_2, } - voice_client._check_if_alive = mock.Mock() - await voice_client.disconnect(123) + with mock.patch.object(voice.VoiceComponentImpl, "_check_if_alive") as patched__check_if_alive: + await voice_client.disconnect(123) - voice_client._check_if_alive.assert_called_once_with() - mock_connection.disconnect.assert_awaited_once_with() - mock_connection_2.disconnect.assert_not_called() + patched__check_if_alive.assert_called_once_with() + mock_connection.disconnect.assert_awaited_once_with() + mock_connection_2.disconnect.assert_not_called() @pytest.mark.asyncio - async def test_disconnect_when_guild_id_not_in_connections(self, voice_client): + async def test_disconnect_when_guild_id_not_in_connections(self, voice_client: voice.VoiceComponentImpl): mock_connection = mock.AsyncMock() mock_connection_2 = mock.AsyncMock() - voice_client._connections = {123: mock_connection, 5324: mock_connection_2} + voice_client._connections = { + snowflakes.Snowflake(123): mock_connection, + snowflakes.Snowflake(5324): mock_connection_2, + } with pytest.raises(errors.VoiceError): await voice_client.disconnect(1234567890) @@ -92,10 +97,13 @@ async def test_disconnect_when_guild_id_not_in_connections(self, voice_client): mock_connection_2.disconnect.assert_not_called() @pytest.mark.asyncio - async def test__disconnect_all(self, voice_client): + async def test__disconnect_all(self, voice_client: voice.VoiceComponentImpl): mock_connection = mock.AsyncMock() mock_connection_2 = mock.AsyncMock() - voice_client._connections = {123: mock_connection, 5324: mock_connection_2} + voice_client._connections = { + snowflakes.Snowflake(123): mock_connection, + snowflakes.Snowflake(5324): mock_connection_2, + } await voice_client._disconnect_all() @@ -103,60 +111,65 @@ async def test__disconnect_all(self, voice_client): mock_connection_2.disconnect.assert_awaited_once_with() @pytest.mark.asyncio - async def test_disconnect_all(self, voice_client): - voice_client._disconnect_all = mock.AsyncMock() - voice_client._check_if_alive = mock.Mock() - - await voice_client.disconnect_all() + async def test_disconnect_all(self, voice_client: voice.VoiceComponentImpl): + with ( + mock.patch.object(voice.VoiceComponentImpl, "_disconnect_all", mock.AsyncMock()) as patched__disconnect_all, + mock.patch.object(voice.VoiceComponentImpl, "_check_if_alive") as patched__check_if_alive, + ): + await voice_client.disconnect_all() - voice_client._check_if_alive.assert_called_once_with() - voice_client._disconnect_all.assert_awaited_once_with() + patched__check_if_alive.assert_called_once_with() + patched__disconnect_all.assert_awaited_once_with() @pytest.mark.asyncio @pytest.mark.parametrize("voice_listener", [True, False]) - async def test_close(self, voice_client, mock_app, voice_listener): - voice_client._disconnect_all = mock.AsyncMock() - voice_client._connections = {123: None} - voice_client._check_if_alive = mock.Mock() - voice_client._voice_listener = voice_listener + async def test_close(self, voice_client: voice.VoiceComponentImpl, voice_listener: bool): + voice_client._connections = {snowflakes.Snowflake(123): mock.Mock()} + + with ( + mock.patch.object(voice.VoiceComponentImpl, "_disconnect_all", mock.AsyncMock()) as patched__disconnect_all, + mock.patch.object(voice.VoiceComponentImpl, "_check_if_alive") as patched__check_if_alive, + mock.patch.object(voice_client, "_voice_listener", voice_listener), + mock.patch.object(voice_client, "_app", mock.Mock(traits.EventManagerAware)) as patched_app, + mock.patch.object(patched_app.event_manager, "unsubscribe") as patched_unsubscribe, + ): + await voice_client.close() - await voice_client.close() + if voice_listener: + patched_unsubscribe.assert_called_once_with(voice_events.VoiceEvent, voice_client._on_voice_event) + else: + patched_unsubscribe.assert_not_called() - if voice_listener: - mock_app.event_manager.unsubscribe.assert_called_once_with( - voice_events.VoiceEvent, voice_client._on_voice_event - ) - else: - mock_app.event_manager.unsubscribe.assert_not_called() - - voice_client._check_if_alive.assert_called_once_with() - voice_client._disconnect_all.assert_awaited_once_with() - assert voice_client._is_alive is False - assert voice_client._is_closing is False + patched__check_if_alive.assert_called_once_with() + patched__disconnect_all.assert_awaited_once_with() + assert voice_client._is_alive is False + assert voice_client._is_closing is False @pytest.mark.asyncio @pytest.mark.parametrize("voice_listener", [True, False]) - async def test_close_when_no_connections(self, voice_client, mock_app, voice_listener): - voice_client._disconnect_all = mock.AsyncMock() + async def test_close_when_no_connections(self, voice_client: voice.VoiceComponentImpl, voice_listener: bool): voice_client._connections = {} - voice_client._check_if_alive = mock.Mock() - voice_client._voice_listener = voice_listener - await voice_client.close() + with ( + mock.patch.object(voice.VoiceComponentImpl, "_disconnect_all", mock.AsyncMock()) as patched__disconnect_all, + mock.patch.object(voice.VoiceComponentImpl, "_check_if_alive") as patched__check_if_alive, + mock.patch.object(voice_client, "_voice_listener", voice_listener), + mock.patch.object(voice_client, "_app", mock.Mock(traits.EventManagerAware)) as patched_app, + mock.patch.object(patched_app.event_manager, "unsubscribe") as patched_unsubscribe, + ): + await voice_client.close() - if voice_listener: - mock_app.event_manager.unsubscribe.assert_called_once_with( - voice_events.VoiceEvent, voice_client._on_voice_event - ) - else: - mock_app.event_manager.unsubscribe.assert_not_called() + if voice_listener: + patched_unsubscribe.assert_called_once_with(voice_events.VoiceEvent, voice_client._on_voice_event) + else: + patched_unsubscribe.assert_not_called() - voice_client._check_if_alive.assert_called_once_with() - voice_client._disconnect_all.assert_not_called() - assert voice_client._is_alive is False - assert voice_client._is_closing is False + patched__check_if_alive.assert_called_once_with() + patched__disconnect_all.assert_not_called() + assert voice_client._is_alive is False + assert voice_client._is_closing is False - def test_start(self, voice_client, mock_app): + def test_start(self, voice_client: voice.VoiceComponentImpl): voice_client._is_alive = False voice_client.start() @@ -164,229 +177,286 @@ def test_start(self, voice_client, mock_app): assert voice_client._is_alive is True @pytest.mark.asyncio - async def test_start_when_already_alive(self, voice_client, mock_app): + async def test_start_when_already_alive(self, voice_client: voice.VoiceComponentImpl): voice_client._is_alive = True with pytest.raises(errors.ComponentStateConflictError): - await voice_client.start() + voice_client.start() @pytest.mark.asyncio @pytest.mark.parametrize("voice_listener", [True, False]) - async def test_connect_to(self, voice_client, mock_app, voice_listener): - voice_client._init_state_update_predicate = mock.Mock() - voice_client._init_server_update_predicate = mock.Mock() - mock_other_connection = object() - voice_client._connections = {555: mock_other_connection} + async def test_connect_to( + self, voice_client: voice.VoiceComponentImpl, mock_app: traits.RESTAware, voice_listener: bool + ): mock_shard = mock.AsyncMock(is_alive=True) - mock_app.event_manager.wait_for = mock.AsyncMock() - mock_app.shard_count = 42 - mock_app.shards = {0: mock_shard} - mock_connection_type = mock.AsyncMock() - voice_client._check_if_alive = mock.Mock() - voice_client._voice_listener = voice_listener - - result = await voice_client.connect_to(123, 4532, mock_connection_type, deaf=False, mute=True, timeout=None) - - voice_client._check_if_alive.assert_called_once_with() - mock_app.event_manager.wait_for.assert_has_awaits( - [ - mock.call( - voice_events.VoiceStateUpdateEvent, - timeout=None, - predicate=voice_client._init_state_update_predicate.return_value, - ), - mock.call( - voice_events.VoiceServerUpdateEvent, - timeout=None, - predicate=voice_client._init_server_update_predicate.return_value, - ), - ] - ) - mock_app.rest.fetch_my_user.assert_not_called() - mock_app.cache.get_me.assert_called_once_with() - voice_client._init_state_update_predicate.assert_called_once_with(123, mock_app.cache.get_me.return_value.id) - voice_client._init_server_update_predicate.assert_called_once_with(123) - if voice_listener: - mock_app.event_manager.subscribe.assert_not_called() - else: - mock_app.event_manager.subscribe.assert_called_once_with( - voice_events.VoiceEvent, voice_client._on_voice_event - ) - assert voice_client._voice_listener is True - mock_shard.update_voice_state.assert_awaited_once_with(123, 4532, self_deaf=False, self_mute=True) - assert voice_client._connections == { - 123: mock_connection_type.initialize.return_value, - 555: mock_other_connection, - } - assert result is mock_connection_type.initialize.return_value + with ( + mock.patch.object( + voice.VoiceComponentImpl, "_init_state_update_predicate" + ) as patched__init_state_update_predicate, + mock.patch.object( + voice.VoiceComponentImpl, "_init_server_update_predicate" + ) as patched__init_server_update_predicate, + mock.patch.object(mock_app, "shard_count", 42), + mock.patch.object(mock_app, "shards", {0: mock_shard}), + mock.patch.object(mock_app, "event_manager") as patched_event_manager, + mock.patch.object(patched_event_manager, "wait_for", mock.AsyncMock()) as patched_wait_for, + mock.patch.object(patched_event_manager, "subscribe") as patched_subscribe, + mock.patch.object(voice.VoiceComponentImpl, "_check_if_alive") as patched__check_if_alive, + mock.patch.object(mock_app.rest, "fetch_my_user") as patched_fetch_my_user, + mock.patch.object(mock_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_me") as patched_get_me, + ): + mock_other_connection = mock.Mock(voice_api.VoiceComponent) + voice_client._connections = {snowflakes.Snowflake(555): mock_other_connection} + + # FIXME: How is this even legal????? + mock_connection_type = mock.Mock + mock_connection_type.initialize = mock.AsyncMock() + + voice_client._voice_listener = voice_listener + + result = await voice_client.connect_to(123, 4532, mock_connection_type, deaf=False, mute=True, timeout=None) + + patched__check_if_alive.assert_called_once_with() + patched_wait_for.assert_has_awaits( + [ + mock.call( + voice_events.VoiceStateUpdateEvent, + timeout=None, + predicate=patched__init_state_update_predicate.return_value, + ), + mock.call( + voice_events.VoiceServerUpdateEvent, + timeout=None, + predicate=patched__init_server_update_predicate.return_value, + ), + ] + ) + patched_fetch_my_user.assert_not_called() + patched_get_me.assert_called_once_with() + patched__init_state_update_predicate.assert_called_once_with(123, patched_get_me.return_value.id) + patched__init_server_update_predicate.assert_called_once_with(123) + if voice_listener: + patched_subscribe.assert_not_called() + else: + patched_subscribe.assert_called_once_with(voice_events.VoiceEvent, voice_client._on_voice_event) + + assert voice_client._voice_listener is True + mock_shard.update_voice_state.assert_awaited_once_with(123, 4532, self_deaf=False, self_mute=True) + assert voice_client._connections == { + 123: mock_connection_type.initialize.return_value, + 555: mock_other_connection, + } + assert result is mock_connection_type.initialize.return_value @pytest.mark.asyncio - async def test_connect_to_fails_when_wait_for_timeout(self, voice_client, mock_app): + async def test_connect_to_fails_when_wait_for_timeout( + self, voice_client: voice.VoiceComponentImpl, mock_app: traits.RESTAware + ): mock_shard = mock.AsyncMock(is_alive=True) mock_wait_for = mock.AsyncMock() mock_wait_for.side_effect = asyncio.TimeoutError - mock_app.event_manager.wait_for = mock_wait_for - mock_app.shard_count = 42 - mock_app.shards = {0: mock_shard} - mock_connection_type = mock.AsyncMock() - - with pytest.raises(errors.VoiceError, match="Could not connect to voice channel 4532 in guild 123."): + mock_connection_type = mock.AsyncMock + + with ( + mock.patch.object(mock_app, "shard_count", 42), + mock.patch.object(mock_app, "shards", {0: mock_shard}), + mock.patch.object(mock_app, "event_manager") as patched_event_manager, + mock.patch.object(patched_event_manager, "wait_for", mock_wait_for), + pytest.raises(errors.VoiceError, match="Could not connect to voice channel 4532 in guild 123."), + ): await voice_client.connect_to(123, 4532, mock_connection_type) @pytest.mark.asyncio - async def test_connect_to_falls_back_to_rest_to_get_own_user(self, voice_client, mock_app): - voice_client._init_state_update_predicate = mock.Mock() - voice_client._init_server_update_predicate = mock.Mock() + async def test_connect_to_falls_back_to_rest_to_get_own_user( + self, voice_client: voice.VoiceComponentImpl, mock_app: traits.RESTAware + ): mock_shard = mock.AsyncMock(is_alive=True) - mock_app.event_manager.wait_for = mock.AsyncMock() - mock_app.shard_count = 42 - mock_app.shards = {0: mock_shard} - mock_app.cache.get_me.return_value = None - mock_app.rest = mock.AsyncMock() - mock_connection_type = mock.AsyncMock() - - await voice_client.connect_to(123, 4532, mock_connection_type, deaf=False, mute=True, timeout=None) - - mock_app.event_manager.wait_for.assert_has_awaits( - [ - mock.call( - voice_events.VoiceStateUpdateEvent, - timeout=None, - predicate=voice_client._init_state_update_predicate.return_value, - ), - mock.call( - voice_events.VoiceServerUpdateEvent, - timeout=None, - predicate=voice_client._init_server_update_predicate.return_value, - ), - ] - ) - mock_app.cache.get_me.assert_called_once_with() - mock_app.rest.fetch_my_user.assert_awaited_once_with() - voice_client._init_state_update_predicate.assert_called_once_with( - 123, mock_app.rest.fetch_my_user.return_value.id - ) + + with ( + mock.patch.object( + voice.VoiceComponentImpl, "_init_state_update_predicate" + ) as patched__init_state_update_predicate, + mock.patch.object( + voice.VoiceComponentImpl, "_init_server_update_predicate" + ) as patched__init_server_update_predicate, + mock.patch.object(mock_app, "shard_count", 42), + mock.patch.object(mock_app, "shards", {0: mock_shard}), + mock.patch.object(mock_app, "event_manager") as patched_event_manager, + mock.patch.object(patched_event_manager, "wait_for", mock.AsyncMock()) as patched_wait_for, + mock.patch.object(mock_app.rest, "fetch_my_user", mock.AsyncMock()) as patched_fetch_my_user, + mock.patch.object(mock_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_me", return_value=None) as patched_get_me, + ): + mock_connection_type = mock.Mock + mock_connection_type.initialize = mock.AsyncMock() + + await voice_client.connect_to(123, 4532, mock_connection_type, deaf=False, mute=True, timeout=None) + + patched_wait_for.assert_has_awaits( + [ + mock.call( + voice_events.VoiceStateUpdateEvent, + timeout=None, + predicate=patched__init_state_update_predicate.return_value, + ), + mock.call( + voice_events.VoiceServerUpdateEvent, + timeout=None, + predicate=patched__init_server_update_predicate.return_value, + ), + ] + ) + patched_get_me.assert_called_once_with() + patched_fetch_my_user.assert_awaited_once_with() + patched__init_state_update_predicate.assert_called_once_with(123, patched_fetch_my_user.return_value.id) @pytest.mark.asyncio - async def test_connect_to_when_connection_already_present(self, voice_client, mock_app): - voice_client._connections = {snowflakes.Snowflake(123): object()} + async def test_connect_to_when_connection_already_present( + self, voice_client: voice.VoiceComponentImpl, mock_app: traits.RESTAware + ): + voice_client._connections = {snowflakes.Snowflake(123): mock.Mock()} with pytest.raises( errors.VoiceError, match="Already in a voice channel for that guild. Disconnect before attempting to connect again", ): - await voice_client.connect_to(123, 4532, object()) + await voice_client.connect_to(123, 4532, mock.Mock) @pytest.mark.asyncio - async def test_connect_to_for_unknown_shard(self, voice_client, mock_app): - mock_app.shard_count = 42 - mock_app.shards = {} - - with pytest.raises( - errors.VoiceError, match="Cannot connect to shard 0 as it is not present in this application" + async def test_connect_to_for_unknown_shard( + self, voice_client: voice.VoiceComponentImpl, mock_app: traits.RESTAware + ): + with ( + mock.patch.object(mock_app, "shard_count", 42), + mock.patch.object(mock_app, "shards", {}), + pytest.raises( + errors.VoiceError, match="Cannot connect to shard 0 as it is not present in this application" + ), ): - await voice_client.connect_to(123, 4532, object()) + await voice_client.connect_to(123, 4532, mock.Mock) @pytest.mark.asyncio - async def test_connect_to_handles_failed_connection_initialise(self, voice_client, mock_app): - voice_client._init_state_update_predicate = mock.Mock() - voice_client._init_server_update_predicate = mock.Mock() + async def test_connect_to_handles_failed_connection_initialise( + self, voice_client: voice.VoiceComponentImpl, mock_app: traits.RESTAware + ): mock_shard = mock.Mock(is_alive=True) + update_voice_state_call_1 = mock.AsyncMock() update_voice_state_call_2 = mock.Mock() mock_shard.update_voice_state = mock.Mock( side_effect=[update_voice_state_call_1(), update_voice_state_call_2()] ) - mock_app.event_manager.wait_for = mock.AsyncMock() - mock_app.shard_count = 42 - mock_app.shards = {0: mock_shard} - - class StubError(Exception): ... - - mock_connection_type = mock.AsyncMock() - mock_connection_type.initialize.side_effect = StubError - - with mock.patch.object( - asyncio, "wait_for", new=mock.AsyncMock(side_effect=asyncio.TimeoutError) - ) as asyncio_wait_for: - with pytest.raises(StubError): - await voice_client.connect_to(123, 4532, mock_connection_type, deaf=False, mute=True, timeout=None) - - mock_app.event_manager.wait_for.assert_has_awaits( - [ - mock.call( - voice_events.VoiceStateUpdateEvent, - timeout=None, - predicate=voice_client._init_state_update_predicate.return_value, - ), - mock.call( - voice_events.VoiceServerUpdateEvent, - timeout=None, - predicate=voice_client._init_server_update_predicate.return_value, - ), - ] - ) - mock_app.cache.get_me.assert_called_once_with() - voice_client._init_state_update_predicate.assert_called_once_with(123, mock_app.cache.get_me.return_value.id) - voice_client._init_server_update_predicate.assert_called_once_with(123) - mock_shard.update_voice_state.assert_has_calls( - [mock.call(123, 4532, self_deaf=False, self_mute=True), mock.call(123, None)] - ) - update_voice_state_call_1.assert_awaited_once() - asyncio_wait_for.assert_awaited_once_with(update_voice_state_call_2.return_value, timeout=5.0) + + with ( + mock.patch.object( + voice.VoiceComponentImpl, "_init_state_update_predicate" + ) as patched__init_state_update_predicate, + mock.patch.object( + voice.VoiceComponentImpl, "_init_server_update_predicate" + ) as patched__init_server_update_predicate, + mock.patch.object(mock_app, "shard_count", 42), + mock.patch.object(mock_app, "shards", {0: mock_shard}), + mock.patch.object(mock_app, "event_manager") as patched_event_manager, + mock.patch.object(patched_event_manager, "wait_for", mock.AsyncMock()) as patched_wait_for, + mock.patch.object(mock_app.rest, "fetch_my_user", mock.AsyncMock()), + mock.patch.object(mock_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_me") as patched_get_me, + ): + + class StubError(Exception): ... + + mock_connection_type = mock.Mock + mock_connection_type.initialize = mock.AsyncMock(side_effect=StubError) + + with mock.patch.object( + asyncio, "wait_for", new=mock.AsyncMock(side_effect=asyncio.TimeoutError) + ) as asyncio_wait_for: + with pytest.raises(StubError): + await voice_client.connect_to(123, 4532, mock_connection_type, deaf=False, mute=True, timeout=None) + + patched_wait_for.assert_has_awaits( + [ + mock.call( + voice_events.VoiceStateUpdateEvent, + timeout=None, + predicate=patched__init_state_update_predicate.return_value, + ), + mock.call( + voice_events.VoiceServerUpdateEvent, + timeout=None, + predicate=patched__init_server_update_predicate.return_value, + ), + ] + ) + patched_get_me.assert_called_once_with() + patched__init_state_update_predicate.assert_called_once_with(123, patched_get_me.return_value.id) + patched__init_server_update_predicate.assert_called_once_with(123) + mock_shard.update_voice_state.assert_has_calls( + [mock.call(123, 4532, self_deaf=False, self_mute=True), mock.call(123, None)] + ) + update_voice_state_call_1.assert_awaited_once() + asyncio_wait_for.assert_awaited_once_with(update_voice_state_call_2.return_value, timeout=5.0) @pytest.mark.asyncio @pytest.mark.parametrize("more_connections", [True, False]) - async def test__on_connection_close(self, voice_client, mock_app, more_connections): + async def test__on_connection_close( + self, voice_client: voice.VoiceComponentImpl, mock_app: traits.RESTAware, more_connections: bool + ): mock_shard = mock.AsyncMock() - mock_app.shards = {69: mock_shard} - voice_client._connections = {65234123: object()} + voice_client._connections = {snowflakes.Snowflake(65234123): mock.Mock()} expected_connections = {} if more_connections: - mock_connection = object() - voice_client._connections[123] = mock_connection + mock_connection = mock.Mock() + voice_client._connections[snowflakes.Snowflake(123)] = mock_connection expected_connections[123] = mock_connection - await voice_client._on_connection_close(mock.Mock(guild_id=65234123, shard_id=69)) + with ( + mock.patch.object(mock_app, "shards", {69: mock_shard}), + mock.patch.object(mock_app, "event_manager") as patched_event_manager, + mock.patch.object(patched_event_manager, "unsubscribe") as patched_unsubscribe, + ): + await voice_client._on_connection_close(mock.Mock(guild_id=65234123, shard_id=69)) if more_connections: - mock_app.event_manager.unsubscribe.assert_not_called() + patched_unsubscribe.assert_not_called() else: - mock_app.event_manager.unsubscribe.assert_called_once_with( - voice_events.VoiceEvent, voice_client._on_voice_event - ) + patched_unsubscribe.assert_called_once_with(voice_events.VoiceEvent, voice_client._on_voice_event) mock_shard.update_voice_state.assert_awaited_once_with(guild=65234123, channel=None) assert voice_client._connections == expected_connections - def test__init_state_update_predicate_matches(self, voice_client): - predicate = voice_client._init_state_update_predicate(42069, 696969) + def test__init_state_update_predicate_matches(self, voice_client: voice.VoiceComponentImpl): + predicate = voice_client._init_state_update_predicate(snowflakes.Snowflake(42069), snowflakes.Snowflake(696969)) mock_voice_state = mock.Mock(state=mock.Mock(guild_id=42069, user_id=696969)) assert predicate(mock_voice_state) is True - def test__init_state_update_predicate_ignores(self, voice_client): - predicate = voice_client._init_state_update_predicate(999, 420) + def test__init_state_update_predicate_ignores(self, voice_client: voice.VoiceComponentImpl): + predicate = voice_client._init_state_update_predicate(snowflakes.Snowflake(999), snowflakes.Snowflake(420)) mock_voice_state = mock.Mock(state=mock.Mock(guild_id=6969, user_id=3333)) assert predicate(mock_voice_state) is False - def test__init_server_update_predicate_matches(self, voice_client): - predicate = voice_client._init_server_update_predicate(696969) + def test__init_server_update_predicate_matches(self, voice_client: voice.VoiceComponentImpl): + predicate = voice_client._init_server_update_predicate(snowflakes.Snowflake(696969)) mock_voice_state = mock.Mock(guild_id=696969) assert predicate(mock_voice_state) is True - def test__init_server_update_predicate_ignores(self, voice_client): - predicate = voice_client._init_server_update_predicate(321231) + def test__init_server_update_predicate_ignores(self, voice_client: voice.VoiceComponentImpl): + predicate = voice_client._init_server_update_predicate(snowflakes.Snowflake(321231)) mock_voice_state = mock.Mock(guild_id=123123123) assert predicate(mock_voice_state) is False @pytest.mark.asyncio - async def test__on_connection_close_ignores_unknown_voice_state(self, voice_client): - connections = {123132: object(), 65234234: object()} + async def test__on_connection_close_ignores_unknown_voice_state(self, voice_client: voice.VoiceComponentImpl): + connections: dict[snowflakes.Snowflake, voice_api.VoiceConnection] = { + snowflakes.Snowflake(123132): mock.Mock(voice_api.VoiceConnection), + snowflakes.Snowflake(65234234): mock.Mock(voice_api.VoiceConnection), + } voice_client._connections = connections.copy() await voice_client._on_connection_close(mock.Mock(guild_id=-1)) @@ -394,9 +464,9 @@ async def test__on_connection_close_ignores_unknown_voice_state(self, voice_clie assert voice_client._connections == connections @pytest.mark.asyncio - async def test__on_voice_event(self, voice_client): + async def test__on_voice_event(self, voice_client: voice.VoiceComponentImpl): mock_connection = mock.AsyncMock() - voice_client._connections = {6633: mock_connection} + voice_client._connections = {snowflakes.Snowflake(6633): mock_connection} mock_event = mock.Mock(guild_id=6633) await voice_client._on_voice_event(mock_event) @@ -404,7 +474,7 @@ async def test__on_voice_event(self, voice_client): mock_connection.notify.assert_awaited_once_with(mock_event) @pytest.mark.asyncio - async def test__on_voice_event_for_untracked_guild(self, voice_client): + async def test__on_voice_event_for_untracked_guild(self, voice_client: voice.VoiceComponentImpl): mock_event = mock.Mock(guild_id=44444) await voice_client._on_voice_event(mock_event) diff --git a/tests/hikari/integration/test_equality_comparisons.py b/tests/hikari/integration/test_equality_comparisons.py index 75c83f4140..60cd330e0f 100644 --- a/tests/hikari/integration/test_equality_comparisons.py +++ b/tests/hikari/integration/test_equality_comparisons.py @@ -31,7 +31,7 @@ from hikari import users -def make_user(user_id): +def make_user(user_id: int) -> users.UserImpl: return users.UserImpl( app=mock.Mock(), id=snowflakes.Snowflake(user_id), @@ -49,7 +49,7 @@ def make_user(user_id): ) -def make_team_member(user_id): +def make_team_member(user_id: int) -> applications.TeamMember: user = make_user(user_id) return applications.TeamMember( membership_state=applications.TeamMembershipState.ACCEPTED, @@ -59,7 +59,7 @@ def make_team_member(user_id): ) -def make_guild_member(user_id): +def make_guild_member(user_id: int) -> guilds.Member: user = make_user(user_id) return guilds.Member( user=user, @@ -79,15 +79,15 @@ def make_guild_member(user_id): ) -def make_unicode_emoji(): +def make_unicode_emoji() -> emojis.UnicodeEmoji: return emojis.UnicodeEmoji("\N{OK HAND SIGN}") -def make_custom_emoji(emoji_id): +def make_custom_emoji(emoji_id: snowflakes.Snowflake) -> emojis.CustomEmoji: return emojis.CustomEmoji(id=emoji_id, name="testing", is_animated=False) -def make_known_custom_emoji(emoji_id): +def make_known_custom_emoji(emoji_id: snowflakes.Snowflake) -> emojis.KnownCustomEmoji: return emojis.KnownCustomEmoji( app=mock.Mock(), id=emoji_id, @@ -111,10 +111,10 @@ def make_known_custom_emoji(emoji_id): (make_user(1), make_guild_member(2), False), (make_team_member(1), make_guild_member(1), True), (make_team_member(1), make_guild_member(2), False), - (make_custom_emoji(1), make_known_custom_emoji(1), True), - (make_custom_emoji(1), make_known_custom_emoji(2), False), - (make_unicode_emoji(), make_custom_emoji(1), False), - (make_unicode_emoji(), make_known_custom_emoji(2), False), + (make_custom_emoji(snowflakes.Snowflake(1)), make_known_custom_emoji(snowflakes.Snowflake(1)), True), + (make_custom_emoji(snowflakes.Snowflake(1)), make_known_custom_emoji(snowflakes.Snowflake(2)), False), + (make_unicode_emoji(), make_custom_emoji(snowflakes.Snowflake(1)), False), + (make_unicode_emoji(), make_known_custom_emoji(snowflakes.Snowflake(2)), False), ], ids=[ "User == Team Member", @@ -129,7 +129,7 @@ def make_known_custom_emoji(emoji_id): "Unicode Emoji != Known Custom Emoji", ], ) -def test_comparison(a: object, b: object, eq: bool) -> None: +def test_comparison(a: users.UserImpl, b: applications.TeamMember, eq: bool) -> None: if eq: assert a == b assert b == a diff --git a/tests/hikari/interactions/test_base_interactions.py b/tests/hikari/interactions/test_base_interactions.py index 6b123eefa4..34c1ba1936 100644 --- a/tests/hikari/interactions/test_base_interactions.py +++ b/tests/hikari/interactions/test_base_interactions.py @@ -20,39 +20,37 @@ # SOFTWARE. from __future__ import annotations +import typing + import mock import pytest from hikari import applications from hikari import monetization +from hikari import permissions from hikari import snowflakes from hikari import traits from hikari import undefined from hikari.interactions import base_interactions -@pytest.fixture -def mock_app(): - return mock.Mock(traits.CacheAware, rest=mock.AsyncMock()) - - class TestPartialInteraction: @pytest.fixture - def mock_partial_interaction(self, mock_app): + def mock_partial_interaction(self, hikari_app: traits.RESTAware) -> base_interactions.PartialInteraction: return base_interactions.PartialInteraction( - app=mock_app, - id=34123, - application_id=651231, + app=hikari_app, + id=snowflakes.Snowflake(34123), + application_id=snowflakes.Snowflake(651231), type=base_interactions.InteractionType.APPLICATION_COMMAND, token="399393939doodsodso", version=3122312, guild_id=snowflakes.Snowflake(5412231), channel=mock.Mock(id=3123123), - member=object(), - user=object(), + member=mock.Mock(), + user=mock.Mock(), locale="es-ES", guild_locale="en-US", - app_permissions=123321, + app_permissions=permissions.Permissions.NONE, authorizing_integration_owners={ applications.ApplicationIntegrationType.GUILD_INSTALL: snowflakes.Snowflake(123) }, @@ -73,66 +71,75 @@ def mock_partial_interaction(self, mock_app): context=applications.ApplicationContextType.PRIVATE_CHANNEL, ) - def test_webhook_id_property(self, mock_partial_interaction): + def test_webhook_id_property(self, mock_partial_interaction: base_interactions.PartialInteraction): assert mock_partial_interaction.webhook_id is mock_partial_interaction.application_id @pytest.mark.asyncio - async def test_fetch_guild(self, mock_partial_interaction, mock_app): - mock_partial_interaction.guild_id = 43123123 + async def test_fetch_guild( + self, mock_partial_interaction: base_interactions.PartialInteraction, hikari_app: traits.RESTAware + ): + mock_partial_interaction.guild_id = snowflakes.Snowflake(43123123) - assert await mock_partial_interaction.fetch_guild() is mock_app.rest.fetch_guild.return_value + assert await mock_partial_interaction.fetch_guild() is hikari_app.rest.fetch_guild.return_value - mock_app.rest.fetch_guild.assert_awaited_once_with(43123123) + hikari_app.rest.fetch_guild.assert_awaited_once_with(43123123) @pytest.mark.asyncio - async def test_fetch_guild_for_dm_interaction(self, mock_partial_interaction, mock_app): + async def test_fetch_guild_for_dm_interaction( + self, mock_partial_interaction: base_interactions.PartialInteraction, hikari_app: traits.RESTAware + ): mock_partial_interaction.guild_id = None assert await mock_partial_interaction.fetch_guild() is None - mock_app.rest.fetch_guild.assert_not_called() + hikari_app.rest.fetch_guild.assert_not_called() - def test_get_guild(self, mock_partial_interaction, mock_app): - mock_partial_interaction.guild_id = 874356 + def test_get_guild(self, mock_partial_interaction: base_interactions.PartialInteraction): + mock_partial_interaction.guild_id = snowflakes.Snowflake(874356) - assert mock_partial_interaction.get_guild() is mock_app.cache.get_guild.return_value + with mock.patch.object( + mock_partial_interaction, "app", mock.Mock(traits.CacheAware, rest=mock.AsyncMock()) + ) as patched_hikari_app: + assert mock_partial_interaction.get_guild() is patched_hikari_app.cache.get_guild.return_value - mock_app.cache.get_guild.assert_called_once_with(874356) + patched_hikari_app.cache.get_guild.assert_called_once_with(874356) - def test_get_guild_for_dm_interaction(self, mock_partial_interaction, mock_app): + def test_get_guild_for_dm_interaction(self, mock_partial_interaction: base_interactions.PartialInteraction): mock_partial_interaction.guild_id = None - assert mock_partial_interaction.get_guild() is None + with mock.patch.object( + mock_partial_interaction, "app", mock.Mock(traits.CacheAware, rest=mock.AsyncMock()) + ) as patched_hikari_app: + assert mock_partial_interaction.get_guild() is None - mock_app.cache.get_guild.assert_not_called() + patched_hikari_app.cache.get_guild.assert_not_called() - def test_get_guild_when_cacheless(self, mock_partial_interaction, mock_app): - mock_partial_interaction.guild_id = 321123 - mock_partial_interaction.app = mock.Mock(traits.RESTAware) + def test_get_guild_when_cacheless(self, mock_partial_interaction: base_interactions.PartialInteraction): + mock_partial_interaction.guild_id = snowflakes.Snowflake(321123) assert mock_partial_interaction.get_guild() is None - mock_app.cache.get_guild.assert_not_called() - class TestMessageResponseMixin: @pytest.fixture - def mock_message_response_mixin(self, mock_app): - return base_interactions.MessageResponseMixin( - app=mock_app, - id=34123, - application_id=651231, + def mock_message_response_mixin( + self, hikari_app: traits.RESTAware + ) -> base_interactions.MessageResponseMixin[typing.Any]: + return base_interactions.MessageResponseMixin[typing.Any]( + app=hikari_app, + id=snowflakes.Snowflake(34123), + application_id=snowflakes.Snowflake(651231), type=base_interactions.InteractionType.APPLICATION_COMMAND, token="399393939doodsodso", version=3122312, context=applications.ApplicationContextType.PRIVATE_CHANNEL, guild_id=snowflakes.Snowflake(5412231), channel=mock.Mock(id=3123123), - member=object(), - user=object(), + member=mock.Mock(), + user=mock.Mock(), locale="es-ES", guild_locale="en-US", - app_permissions=123321, + app_permissions=permissions.Permissions.NONE, authorizing_integration_owners={ applications.ApplicationIntegrationType.GUILD_INSTALL: snowflakes.Snowflake(123) }, @@ -153,64 +160,80 @@ def mock_message_response_mixin(self, mock_app): ) @pytest.mark.asyncio - async def test_fetch_initial_response(self, mock_message_response_mixin, mock_app): - result = await mock_message_response_mixin.fetch_initial_response() + async def test_fetch_initial_response( + self, mock_message_response_mixin: base_interactions.MessageResponseMixin[typing.Any] + ): + with mock.patch.object( + mock_message_response_mixin.app.rest, "fetch_interaction_response", mock.AsyncMock() + ) as patched_fetch_interaction_response: + result = await mock_message_response_mixin.fetch_initial_response() - assert result is mock_app.rest.fetch_interaction_response.return_value - mock_app.rest.fetch_interaction_response.assert_awaited_once_with(651231, "399393939doodsodso") + assert result is patched_fetch_interaction_response.return_value + patched_fetch_interaction_response.assert_awaited_once_with(651231, "399393939doodsodso") @pytest.mark.asyncio - async def test_create_initial_response_with_optional_args(self, mock_message_response_mixin, mock_app): - mock_embed_1 = object() - mock_embed_2 = object() - mock_poll = object() - mock_component = object() - mock_components = object(), object() - mock_attachment = object() - mock_attachments = object(), object() - await mock_message_response_mixin.create_initial_response( - base_interactions.ResponseType.MESSAGE_CREATE, - "content", - tts=True, - embed=mock_embed_1, - flags=64, - embeds=[mock_embed_2], - poll=mock_poll, - component=mock_component, - components=mock_components, - attachment=mock_attachment, - attachments=mock_attachments, - mentions_everyone=False, - user_mentions=[123432], - role_mentions=[6324523], - ) - - mock_app.rest.create_interaction_response.assert_awaited_once_with( - 34123, - "399393939doodsodso", - base_interactions.ResponseType.MESSAGE_CREATE, - "content", - tts=True, - flags=64, - embed=mock_embed_1, - embeds=[mock_embed_2], - poll=mock_poll, - component=mock_component, - components=mock_components, - attachment=mock_attachment, - attachments=mock_attachments, - mentions_everyone=False, - user_mentions=[123432], - role_mentions=[6324523], - ) + async def test_create_initial_response_with_optional_args( + self, mock_message_response_mixin: base_interactions.MessageResponseMixin[typing.Any] + ): + mock_embed_1 = mock.Mock() + mock_embed_2 = mock.Mock() + mock_poll = mock.Mock() + mock_component = mock.Mock() + mock_components = mock.Mock(), mock.Mock() + mock_attachment = mock.Mock() + mock_attachments = mock.Mock(), mock.Mock() + + with mock.patch.object( + mock_message_response_mixin.app.rest, "create_interaction_response", mock.AsyncMock() + ) as patched_create_interaction_response: + await mock_message_response_mixin.create_initial_response( + base_interactions.ResponseType.MESSAGE_CREATE, + "content", + tts=True, + flags=64, + embed=mock_embed_1, + embeds=[mock_embed_2], + poll=mock_poll, + component=mock_component, + components=mock_components, + attachment=mock_attachment, + attachments=mock_attachments, + mentions_everyone=False, + user_mentions=[123432], + role_mentions=[6324523], + ) + + patched_create_interaction_response.assert_awaited_once_with( + 34123, + "399393939doodsodso", + base_interactions.ResponseType.MESSAGE_CREATE, + "content", + tts=True, + flags=64, + embed=mock_embed_1, + embeds=[mock_embed_2], + poll=mock_poll, + component=mock_component, + components=mock_components, + attachment=mock_attachment, + attachments=mock_attachments, + mentions_everyone=False, + user_mentions=[123432], + role_mentions=[6324523], + ) @pytest.mark.asyncio - async def test_create_initial_response_without_optional_args(self, mock_message_response_mixin, mock_app): - await mock_message_response_mixin.create_initial_response( - base_interactions.ResponseType.DEFERRED_MESSAGE_CREATE - ) - - mock_app.rest.create_interaction_response.assert_awaited_once_with( + async def test_create_initial_response_without_optional_args( + self, mock_message_response_mixin: base_interactions.MessageResponseMixin[typing.Any] + ): + with mock.patch.object( + mock_message_response_mixin.app.rest, "create_interaction_response", mock.AsyncMock() + ) as patched_create_interaction_response: + await mock_message_response_mixin.create_initial_response( + base_interactions.ResponseType.DEFERRED_MESSAGE_CREATE + ) + + patched_create_interaction_response.assert_awaited_once_with( 34123, "399393939doodsodso", base_interactions.ResponseType.DEFERRED_MESSAGE_CREATE, @@ -230,87 +253,102 @@ async def test_create_initial_response_without_optional_args(self, mock_message_ ) @pytest.mark.asyncio - async def test_edit_initial_response_with_optional_args(self, mock_message_response_mixin, mock_app): - mock_embed_1 = object() - mock_embed_2 = object() - mock_attachment_1 = object() - mock_attachment_2 = object() - mock_component = object() - mock_components = object(), object() - result = await mock_message_response_mixin.edit_initial_response( - "new content", - embed=mock_embed_1, - embeds=[mock_embed_2], - attachment=mock_attachment_1, - attachments=[mock_attachment_2], - component=mock_component, - components=mock_components, - mentions_everyone=False, - user_mentions=[123123], - role_mentions=[562134], - ) - - assert result is mock_app.rest.edit_interaction_response.return_value - mock_app.rest.edit_interaction_response.assert_awaited_once_with( - 651231, - "399393939doodsodso", - "new content", - embed=mock_embed_1, - embeds=[mock_embed_2], - attachment=mock_attachment_1, - attachments=[mock_attachment_2], - component=mock_component, - components=mock_components, - mentions_everyone=False, - user_mentions=[123123], - role_mentions=[562134], - ) + async def test_edit_initial_response_with_optional_args( + self, mock_message_response_mixin: base_interactions.MessageResponseMixin[typing.Any] + ): + mock_embed_1 = mock.Mock() + mock_embed_2 = mock.Mock() + mock_attachment_1 = mock.Mock() + mock_attachment_2 = mock.Mock() + mock_component = mock.Mock() + mock_components = mock.Mock(), mock.Mock() + + with mock.patch.object( + mock_message_response_mixin.app.rest, "edit_interaction_response", mock.AsyncMock() + ) as patched_edit_interaction_response: + result = await mock_message_response_mixin.edit_initial_response( + "new content", + embed=mock_embed_1, + embeds=[mock_embed_2], + attachment=mock_attachment_1, + attachments=[mock_attachment_2], + component=mock_component, + components=mock_components, + mentions_everyone=False, + user_mentions=[123123], + role_mentions=[562134], + ) + + assert result is patched_edit_interaction_response.return_value + patched_edit_interaction_response.assert_awaited_once_with( + 651231, + "399393939doodsodso", + "new content", + embed=mock_embed_1, + embeds=[mock_embed_2], + attachment=mock_attachment_1, + attachments=[mock_attachment_2], + component=mock_component, + components=mock_components, + mentions_everyone=False, + user_mentions=[123123], + role_mentions=[562134], + ) @pytest.mark.asyncio - async def test_edit_initial_response_without_optional_args(self, mock_message_response_mixin, mock_app): - result = await mock_message_response_mixin.edit_initial_response() - - assert result is mock_app.rest.edit_interaction_response.return_value - mock_app.rest.edit_interaction_response.assert_awaited_once_with( - 651231, - "399393939doodsodso", - undefined.UNDEFINED, - embed=undefined.UNDEFINED, - embeds=undefined.UNDEFINED, - attachment=undefined.UNDEFINED, - attachments=undefined.UNDEFINED, - component=undefined.UNDEFINED, - components=undefined.UNDEFINED, - mentions_everyone=undefined.UNDEFINED, - user_mentions=undefined.UNDEFINED, - role_mentions=undefined.UNDEFINED, - ) + async def test_edit_initial_response_without_optional_args( + self, mock_message_response_mixin: base_interactions.MessageResponseMixin[typing.Any] + ): + with mock.patch.object( + mock_message_response_mixin.app.rest, "edit_interaction_response", mock.AsyncMock() + ) as patched_edit_interaction_response: + result = await mock_message_response_mixin.edit_initial_response() + + assert result is patched_edit_interaction_response.return_value + patched_edit_interaction_response.assert_awaited_once_with( + 651231, + "399393939doodsodso", + undefined.UNDEFINED, + embed=undefined.UNDEFINED, + embeds=undefined.UNDEFINED, + attachment=undefined.UNDEFINED, + attachments=undefined.UNDEFINED, + component=undefined.UNDEFINED, + components=undefined.UNDEFINED, + mentions_everyone=undefined.UNDEFINED, + user_mentions=undefined.UNDEFINED, + role_mentions=undefined.UNDEFINED, + ) @pytest.mark.asyncio - async def test_delete_initial_response(self, mock_message_response_mixin, mock_app): - await mock_message_response_mixin.delete_initial_response() - - mock_app.rest.delete_interaction_response.assert_awaited_once_with(651231, "399393939doodsodso") + async def test_delete_initial_response( + self, mock_message_response_mixin: base_interactions.MessageResponseMixin[typing.Any] + ): + with mock.patch.object( + mock_message_response_mixin.app.rest, "delete_interaction_response", mock.AsyncMock() + ) as patched_delete_interaction_response: + await mock_message_response_mixin.delete_initial_response() + patched_delete_interaction_response.assert_awaited_once_with(651231, "399393939doodsodso") class TestModalResponseMixin: @pytest.fixture - def mock_modal_response_mixin(self, mock_app): + def mock_modal_response_mixin(self, hikari_app: traits.RESTAware) -> base_interactions.ModalResponseMixin: return base_interactions.ModalResponseMixin( - app=mock_app, - id=34123, - application_id=651231, + app=hikari_app, + id=snowflakes.Snowflake(34123), + application_id=snowflakes.Snowflake(651231), type=base_interactions.InteractionType.APPLICATION_COMMAND, token="399393939doodsodso", version=3122312, context=applications.ApplicationContextType.PRIVATE_CHANNEL, guild_id=snowflakes.Snowflake(5412231), channel=mock.Mock(id=3123123), - member=object(), - user=object(), + member=None, + user=mock.Mock(), locale="es-ES", guild_locale="en-US", - app_permissions=123321, + app_permissions=None, authorizing_integration_owners={ applications.ApplicationIntegrationType.GUILD_INSTALL: snowflakes.Snowflake(123) }, @@ -331,16 +369,26 @@ def mock_modal_response_mixin(self, mock_app): ) @pytest.mark.asyncio - async def test_create_modal_response(self, mock_modal_response_mixin, mock_app): - await mock_modal_response_mixin.create_modal_response("title", "custom_id", None, []) - - mock_app.rest.create_modal_response.assert_awaited_once_with( - 34123, "399393939doodsodso", title="title", custom_id="custom_id", component=None, components=[] - ) - - def test_build_response(self, mock_modal_response_mixin, mock_app): - mock_app.rest.interaction_modal_builder = mock.Mock() + async def test_create_modal_response(self, mock_modal_response_mixin: base_interactions.ModalResponseMixin): + with mock.patch.object( + mock_modal_response_mixin.app.rest, "create_modal_response", mock.AsyncMock() + ) as patched_create_modal_response: + await mock_modal_response_mixin.create_modal_response("title", "custom_id", undefined.UNDEFINED, []) + + patched_create_modal_response.assert_awaited_once_with( + 34123, + "399393939doodsodso", + title="title", + custom_id="custom_id", + component=undefined.UNDEFINED, + components=[], + ) + + def test_build_response( + self, mock_modal_response_mixin: base_interactions.ModalResponseMixin, hikari_app: traits.RESTAware + ): + hikari_app.rest.interaction_modal_builder = mock.Mock() builder = mock_modal_response_mixin.build_modal_response("title", "custom_id") - assert builder is mock_app.rest.interaction_modal_builder.return_value - mock_app.rest.interaction_modal_builder.assert_called_once_with(title="title", custom_id="custom_id") + assert builder is hikari_app.rest.interaction_modal_builder.return_value + hikari_app.rest.interaction_modal_builder.assert_called_once_with(title="title", custom_id="custom_id") diff --git a/tests/hikari/interactions/test_command_interactions.py b/tests/hikari/interactions/test_command_interactions.py index 40879fd3ba..eb83beb542 100644 --- a/tests/hikari/interactions/test_command_interactions.py +++ b/tests/hikari/interactions/test_command_interactions.py @@ -20,11 +20,14 @@ # SOFTWARE. from __future__ import annotations +import typing + import mock import pytest from hikari import applications from hikari import monetization +from hikari import permissions from hikari import snowflakes from hikari import traits from hikari.impl import special_endpoints @@ -32,22 +35,17 @@ from hikari.interactions import command_interactions -@pytest.fixture -def mock_app(): - return mock.Mock(traits.CacheAware, rest=mock.AsyncMock()) - - class TestCommandInteraction: @pytest.fixture - def mock_command_interaction(self, mock_app): + def mock_command_interaction(self, hikari_app: traits.RESTAware) -> command_interactions.CommandInteraction: return command_interactions.CommandInteraction( - app=mock_app, + app=hikari_app, id=snowflakes.Snowflake(2312312), type=base_interactions.InteractionType.APPLICATION_COMMAND, channel=mock.Mock(id=3123123), guild_id=snowflakes.Snowflake(5412231), - member=object(), - user=object(), + member=mock.Mock(), + user=mock.Mock(), token="httptptptptptptptp", version=1, application_id=snowflakes.Snowflake(43123), @@ -58,7 +56,7 @@ def mock_command_interaction(self, mock_app): resolved=None, locale="es-ES", guild_locale="en-US", - app_permissions=543123, + app_permissions=permissions.Permissions.NONE, registered_guild_id=snowflakes.Snowflake(12345678), entitlements=[ monetization.Entitlement( @@ -80,40 +78,45 @@ def mock_command_interaction(self, mock_app): context=applications.ApplicationContextType.PRIVATE_CHANNEL, ) - def test_channel_id_property(self, mock_command_interaction): - assert mock_command_interaction.channel_id == 3123123 - - def test_build_response(self, mock_command_interaction, mock_app): - mock_app.rest.interaction_message_builder = mock.Mock() + def test_build_response( + self, mock_command_interaction: command_interactions.CommandInteraction, hikari_app: mock.Mock + ): + hikari_app.rest.interaction_message_builder = mock.Mock() builder = mock_command_interaction.build_response() - assert builder is mock_app.rest.interaction_message_builder.return_value - mock_app.rest.interaction_message_builder.assert_called_once_with(base_interactions.ResponseType.MESSAGE_CREATE) + assert builder is hikari_app.rest.interaction_message_builder.return_value + hikari_app.rest.interaction_message_builder.assert_called_once_with( + base_interactions.ResponseType.MESSAGE_CREATE + ) - def test_build_deferred_response(self, mock_command_interaction, mock_app): - mock_app.rest.interaction_deferred_builder = mock.Mock() + def test_build_deferred_response( + self, mock_command_interaction: command_interactions.CommandInteraction, hikari_app: traits.RESTAware + ): + hikari_app.rest.interaction_deferred_builder = mock.Mock() builder = mock_command_interaction.build_deferred_response() - assert builder is mock_app.rest.interaction_deferred_builder.return_value - mock_app.rest.interaction_deferred_builder.assert_called_once_with( + assert builder is hikari_app.rest.interaction_deferred_builder.return_value + hikari_app.rest.interaction_deferred_builder.assert_called_once_with( base_interactions.ResponseType.DEFERRED_MESSAGE_CREATE ) class TestAutocompleteInteraction: @pytest.fixture - def mock_autocomplete_interaction(self, mock_app): + def mock_autocomplete_interaction( + self, hikari_app: traits.RESTAware + ) -> command_interactions.AutocompleteInteraction: return command_interactions.AutocompleteInteraction( - app=mock_app, + app=hikari_app, id=snowflakes.Snowflake(2312312), type=base_interactions.InteractionType.APPLICATION_COMMAND, channel=mock.Mock(3123123), guild_id=snowflakes.Snowflake(5412231), guild_locale="en-US", locale="en-US", - app_permissions=123321, - member=object(), - user=object(), + app_permissions=None, + member=mock.Mock(), + user=mock.Mock(), token="httptptptptptptptp", version=1, application_id=snowflakes.Snowflake(43123), @@ -143,28 +146,35 @@ def mock_autocomplete_interaction(self, mock_app): ) @pytest.fixture - def mock_command_choices(self): + def mock_command_choices(self) -> typing.Sequence[special_endpoints.AutocompleteChoiceBuilder]: return [ special_endpoints.AutocompleteChoiceBuilder(name="a", value="b"), special_endpoints.AutocompleteChoiceBuilder(name="foo", value="bar"), ] - def test_build_response(self, mock_autocomplete_interaction, mock_app, mock_command_choices): - mock_app.rest.interaction_autocomplete_builder = mock.Mock() + def test_build_response( + self, + mock_autocomplete_interaction: command_interactions.AutocompleteInteraction, + hikari_app: traits.RESTAware, + mock_command_choices: typing.Sequence[special_endpoints.AutocompleteChoiceBuilder], + ): + hikari_app.rest.interaction_autocomplete_builder = mock.Mock() builder = mock_autocomplete_interaction.build_response(mock_command_choices) - assert builder is mock_app.rest.interaction_autocomplete_builder.return_value - mock_app.rest.interaction_autocomplete_builder.assert_called_once_with(mock_command_choices) + assert builder is hikari_app.rest.interaction_autocomplete_builder.return_value + hikari_app.rest.interaction_autocomplete_builder.assert_called_once_with(mock_command_choices) @pytest.mark.asyncio async def test_create_response( self, mock_autocomplete_interaction: command_interactions.AutocompleteInteraction, - mock_app, - mock_command_choices, + mock_command_choices: typing.Sequence[special_endpoints.AutocompleteChoiceBuilder], ): - await mock_autocomplete_interaction.create_response(mock_command_choices) - - mock_app.rest.create_autocomplete_response.assert_awaited_once_with( - 2312312, "httptptptptptptptp", mock_command_choices - ) + with mock.patch.object( + mock_autocomplete_interaction.app.rest, "create_autocomplete_response", mock.AsyncMock() + ) as patched_create_autocomplete_response: + await mock_autocomplete_interaction.create_response(mock_command_choices) + + patched_create_autocomplete_response.assert_awaited_once_with( + 2312312, "httptptptptptptptp", mock_command_choices + ) diff --git a/tests/hikari/interactions/test_component_interactions.py b/tests/hikari/interactions/test_component_interactions.py index cbcb764038..a154cfbb2f 100644 --- a/tests/hikari/interactions/test_component_interactions.py +++ b/tests/hikari/interactions/test_component_interactions.py @@ -20,42 +20,43 @@ # SOFTWARE. from __future__ import annotations +import typing + import mock import pytest from hikari import applications from hikari import monetization +from hikari import permissions from hikari import snowflakes from hikari.interactions import base_interactions from hikari.interactions import component_interactions - -@pytest.fixture -def mock_app(): - return mock.Mock(rest=mock.AsyncMock()) +if typing.TYPE_CHECKING: + from hikari import traits class TestComponentInteraction: @pytest.fixture - def mock_component_interaction(self, mock_app): + def mock_component_interaction(self, hikari_app: traits.RESTAware) -> component_interactions.ComponentInteraction: return component_interactions.ComponentInteraction( - app=mock_app, + app=hikari_app, id=snowflakes.Snowflake(2312312), type=base_interactions.InteractionType.MESSAGE_COMPONENT, guild_id=snowflakes.Snowflake(5412231), - channel=object(), - member=object(), - user=object(), + channel=mock.Mock(), + member=mock.Mock(), + user=mock.Mock(), token="httptptptptptptptp", version=1, application_id=snowflakes.Snowflake(43123), component_type=2, values=(), custom_id="OKOKOK", - message=object(), + message=mock.Mock(), locale="es-ES", guild_locale="en-US", - app_permissions=123321, + app_permissions=permissions.Permissions.NONE, resolved=None, entitlements=[ monetization.Entitlement( @@ -77,24 +78,32 @@ def mock_component_interaction(self, mock_app): context=applications.ApplicationContextType.PRIVATE_CHANNEL, ) - def test_build_response(self, mock_component_interaction, mock_app): - mock_app.rest.interaction_message_builder = mock.Mock() + def test_build_response( + self, mock_component_interaction: component_interactions.ComponentInteraction, hikari_app: traits.RESTAware + ): + hikari_app.rest.interaction_message_builder = mock.Mock() response = mock_component_interaction.build_response(4) - assert response is mock_app.rest.interaction_message_builder.return_value - mock_app.rest.interaction_message_builder.assert_called_once_with(4) + assert response is hikari_app.rest.interaction_message_builder.return_value + hikari_app.rest.interaction_message_builder.assert_called_once_with(4) - def test_build_response_with_invalid_type(self, mock_component_interaction): + def test_build_response_with_invalid_type( + self, mock_component_interaction: component_interactions.ComponentInteraction + ): with pytest.raises(ValueError, match="Invalid type passed for an immediate response"): - mock_component_interaction.build_response(999) + mock_component_interaction.build_response(999) # pyright: ignore [reportArgumentType] - def test_build_deferred_response(self, mock_component_interaction, mock_app): - mock_app.rest.interaction_deferred_builder = mock.Mock() + def test_build_deferred_response( + self, mock_component_interaction: component_interactions.ComponentInteraction, hikari_app: traits.RESTAware + ): + hikari_app.rest.interaction_deferred_builder = mock.Mock() response = mock_component_interaction.build_deferred_response(5) - assert response is mock_app.rest.interaction_deferred_builder.return_value - mock_app.rest.interaction_deferred_builder.assert_called_once_with(5) + assert response is hikari_app.rest.interaction_deferred_builder.return_value + hikari_app.rest.interaction_deferred_builder.assert_called_once_with(5) - def test_build_deferred_response_with_invalid_type(self, mock_component_interaction): + def test_build_deferred_response_with_invalid_type( + self, mock_component_interaction: component_interactions.ComponentInteraction + ): with pytest.raises(ValueError, match="Invalid type passed for a deferred response"): mock_component_interaction.build_deferred_response(33333) diff --git a/tests/hikari/interactions/test_modal_interactions.py b/tests/hikari/interactions/test_modal_interactions.py index a9deebac83..1d7d96c827 100644 --- a/tests/hikari/interactions/test_modal_interactions.py +++ b/tests/hikari/interactions/test_modal_interactions.py @@ -26,6 +26,7 @@ from hikari import applications from hikari import components from hikari import monetization +from hikari import permissions from hikari import snowflakes from hikari import traits from hikari.impl import special_endpoints @@ -33,38 +34,35 @@ from hikari.interactions import modal_interactions -@pytest.fixture -def mock_app(): - return mock.Mock(rest=mock.AsyncMock()) - - class TestModalInteraction: @pytest.fixture - def mock_modal_interaction(self, mock_app): + def mock_modal_interaction(self, hikari_app: traits.RESTAware) -> modal_interactions.ModalInteraction: return modal_interactions.ModalInteraction( - app=mock_app, + app=hikari_app, id=snowflakes.Snowflake(2312312), type=base_interactions.InteractionType.APPLICATION_COMMAND, - channel=object(), + channel=mock.Mock(), guild_id=snowflakes.Snowflake(5412231), - member=object(), - user=object(), + member=mock.Mock(), + user=mock.Mock(), token="httptptptptptptptp", version=1, application_id=snowflakes.Snowflake(43123), custom_id="OKOKOK", - message=object(), + message=mock.Mock(), locale="es-ES", guild_locale="en-US", - app_permissions=543123, - components=special_endpoints.ModalActionRowBuilder( - id=9817398, - components=[ - components.TextInputComponent( - type=components.ComponentType.TEXT_INPUT, id=9817398, custom_id="le id", value="le value" - ) - ], - ), + app_permissions=None, + components=[ + special_endpoints.ModalActionRowBuilder( + id=9817398, + components=[ + special_endpoints.TextInputBuilder( + id=9817398, custom_id="le id", label="le label", value="le value" + ) + ], + ) + ], entitlements=[ monetization.Entitlement( id=snowflakes.Snowflake(123123), @@ -85,54 +83,73 @@ def mock_modal_interaction(self, mock_app): context=applications.ApplicationContextType.PRIVATE_CHANNEL, ) - def test_build_response(self, mock_modal_interaction, mock_app): - mock_app.rest.interaction_message_builder = mock.Mock() + def test_build_response( + self, mock_modal_interaction: modal_interactions.ModalInteraction, hikari_app: traits.RESTAware + ): + hikari_app.rest.interaction_message_builder = mock.Mock() response = mock_modal_interaction.build_response() - assert response is mock_app.rest.interaction_message_builder.return_value - mock_app.rest.interaction_message_builder.assert_called_once() + assert response is hikari_app.rest.interaction_message_builder.return_value + hikari_app.rest.interaction_message_builder.assert_called_once() - def test_build_deferred_response(self, mock_modal_interaction, mock_app): - mock_app.rest.interaction_deferred_builder = mock.Mock() + def test_build_deferred_response( + self, mock_modal_interaction: modal_interactions.ModalInteraction, hikari_app: traits.RESTAware + ): + hikari_app.rest.interaction_deferred_builder = mock.Mock() response = mock_modal_interaction.build_deferred_response() - assert response is mock_app.rest.interaction_deferred_builder.return_value - mock_app.rest.interaction_deferred_builder.assert_called_once() + assert response is hikari_app.rest.interaction_deferred_builder.return_value + hikari_app.rest.interaction_deferred_builder.assert_called_once() @pytest.mark.asyncio - async def test_fetch_guild(self, mock_modal_interaction, mock_app): - mock_modal_interaction.guild_id = 43123123 - - assert await mock_modal_interaction.fetch_guild() is mock_app.rest.fetch_guild.return_value - - mock_app.rest.fetch_guild.assert_awaited_once_with(43123123) + async def test_fetch_guild( + self, mock_modal_interaction: modal_interactions.ModalInteraction, hikari_app: traits.RESTAware + ): + with ( + mock.patch.object(mock_modal_interaction, "guild_id", snowflakes.Snowflake(43123123)), + mock.patch.object(hikari_app.rest, "fetch_guild", mock.AsyncMock()) as patched_fetch_guild, + ): + assert await mock_modal_interaction.fetch_guild() is patched_fetch_guild.return_value + patched_fetch_guild.assert_awaited_once_with(43123123) @pytest.mark.asyncio - async def test_fetch_guild_for_dm_interaction(self, mock_modal_interaction, mock_app): - mock_modal_interaction.guild_id = None - - assert await mock_modal_interaction.fetch_guild() is None - - mock_app.rest.fetch_guild.assert_not_called() - - def test_get_guild(self, mock_modal_interaction, mock_app): - mock_modal_interaction.guild_id = 874356 - - assert mock_modal_interaction.get_guild() is mock_app.cache.get_guild.return_value - - mock_app.cache.get_guild.assert_called_once_with(874356) - - def test_get_guild_for_dm_interaction(self, mock_modal_interaction, mock_app): - mock_modal_interaction.guild_id = None - - assert mock_modal_interaction.get_guild() is None - - mock_app.cache.get_guild.assert_not_called() - - def test_get_guild_when_cacheless(self, mock_modal_interaction, mock_app): - mock_modal_interaction.guild_id = 321123 + async def test_fetch_guild_for_dm_interaction( + self, mock_modal_interaction: modal_interactions.ModalInteraction, hikari_app: traits.RESTAware + ): + with ( + mock.patch.object(mock_modal_interaction, "guild_id", None), + mock.patch.object(hikari_app.rest, "fetch_guild") as patched_fetch_guild, + ): + assert await mock_modal_interaction.fetch_guild() is None + + patched_fetch_guild.assert_not_called() + + def test_get_guild(self, mock_modal_interaction: modal_interactions.ModalInteraction): + with ( + mock.patch.object(mock_modal_interaction, "app", mock.Mock(traits.CacheAware)) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_guild") as patched_get_guild, + ): + assert mock_modal_interaction.get_guild() is patched_get_guild.return_value + patched_get_guild.assert_called_once_with(5412231) + + def test_get_guild_for_dm_interaction(self, mock_modal_interaction: modal_interactions.ModalInteraction): + with ( + mock.patch.object(mock_modal_interaction, "guild_id", None), + mock.patch.object(mock_modal_interaction, "app", mock.Mock(traits.CacheAware)) as patched_app, + mock.patch.object(patched_app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_guild") as patched_get_guild, + ): + assert mock_modal_interaction.get_guild() is None + + patched_get_guild.assert_not_called() + + def test_get_guild_when_cacheless( + self, mock_modal_interaction: modal_interactions.ModalInteraction, hikari_app: traits.RESTAware + ): + mock_modal_interaction.guild_id = snowflakes.Snowflake(321123) mock_modal_interaction.app = mock.Mock(traits.RESTAware) assert mock_modal_interaction.get_guild() is None - mock_app.cache.get_guild.assert_not_called() + # hikari_app.cache.get_guild.assert_not_called() # FIXME: This isn't an easy thing to patch, because it complains that the mock app does not have the attribute cache anyways, so it can never be called. diff --git a/tests/hikari/internal/test_aio.py b/tests/hikari/internal/test_aio.py index d2b772aab6..b04c05101b 100644 --- a/tests/hikari/internal/test_aio.py +++ b/tests/hikari/internal/test_aio.py @@ -21,6 +21,7 @@ from __future__ import annotations import asyncio +import typing import mock.mock import pytest @@ -30,12 +31,12 @@ class CoroutineStub: - def __init__(self, *args, **kwargs): + def __init__(self, *args: typing.Any, **kwargs: typing.Any): self.awaited = False self.args = args self.kwargs = kwargs - def __eq__(self, other): + def __eq__(self, other: typing.Any): return isinstance(other, CoroutineStub) and self.args == other.args and self.kwargs == other.kwargs def __await__(self): @@ -50,7 +51,7 @@ def __repr__(self): class CoroutineFunctionStub: - def __call__(self, *args, **kwargs): + def __call__(self, *args: typing.Any, **kwargs: typing.Any): return CoroutineStub(*args, **kwargs) @@ -66,12 +67,12 @@ def test_coro_stub_neq(self): class TestCompletedFuture: @pytest.mark.asyncio @pytest.mark.parametrize("args", [(), (12,)]) - async def test_is_awaitable(self, args): + async def test_is_awaitable(self, args: tuple[int, ...]): await aio.completed_future(*args) @pytest.mark.asyncio @pytest.mark.parametrize("args", [(), (12,)]) - async def test_is_completed(self, args): + async def test_is_completed(self, args: tuple[int, ...]): future = aio.completed_future(*args) assert future.done() @@ -81,7 +82,8 @@ async def test_default_result_is_none(self): @pytest.mark.asyncio async def test_non_default_result(self): - assert aio.completed_future(...).result() is ... + obj = mock.Mock + assert aio.completed_future(obj).result() is obj @pytest.mark.asyncio @@ -214,7 +216,7 @@ async def test_waits_for_all(self): f2 = event_loop.create_future() f3 = event_loop.create_future() - async def quickly_run_task(task): + async def quickly_run_task(task: asyncio.Task[typing.Sequence[typing.Any]]): try: await asyncio.wait_for(asyncio.shield(task), timeout=0.01) except asyncio.TimeoutError: diff --git a/tests/hikari/internal/test_attr_extensions.py b/tests/hikari/internal/test_attr_extensions.py index e3062d595f..95fa45c565 100644 --- a/tests/hikari/internal/test_attr_extensions.py +++ b/tests/hikari/internal/test_attr_extensions.py @@ -22,6 +22,7 @@ import contextlib import copy as stdlib_copy +import typing import attrs import mock @@ -30,13 +31,13 @@ def test_invalidate_shallow_copy_cache(): - attrs_extensions._SHALLOW_COPIERS = {int: object(), str: object()} + attrs_extensions._SHALLOW_COPIERS = {int: mock.Mock(), str: mock.Mock()} assert attrs_extensions.invalidate_shallow_copy_cache() is None assert attrs_extensions._SHALLOW_COPIERS == {} def test_invalidate_deep_copy_cache(): - attrs_extensions._DEEP_COPIERS = {str: object(), int: object(), object: object()} + attrs_extensions._DEEP_COPIERS = {str: mock.Mock(), int: mock.Mock(), object: mock.Mock()} assert attrs_extensions.invalidate_deep_copy_cache() is None assert attrs_extensions._DEEP_COPIERS == {} @@ -139,36 +140,38 @@ class StubModel: ... def test_get_or_generate_shallow_copier_for_cached_copier(): - mock_copier = object() + mock_copier = mock.Mock() @attrs.define() class StubModel: ... attrs_extensions._SHALLOW_COPIERS = { - type("b", (), {}): object(), + type("b", (), {}): mock.Mock(), StubModel: mock_copier, - type("a", (), {}): object(), + type("a", (), {}): mock.Mock(), } assert attrs_extensions.get_or_generate_shallow_copier(StubModel) is mock_copier def test_get_or_generate_shallow_copier_for_uncached_copier(): - mock_copier = object() + mock_copier = mock.Mock() @attrs.define() class StubModel: ... - with mock.patch.object(attrs_extensions, "generate_shallow_copier", return_value=mock_copier): + with mock.patch.object( + attrs_extensions, "generate_shallow_copier", return_value=mock_copier + ) as patched_generate_shallow_copier: assert attrs_extensions.get_or_generate_shallow_copier(StubModel) is mock_copier - attrs_extensions.generate_shallow_copier.assert_called_once_with(StubModel) + patched_generate_shallow_copier.assert_called_once_with(StubModel) assert attrs_extensions._SHALLOW_COPIERS[StubModel] is mock_copier def test_copy_attrs(): - mock_result = object() + mock_result = mock.Mock() mock_copier = mock.Mock(return_value=mock_result) @attrs.define() @@ -176,10 +179,12 @@ class StubModel: ... model = StubModel() - with mock.patch.object(attrs_extensions, "get_or_generate_shallow_copier", return_value=mock_copier): + with mock.patch.object( + attrs_extensions, "get_or_generate_shallow_copier", return_value=mock_copier + ) as patched_get_or_generate_shallow_copier: assert attrs_extensions.copy_attrs(model) is mock_result - attrs_extensions.get_or_generate_shallow_copier.assert_called_once_with(StubModel) + patched_get_or_generate_shallow_copier.assert_called_once_with(StubModel) mock_copier.assert_called_once_with(model) @@ -195,21 +200,21 @@ class StubBaseClass: model = StubBaseClass(recursor=431, field=True, foo="blam") model.end = "the way" - model._blam = "555555" + model._blam = True old_model_fields = stdlib_copy.copy(model) - copied_recursor = object() - copied_field = object() - copied_foo = object() - copied_end = object() - copied_blam = object() - memo = {123: object()} + copied_recursor = mock.Mock() + copied_field = mock.Mock() + copied_foo = mock.Mock() + copied_end = mock.Mock() + copied_blam = mock.Mock() + memo = {123: mock.Mock()} with mock.patch.object( stdlib_copy, "deepcopy", side_effect=[copied_recursor, copied_field, copied_foo, copied_end, copied_blam] - ): + ) as patched_deepcopy: attrs_extensions.generate_deep_copier(StubBaseClass)(model, memo) - stdlib_copy.deepcopy.assert_has_calls( + patched_deepcopy.assert_has_calls( [ mock.call(old_model_fields.recursor, memo), mock.call(old_model_fields._field, memo), @@ -235,15 +240,17 @@ class StubBaseClass: model = StubBaseClass(recursor=431, field=True, foo="blam") old_model_fields = stdlib_copy.copy(model) - copied_recursor = object() - copied_field = object() - copied_foo = object() - memo = {123: object()} + copied_recursor = mock.Mock() + copied_field = mock.Mock() + copied_foo = mock.Mock() + memo = {123: mock.Mock()} - with mock.patch.object(stdlib_copy, "deepcopy", side_effect=[copied_recursor, copied_field, copied_foo]): + with mock.patch.object( + stdlib_copy, "deepcopy", side_effect=[copied_recursor, copied_field, copied_foo] + ) as patched_deepcopy: attrs_extensions.generate_deep_copier(StubBaseClass)(model, memo) - stdlib_copy.deepcopy.assert_has_calls( + patched_deepcopy.assert_has_calls( [ mock.call(old_model_fields.recursor, memo), mock.call(old_model_fields._field, memo), @@ -264,16 +271,16 @@ class StubBaseClass: model = StubBaseClass() model.end = "the way" - model._blam = "555555" + model._blam = False old_model_fields = stdlib_copy.copy(model) - copied_end = object() - copied_blam = object() - memo = {123: object()} + copied_end = mock.Mock() + copied_blam = mock.Mock() + memo = {123: mock.Mock()} - with mock.patch.object(stdlib_copy, "deepcopy", side_effect=[copied_end, copied_blam]): + with mock.patch.object(stdlib_copy, "deepcopy", side_effect=[copied_end, copied_blam]) as patched_deepcopy: attrs_extensions.generate_deep_copier(StubBaseClass)(model, memo) - stdlib_copy.deepcopy.assert_has_calls( + patched_deepcopy.assert_has_calls( [mock.call(old_model_fields.end, memo), mock.call(old_model_fields._blam, memo)] ) @@ -286,24 +293,26 @@ def test_generate_deep_copier_with_no_attributes(): class StubBaseClass: ... model = StubBaseClass() - memo = {123: object()} + memo = {123: mock.Mock()} - with mock.patch.object(stdlib_copy, "deepcopy", side_effect=NotImplementedError): + with mock.patch.object(stdlib_copy, "deepcopy", side_effect=NotImplementedError) as patched_deepcopy: attrs_extensions.generate_deep_copier(StubBaseClass)(model, memo) - stdlib_copy.deepcopy.assert_not_called() + patched_deepcopy.assert_not_called() def test_get_or_generate_deep_copier_for_cached_function(): class StubClass: ... - mock_copier = object() + mock_copier = mock.Mock() attrs_extensions._DEEP_COPIERS = {} - with mock.patch.object(attrs_extensions, "generate_deep_copier", return_value=mock_copier): + with mock.patch.object( + attrs_extensions, "generate_deep_copier", return_value=mock_copier + ) as patched_generate_deep_copier: assert attrs_extensions.get_or_generate_deep_copier(StubClass) is mock_copier - attrs_extensions.generate_deep_copier.assert_called_once_with(StubClass) + patched_generate_deep_copier.assert_called_once_with(StubClass) assert attrs_extensions._DEEP_COPIERS[StubClass] is mock_copier @@ -311,13 +320,13 @@ class StubClass: ... def test_get_or_generate_deep_copier_for_uncached_function(): class StubClass: ... - mock_copier = object() + mock_copier = mock.Mock() attrs_extensions._DEEP_COPIERS = {StubClass: mock_copier} - with mock.patch.object(attrs_extensions, "generate_deep_copier"): + with mock.patch.object(attrs_extensions, "generate_deep_copier") as patched_generate_deep_copier: assert attrs_extensions.get_or_generate_deep_copier(StubClass) is mock_copier - attrs_extensions.generate_deep_copier.assert_not_called() + patched_generate_deep_copier.assert_not_called() def test_deep_copy_attrs_without_memo(): @@ -346,7 +355,7 @@ class StubClass: ... mock_object = StubClass() mock_result = object() mock_copier = mock.Mock(mock_result) - mock_other_object = object() + mock_other_object = mock.Mock() stack = contextlib.ExitStack() stack.enter_context(mock.patch.object(attrs_extensions, "get_or_generate_deep_copier", return_value=mock_copier)) @@ -363,7 +372,7 @@ class StubClass: ... class TestCopyDecorator: def test___copy__(self): - mock_result = object() + mock_result = mock.Mock() mock_copier = mock.Mock(return_value=mock_result) @attrs.define() @@ -372,21 +381,23 @@ class StubClass: ... model = StubClass() - with mock.patch.object(attrs_extensions, "get_or_generate_shallow_copier", return_value=mock_copier): + with mock.patch.object( + attrs_extensions, "get_or_generate_shallow_copier", return_value=mock_copier + ) as patched_get_or_generate_shallow_copier: assert stdlib_copy.copy(model) is mock_result - attrs_extensions.get_or_generate_shallow_copier.assert_called_once_with(StubClass) + patched_get_or_generate_shallow_copier.assert_called_once_with(StubClass) mock_copier.assert_called_once_with(model) def test___deep__copy(self): class CopyingMock(mock.Mock): - def __call__(self, /, *args, **kwargs): - args = list(args) - args[1] = dict(args[1]) - return super().__call__(*args, **kwargs) + def __call__(self, /, *args: typing.Any, **kwargs: typing.Any): + new_args = list(args) + new_args[1] = dict(args[1]) + return super().__call__(*new_args, **kwargs) - mock_result = object() + mock_result = mock.Mock() mock_copier = CopyingMock(return_value=mock_result) @attrs.define() diff --git a/tests/hikari/internal/test_cache.py b/tests/hikari/internal/test_cache.py index 3d51ba0494..c3156744d4 100644 --- a/tests/hikari/internal/test_cache.py +++ b/tests/hikari/internal/test_cache.py @@ -31,7 +31,7 @@ class TestStickerData: def test_from_entity(self) -> None: - mock_user = object() + mock_user = mock.Mock() mock_sticker = stickers.GuildSticker( id=snowflakes.Snowflake(69420), name="lulzor", @@ -55,7 +55,7 @@ def test_from_entity(self) -> None: assert data.user is mock_user def test_from_entity_when_user_not_passed(self) -> None: - mock_user = object() + mock_user = mock.Mock() mock_sticker = mock_sticker = stickers.GuildSticker( id=snowflakes.Snowflake(69420), name="lulzor", diff --git a/tests/hikari/internal/test_collections.py b/tests/hikari/internal/test_collections.py index f5b9159069..f64b9eed79 100644 --- a/tests/hikari/internal/test_collections.py +++ b/tests/hikari/internal/test_collections.py @@ -23,6 +23,7 @@ import array as array_ import operator import sys +import typing import mock import pytest @@ -76,9 +77,9 @@ def test___len__(self): def test___setitem__(self): mock_map = collections.FreezableDict({"hmm": "forearm", "cat": "bag", "ok": "bye"}) - mock_map["bye"] = 4 + mock_map["bye"] = "goobyyy" - assert mock_map == {"hmm": "forearm", "cat": "bag", "ok": "bye", "bye": 4} + assert mock_map == {"hmm": "forearm", "cat": "bag", "ok": "bye", "bye": "goobyyy"} class TestLimitedCapacityCacheMap: @@ -109,45 +110,45 @@ def test_freeze(self): assert result == {"o": "no", "good": "bye"} def test___delitem___for_existing_entry(self): - mock_map = collections.LimitedCapacityCacheMap(limit=50) + mock_map: collections.LimitedCapacityCacheMap[str, typing.Any] = collections.LimitedCapacityCacheMap(limit=50) mock_map["Ok"] = 42 del mock_map["Ok"] assert "Ok" not in mock_map def test___delitem___for_non_existing_entry(self): - mock_map = collections.LimitedCapacityCacheMap(limit=50) + mock_map: collections.LimitedCapacityCacheMap[str, typing.Any] = collections.LimitedCapacityCacheMap(limit=50) with pytest.raises(KeyError): del mock_map["Blam"] def test___getitem___for_existing_entry(self): - mock_map = collections.LimitedCapacityCacheMap(limit=50) + mock_map: collections.LimitedCapacityCacheMap[str, typing.Any] = collections.LimitedCapacityCacheMap(limit=50) mock_map["blat"] = 42 assert mock_map["blat"] == 42 def test___getitem___for_non_existing_entry(self): - mock_map = collections.LimitedCapacityCacheMap(limit=50) + mock_map: collections.LimitedCapacityCacheMap[str, typing.Any] = collections.LimitedCapacityCacheMap(limit=50) with pytest.raises(KeyError): mock_map["CIA"] def test___iter___(self): - mock_map = collections.LimitedCapacityCacheMap(limit=50) + mock_map: collections.LimitedCapacityCacheMap[str, typing.Any] = collections.LimitedCapacityCacheMap(limit=50) mock_map.update({"OK": "blam", "blaaa": "neoeo", "neon": "genesis", "evangelion": None}) assert list(mock_map) == ["OK", "blaaa", "neon", "evangelion"] def test___len___(self): - mock_map = collections.LimitedCapacityCacheMap(limit=50) + mock_map: collections.LimitedCapacityCacheMap[str, typing.Any] = collections.LimitedCapacityCacheMap(limit=50) mock_map.update({"ooga": "blam", "blaaa": "neoeo", "the": "boys", "neon": "genesis", "evangelion": None}) assert len(mock_map) == 5 def test___setitem___when_limit_not_reached(self): - mock_map = collections.LimitedCapacityCacheMap(limit=50) + mock_map: collections.LimitedCapacityCacheMap[str, typing.Any] = collections.LimitedCapacityCacheMap(limit=50) mock_map["OK"] = 523 mock_map["blam"] = 512387 mock_map.update({"bll": "no", "ieiei": "lslsl"}) assert mock_map == {"OK": 523, "blam": 512387, "bll": "no", "ieiei": "lslsl"} def test___setitem___when_limit_reached(self): - mock_map = collections.LimitedCapacityCacheMap(limit=4) + mock_map: collections.LimitedCapacityCacheMap[str, typing.Any] = collections.LimitedCapacityCacheMap(limit=4) mock_map.update({"bll": "no", "ieiei": "lslsl", "pacify": "me", "qt": "pie"}) mock_map["eva"] = "Rei" mock_map.update({"shinji": "ikari"}) @@ -155,7 +156,9 @@ def test___setitem___when_limit_reached(self): def test___setitem___when_limit_reached_and_expire_callback_set(self): expire_callback = mock.Mock() - mock_map = collections.LimitedCapacityCacheMap(limit=4, on_expire=expire_callback) + mock_map: collections.LimitedCapacityCacheMap[str, typing.Any] = collections.LimitedCapacityCacheMap( + limit=4, on_expire=expire_callback + ) mock_map.update({"bll": "no", "ieiei": "lslsl", "pacify": "me", "qt": "pie"}) mock_map["eva"] = "Rei" mock_map.update({"shinji": "ikari"}) @@ -192,7 +195,9 @@ def test_init_creates_array(self): ([0, 122], [123, 121, 999991, 121, 121, 124, 120], [0, 120, 121, 122, 123, 124, 999991]), ], ) - def test_add_inserts_items(self, start_with, add_items, expect): + def test_add_inserts_items( + self, start_with: typing.Sequence[int], add_items: typing.Sequence[int], expect: typing.Sequence[int] + ): # given sfs = collections.SnowflakeSet() sfs._ids.extend(start_with) @@ -216,7 +221,9 @@ def test_add_inserts_items(self, start_with, add_items, expect): ([0, 122], [123, 121, 999991, 121, 121, 124, 120], [0, 120, 121, 122, 123, 124, 999991]), ], ) - def test_add_all_inserts_items(self, start_with, add_items, expect): + def test_add_all_inserts_items( + self, start_with: typing.Sequence[int], add_items: typing.Sequence[int], expect: typing.Sequence[int] + ): # given sfs = collections.SnowflakeSet() sfs._ids.extend(start_with) @@ -264,7 +271,9 @@ def test_clear_empties_empty_buffer(self): ([9, 18, 27, 36, 45, 54, 63], [18, 27, 18, 18, 36, 64, 63], [9, 45, 54]), ], ) - def test_discard(self, start_with, discard, expect): + def test_discard( + self, start_with: typing.Sequence[int], discard: typing.Sequence[int], expect: typing.Sequence[int] + ): # given sfs = collections.SnowflakeSet() sfs._ids.extend(start_with) @@ -292,7 +301,7 @@ def test_discard(self, start_with, discard, expect): ([12], "12", False), ], ) - def test_contains(self, start_with, look_for, expect): + def test_contains(self, start_with: typing.Sequence[int], look_for: int, expect: bool): # given sfs = collections.SnowflakeSet() sfs._ids.extend(start_with) @@ -308,7 +317,7 @@ def test_iter(self): assert list(sfs) == [9, 18, 27, 36, 45, 54, 63] @pytest.mark.parametrize("items", [*range(0, 10)]) - def test_len(self, items): + def test_len(self, items: int): # given sfs = collections.SnowflakeSet() sfs._ids.extend(i for i in range(items)) @@ -316,7 +325,7 @@ def test_len(self, items): assert len(sfs) == items @pytest.mark.parametrize("items", [*range(0, 10)]) - def test_length_hint(self, items): + def test_length_hint(self, items: int): # given sfs = collections.SnowflakeSet() sfs._ids.extend(i for i in range(items)) diff --git a/tests/hikari/internal/test_data_binding.py b/tests/hikari/internal/test_data_binding.py index 3625f9e3e4..38cea0b687 100644 --- a/tests/hikari/internal/test_data_binding.py +++ b/tests/hikari/internal/test_data_binding.py @@ -40,10 +40,10 @@ class MyUnique(snowflakes.Unique): class TestURLEncodedFormBuilder: @pytest.fixture - def form_builder(self): + def form_builder(self) -> data_binding.URLEncodedFormBuilder: return data_binding.URLEncodedFormBuilder() - def test_add_field(self, form_builder): + def test_add_field(self, form_builder: data_binding.URLEncodedFormBuilder): class TestBytesPayload: def __init__(self, value: bytes) -> None: self.inner = value @@ -67,15 +67,15 @@ def __repr__(self) -> str: ("test_name2", TestBytesPayload(b"test_data2"), "mimetype2"), ] - def test_add_resource(self, form_builder): - mock_resource = object() + def test_add_resource(self, form_builder: data_binding.URLEncodedFormBuilder): + mock_resource = mock.Mock() form_builder.add_resource("lick", mock_resource) assert form_builder._resources == [("lick", mock_resource)] @pytest.mark.asyncio - async def test_build(self, form_builder): + async def test_build(self, form_builder: data_binding.URLEncodedFormBuilder): resource1 = mock.Mock() resource2 = mock.Mock() stream1 = mock.Mock(filename="testing1", mimetype="text") @@ -83,7 +83,7 @@ async def test_build(self, form_builder): data1 = aiohttp.BytesPayload(b"data1") data2 = aiohttp.BytesPayload(b"data2") mock_stack = mock.AsyncMock(enter_async_context=mock.AsyncMock(side_effect=[stream1, stream2])) - executor = object() + executor = mock.Mock() form_builder._fields = [("test_name", data1, "mimetype"), ("test_name2", data2, "mimetype2")] form_builder._resources = [("aye", resource1), ("lmao", resource2)] @@ -150,13 +150,13 @@ def test_put_int(self): @pytest.mark.parametrize( ("name", "input_val", "expect"), [("a", True, "true"), ("b", False, "false"), ("c", None, "null")] ) - def test_put_py_singleton(self, name, input_val, expect): + def test_put_py_singleton(self, name: str, input_val: typing.Optional[typing.Union[str, bool]], expect: str): mapping = data_binding.StringMapBuilder() mapping.put(name, input_val) assert dict(mapping) == {name: expect} def test_put_with_conversion_uses_return_value(self): - def convert(_): + def convert(_: str): return "yeah, i got called" mapping = data_binding.StringMapBuilder() @@ -167,12 +167,12 @@ def test_put_with_conversion_passes_raw_input_to_converter(self): mapping = data_binding.StringMapBuilder() convert = mock.Mock() - expect = object() + expect = mock.Mock() mapping.put("yaskjgakljglak", expect, conversion=convert) convert.assert_called_once_with(expect) def test_put_py_singleton_conversion_runs_before_check(self): - def convert(_): + def convert(_: str): return True mapping = data_binding.StringMapBuilder() @@ -246,7 +246,7 @@ def test_put_array_with_conversion_uses_conversion_result(self): convert = mock.Mock(side_effect=[r1, r2, r3]) builder = data_binding.JSONObjectBuilder() - builder.put_array("www", [object(), object(), object()], conversion=convert) + builder.put_array("www", [mock.Mock(), mock.Mock(), mock.Mock()], conversion=convert) assert builder == {"www": [r1, r2, r3]} def test_put_array_with_conversion_passes_raw_input_to_converter(self): @@ -277,7 +277,7 @@ def test_put_snowflake_undefined(self): (snowflakes.Snowflake("100126"), "100126"), ], ) - def test_put_snowflake(self, input_value, expected_str): + def test_put_snowflake(self, input_value: snowflakes.Snowflake, expected_str: str): builder = data_binding.JSONObjectBuilder() builder.put_snowflake("WAWAWA!", input_value) assert builder == {"WAWAWA!": expected_str} @@ -298,7 +298,7 @@ def test_put_snowflake_none(self): (snowflakes.Snowflake("100126"), "100126"), ], ) - def test_put_snowflake_array_conversions(self, input_value, expected_str): + def test_put_snowflake_array_conversions(self, input_value: snowflakes.Snowflake, expected_str: str): builder = data_binding.JSONObjectBuilder() builder.put_snowflake_array("WAWAWAH!", [input_value] * 5) assert builder == {"WAWAWAH!": [expected_str] * 5} diff --git a/tests/hikari/internal/test_enums.py b/tests/hikari/internal/test_enums.py index 14ec284027..22411296a0 100644 --- a/tests/hikari/internal/test_enums.py +++ b/tests/hikari/internal/test_enums.py @@ -22,6 +22,7 @@ import builtins import operator +import typing import warnings import mock @@ -62,7 +63,9 @@ class Enum(metaclass=enums._EnumMeta): ("args", "kwargs"), [([str], {"metaclass": enums._EnumMeta}), ([enums.Enum], {"metaclass": enums._EnumMeta}), ([enums.Enum], {})], ) - def test_init_enum_type_with_one_base_is_TypeError(self, args, kwargs): + def test_init_enum_type_with_one_base_is_TypeError( + self, args: typing.Sequence[type], kwargs: dict[str, typing.Any] + ): with pytest.raises(TypeError): class Enum(*args, **kwargs): @@ -71,7 +74,9 @@ class Enum(*args, **kwargs): @pytest.mark.parametrize( ("args", "kwargs"), [([enums.Enum, str], {"metaclass": enums._EnumMeta}), ([enums.Enum, str], {})] ) - def test_init_enum_type_with_bases_in_wrong_order_is_TypeError(self, args, kwargs): + def test_init_enum_type_with_bases_in_wrong_order_is_TypeError( + self, args: typing.Sequence[type], kwargs: dict[str, typing.Any] + ): with pytest.raises(TypeError): class Enum(*args, **kwargs): @@ -282,7 +287,7 @@ def __str__(self) -> str: assert str(TestEnum1.FOO) == "Ok" @pytest.mark.parametrize(("type_", "value"), [(int, 42), (str, "ok"), (bytes, b"no"), (float, 4.56), (complex, 3j)]) - def test_inherits_type_dunder_method_behaviour(self, type_, value): + def test_inherits_type_dunder_method_behaviour(self, type_: type, value: typing.Union[int, str]): class TestEnum(type_, enums.Enum): BAR = value diff --git a/tests/hikari/internal/test_fast_protocols.py b/tests/hikari/internal/test_fast_protocols.py index cf6787eac8..8f8f35a1cc 100644 --- a/tests/hikari/internal/test_fast_protocols.py +++ b/tests/hikari/internal/test_fast_protocols.py @@ -91,10 +91,10 @@ class MyProtocol(fast_protocol.FastProtocolChecking): ... def test_new(self): class MyProtocol(fast_protocol.FastProtocolChecking, typing.Protocol): - def test1(): ... + def test1(self): ... class OtherProtocol(MyProtocol, fast_protocol.FastProtocolChecking, typing.Protocol): - def test2(): ... + def test2(self): ... assert sorted(OtherProtocol._attributes_) == ["test1", "test2"] @@ -121,7 +121,7 @@ class Class: ... def test_isinstance_fastfail(self): class MyProtocol(fast_protocol.FastProtocolChecking, typing.Protocol): - def test(): ... + def test(self): ... class Class: ... @@ -129,10 +129,10 @@ class Class: ... def test_isinstance(self): class MyProtocol(fast_protocol.FastProtocolChecking, typing.Protocol): - def test(): ... + def test(self): ... class Class: - def test(): ... + def test(self): ... assert isinstance(Class(), MyProtocol) is True @@ -150,7 +150,7 @@ class Class: ... def test_issubclass_fastfail(self): class MyProtocol(fast_protocol.FastProtocolChecking, typing.Protocol): - def test(): ... + def test(self): ... class Class: ... @@ -158,9 +158,9 @@ class Class: ... def test_issubclass(self): class MyProtocol(fast_protocol.FastProtocolChecking, typing.Protocol): - def test(): ... + def test(self): ... class Class: - def test(): ... + def test(self): ... assert issubclass(Class, MyProtocol) is True diff --git a/tests/hikari/internal/test_mentions.py b/tests/hikari/internal/test_mentions.py index 379f5ecd14..1d8ef4f9f7 100644 --- a/tests/hikari/internal/test_mentions.py +++ b/tests/hikari/internal/test_mentions.py @@ -20,6 +20,8 @@ # SOFTWARE. from __future__ import annotations +import typing + import pytest from hikari import undefined @@ -40,7 +42,7 @@ ), ], ) -def test_generate_allowed_mentions(function_input, expected_output): +def test_generate_allowed_mentions(function_input: tuple[bool, ...], expected_output: dict[str, typing.Any]): returned = mentions.generate_allowed_mentions(*function_input) for k, v in expected_output.items(): if isinstance(v, list): diff --git a/tests/hikari/internal/test_net.py b/tests/hikari/internal/test_net.py index f7fbbc1325..b0f07895be 100644 --- a/tests/hikari/internal/test_net.py +++ b/tests/hikari/internal/test_net.py @@ -21,6 +21,7 @@ from __future__ import annotations import http +import typing import aiohttp import mock @@ -42,7 +43,7 @@ ], ) @pytest.mark.asyncio -async def test_generate_error_response(status_, expected_error): +async def test_generate_error_response(status_: http.HTTPStatus, expected_error: str): class StubResponse: real_url = "https://some.url" status = status_ @@ -95,7 +96,7 @@ async def read(self): ], ) @pytest.mark.asyncio -async def test_generate_error_response_with_non_conforming_status_code(status_, expected_error): +async def test_generate_error_response_with_non_conforming_status_code(status_: int, expected_error: str): class StubResponse: real_url = "https://some.url" status = status_ @@ -121,7 +122,7 @@ async def read(self): ], ) @pytest.mark.asyncio -async def test_generate_error_when_error_without_json(status_, expected_error): +async def test_generate_error_when_error_without_json(status_: http.HTTPStatus, expected_error: str): class StubResponse: real_url = "https://some.url" status = status_ @@ -143,7 +144,7 @@ class StubResponse: real_url = "https://some.url" status = http.HTTPStatus.BAD_REQUEST headers = {} - json = mock.AsyncMock(side_effect=aiohttp.ContentTypeError(None, None)) + json = mock.AsyncMock(side_effect=aiohttp.ContentTypeError(mock.Mock(), ())) async def read(self): return "some raw body" @@ -164,7 +165,9 @@ async def read(self): ], ) @pytest.mark.asyncio -async def test_generate_bad_request_error_with_json_response(data, expected_errors): +async def test_generate_bad_request_error_with_json_response( + data: str, expected_errors: typing.Optional[dict[str, typing.Any]] +): class StubResponse: real_url = "https://some.url" status = http.HTTPStatus.BAD_REQUEST diff --git a/tests/hikari/internal/test_reflect.py b/tests/hikari/internal/test_reflect.py index c233ff27d9..465d1973bb 100644 --- a/tests/hikari/internal/test_reflect.py +++ b/tests/hikari/internal/test_reflect.py @@ -41,9 +41,9 @@ def foo(bar: str, bat: int) -> str: ... assert signature.return_annotation is str def test_handles_normal_no_annotations(self): - def foo(bar, bat): ... + def foo(bar, bat): ... # pyright: ignore [reportMissingParameterType, reportUnknownParameterType] - signature = reflect.resolve_signature(foo) + signature = reflect.resolve_signature(foo) # pyright: ignore [reportUnknownArgumentType] assert signature.parameters["bar"].annotation is reflect.EMPTY assert signature.parameters["bat"].annotation is reflect.EMPTY assert signature.return_annotation is reflect.EMPTY @@ -90,14 +90,14 @@ def foo(bar: None) -> None: ... def test_handles_NoneType(self): def foo(bar: type(None)) -> type(None): ... - signature = reflect.resolve_signature(foo) + signature = reflect.resolve_signature(foo) # pyright: ignore [reportUnknownArgumentType] assert signature.parameters["bar"].annotation is None assert signature.return_annotation is None def test_handles_only_return_annotated(self): - def foo(bar, bat) -> str: ... + def foo(bar, bat) -> str: ... # pyright: ignore [reportMissingParameterType, reportUnknownParameterType] - signature = reflect.resolve_signature(foo) + signature = reflect.resolve_signature(foo) # pyright: ignore [reportUnknownArgumentType] assert signature.parameters["bar"].annotation is reflect.EMPTY assert signature.parameters["bat"].annotation is reflect.EMPTY assert signature.return_annotation is str @@ -111,7 +111,7 @@ def foo(bar: typing.Optional[typing.Iterator[int]]): ... @pytest.mark.skipif(sys.version_info < (3, 10), reason="This strategy is specific to 3.10 >= versions") def test_resolve_signature(): - foo = object() + foo = mock.Mock() with mock.patch.object(inspect, "signature") as signature: sig = reflect.resolve_signature(foo) diff --git a/tests/hikari/internal/test_routes.py b/tests/hikari/internal/test_routes.py index dc7d745a84..a6a77856b4 100644 --- a/tests/hikari/internal/test_routes.py +++ b/tests/hikari/internal/test_routes.py @@ -20,12 +20,13 @@ # SOFTWARE. from __future__ import annotations +import typing + import mock import pytest from hikari import files from hikari.internal import routes -from tests.hikari import hikari_test_helpers class TestCompiledRoute: @@ -35,16 +36,16 @@ def compiled_route(self): major_param_hash="abc123", route=mock.Mock(method="GET"), compiled_path="/some/endpoint" ) - def test_method(self, compiled_route): + def test_method(self, compiled_route: routes.CompiledRoute): assert compiled_route.method == "GET" - def test_create_url(self, compiled_route): + def test_create_url(self, compiled_route: routes.CompiledRoute): assert compiled_route.create_url("https://some.url/api") == "https://some.url/api/some/endpoint" - def test_create_real_bucket_hash(self, compiled_route): + def test_create_real_bucket_hash(self, compiled_route: routes.CompiledRoute): assert compiled_route.create_real_bucket_hash("UNKNOWN", "AUTH_HASH") == "UNKNOWN;AUTH_HASH;abc123" - def test__str__(self, compiled_route): + def test__str__(self, compiled_route: routes.CompiledRoute): assert str(compiled_route) == "GET /some/endpoint" @@ -59,7 +60,7 @@ class TestRoute: (routes.GET_INVITE, None), ], ) - def test_major_params(self, route, params): + def test_major_params(self, route: routes.Route, params: typing.Optional[frozenset[tuple[str, ...]]]): assert route.major_params == params def test_compile_with_no_major_params(self): @@ -148,7 +149,7 @@ def test_hash_operator_considers_path_template_only(self): ("LOTTIE", "json"), ], ) - def test_compile_uses_correct_extensions(self, input_file_format, expected_file_format): + def test_compile_uses_correct_extensions(self, input_file_format: str, expected_file_format: str): route = routes.CDNRoute("/foo/bar", {"PNG", "JPG", "JPEG", "WEBP", "AWEBP", "APNG", "GIF", "LOTTIE"}) compiled_url = route.compile("http://example.com", file_format=input_file_format) @@ -172,7 +173,7 @@ def test_compile_includes_lossless_on_webp(self): @pytest.mark.parametrize( ("input_file_format", "expected_file_format"), [("JPG", "jpg"), ("jpg", "jpg"), ("PNG", "png"), ("PNG", "png")] ) - def test_compile_uses_lowercase_file_format_always(self, input_file_format, expected_file_format): + def test_compile_uses_lowercase_file_format_always(self, input_file_format: str, expected_file_format: str): route = routes.CDNRoute("/foo/bar", {"PNG", "JPG"}) compiled_url = route.compile("http://example.com", file_format=input_file_format) @@ -220,7 +221,7 @@ def test_passing_negative_sizes_to_sizable_raises_ValueError(self): route.compile("http://example.com", file_format="PNG", hash="boooob", size=-10) @pytest.mark.parametrize("size", [int(2**size) for size in range(4, 13)]) - def test_passing_valid_sizes_to_sizable_does_not_raise_ValueError(self, size): + def test_passing_valid_sizes_to_sizable_does_not_raise_ValueError(self, size: int): route = routes.CDNRoute("/foo/bar", {"PNG", "JPG", "GIF"}) route.compile("http://example.com", file_format="PNG", hash="boooob", size=size) @@ -286,7 +287,16 @@ def test_passing_no_size_does_not_add_query_string(self): ), ], ) - def test_compile_generates_expected_url(self, base_url, template, file_format, kwds, foo, bar, expected_url): + def test_compile_generates_expected_url( + self, + base_url: str, + template: str, + file_format: str, + kwds: dict[str, typing.Any], + foo: str, + bar: str | int, + expected_url: str, + ): route = routes.CDNRoute(template, {"PNG", "GIF", "JPG", "WEBP", "APNG"}) actual_url = route.compile(base_url=base_url, file_format=file_format, foo=foo, bar=bar, **kwds) diff --git a/tests/hikari/internal/test_signals.py b/tests/hikari/internal/test_signals.py index 836ccc1579..22f59c6d28 100644 --- a/tests/hikari/internal/test_signals.py +++ b/tests/hikari/internal/test_signals.py @@ -20,7 +20,6 @@ # SOFTWARE. from __future__ import annotations -import contextlib import signal import mock @@ -37,7 +36,7 @@ def test__raise_interrupt(): @pytest.mark.parametrize("trace", [True, False]) -def test__interrupt_handler(trace): +def test__interrupt_handler(trace: bool): loop = mock.Mock() with mock.patch.object(signals, "_LOGGER", new=mock.Mock(isEnabledFor=mock.Mock(return_value=trace))): @@ -50,16 +49,15 @@ def test__interrupt_handler(trace): class TestHandleInterrupt: def test_behaviour(self): - loop = object() - - stack = contextlib.ExitStack() - register_signal_handler = stack.enter_context(mock.patch.object(signal, "signal")) - interrupt_handler = stack.enter_context(mock.patch.object(signals, "_interrupt_handler")) - stack.enter_context(mock.patch.object(signal, "SIGINT", new=2, create=True)) - stack.enter_context(mock.patch.object(signal, "SIGTERM", new=15, create=True)) - stack.enter_context(mock.patch.object(signals, "_INTERRUPT_SIGNALS", ("SIGINT", "SIGTERM", "UNIMPLEMENTED"))) - - with stack: + loop = mock.Mock() + + with ( + mock.patch.object(signal, "signal") as register_signal_handler, + mock.patch.object(signals, "_interrupt_handler") as interrupt_handler, + mock.patch.object(signal, "SIGINT", new=2, create=True), + mock.patch.object(signal, "SIGTERM", new=15, create=True), + mock.patch.object(signals, "_INTERRUPT_SIGNALS", ("SIGINT", "SIGTERM", "UNIMPLEMENTED")), + ): with signals.handle_interrupts(loop, propagate_interrupts=True, enabled=True): interrupt_handler.assert_called_once_with(loop) diff --git a/tests/hikari/internal/test_time.py b/tests/hikari/internal/test_time.py index 63b9185254..a846afd897 100644 --- a/tests/hikari/internal/test_time.py +++ b/tests/hikari/internal/test_time.py @@ -21,6 +21,7 @@ from __future__ import annotations import datetime +import typing import mock import pytest @@ -38,6 +39,7 @@ def test_parse_iso_8601_date_with_negative_timezone(): assert date.minute == 22 assert date.second == 33 assert date.microsecond == 23456 + assert date.tzinfo is not None offset = date.tzinfo.utcoffset(None) assert offset == datetime.timedelta(hours=-2, minutes=-30) @@ -52,6 +54,7 @@ def test_slow_parse_iso_8601_date_with_positive_timezone(): assert date.minute == 22 assert date.second == 33 assert date.microsecond == 23456 + assert date.tzinfo is not None offset = date.tzinfo.utcoffset(None) assert offset == datetime.timedelta(hours=2, minutes=30) @@ -66,6 +69,7 @@ def test_parse_iso_8601_date_with_zulu(): assert date.minute == 22 assert date.second == 33 assert date.microsecond == 23456 + assert date.tzinfo is not None offset = date.tzinfo.utcoffset(None) assert offset == datetime.timedelta(seconds=0) @@ -146,7 +150,7 @@ def test_unix_epoch_to_datetime_with_out_of_range_negative_timestamp(): (datetime.timedelta(days=-5, seconds=-3, milliseconds=12), 0), ], ) -def test_timespan_to_int(input_value, expected_result): +def test_timespan_to_int(input_value: typing.Union[int, float, datetime.timedelta], expected_result: int): assert time.timespan_to_int(input_value) == expected_result diff --git a/tests/hikari/internal/test_ux.py b/tests/hikari/internal/test_ux.py index 6e2de1d6a7..84b890682b 100644 --- a/tests/hikari/internal/test_ux.py +++ b/tests/hikari/internal/test_ux.py @@ -22,6 +22,7 @@ import contextlib import importlib +import importlib.resources import logging import logging.config import os @@ -30,8 +31,11 @@ import string import sys import time +import typing import colorlog +import colorlog.escape_codes +import colorlog.formatter import mock import pytest @@ -91,7 +95,7 @@ def test_when_flavour_is_a_dict_and_is_not_incremental(self): def test_when_flavour_is_a_dict_and_is_incremental(self): # This will emulate it being populated during the basicConfig call - def _basicConfig(*args, **kwargs): + def _basicConfig(*args: typing.Any, **kwargs: typing.Any): logging_basic_config(*args, **kwargs) stack.enter_context(mock.patch.object(logging.root, "handlers", new=[handler])) @@ -118,7 +122,7 @@ def _basicConfig(*args, **kwargs): def test_when_supports_color(self): # This will emulate it being populated during the basicConfig call - def _basicConfig(*args, **kwargs): + def _basicConfig(*args: typing.Any, **kwargs: typing.Any): logging_basic_config(*args, **kwargs) stack.enter_context(mock.patch.object(logging.root, "handlers", new=[handler])) @@ -231,18 +235,18 @@ class MockTraversable: mock_file = None open_encoding = None - def joinpath(self, path): + def joinpath(self, path: str): self.joint_path = path return self - def open(self, mode, encoding): + def open(self, mode: str, encoding: str): self.open_mode = mode self.open_encoding = encoding return self.mock_file traversable = MockTraversable() traversable.mock_file = MockFile() - read = object() + read = mock.Mock() with mock.patch.object(importlib.resources, "files", return_value=traversable, create=True) as read_text: assert ux._read_banner("hikaru") is read @@ -263,7 +267,7 @@ def test_when_package_is_none(self): write.assert_not_called() @pytest.fixture - def mock_args(self): + def mock_args(self) -> typing.Generator[None, None, None]: _about_path = str(pathlib.Path("/", "somewhere", "..", "coding", "hikari", "_about.py")) stack = contextlib.ExitStack() @@ -466,7 +470,7 @@ def test_when_CLICOLOR_is_0(self): assert ux.supports_color(allow_color=True, force_color=False) is False @pytest.mark.parametrize("colorterm", ["truecolor", "24bit", "TRUECOLOR", "24BIT"]) - def test_when_COLORTERM_has_correct_value(self, colorterm): + def test_when_COLORTERM_has_correct_value(self, colorterm: str): with mock.patch.dict(os.environ, {"COLORTERM": colorterm}, clear=True): assert ux.supports_color(allow_color=True, force_color=False) is True @@ -490,7 +494,7 @@ def test_when_plat_is_Pocket_PC(self): ("Terminus", True, False, False), ], ) - def test_when_plat_is_win32(self, term_program, ansicon, isatty, expected): + def test_when_plat_is_win32(self, term_program: str, ansicon: bool, isatty: bool, expected: bool): environ = {"TERM_PROGRAM": term_program} if ansicon: environ["ANSICON"] = "ooga booga" @@ -504,7 +508,7 @@ def test_when_plat_is_win32(self, term_program, ansicon, isatty, expected): assert ux.supports_color(allow_color=True, force_color=False) is expected @pytest.mark.parametrize("isatty", [True, False]) - def test_when_plat_is_not_win32(self, isatty): + def test_when_plat_is_not_win32(self, isatty: bool): stack = contextlib.ExitStack() stack.enter_context(mock.patch.dict(os.environ, {}, clear=True)) stack.enter_context(mock.patch.object(sys.stdout, "isatty", return_value=isatty)) @@ -515,7 +519,7 @@ def test_when_plat_is_not_win32(self, isatty): @pytest.mark.parametrize("isatty", [True, False]) @pytest.mark.parametrize("plat", ["linux", "win32"]) - def test_when_PYCHARM_HOSTED(self, isatty, plat): + def test_when_PYCHARM_HOSTED(self, isatty: bool, plat: str): stack = contextlib.ExitStack() stack.enter_context(mock.patch.dict(os.environ, {"PYCHARM_HOSTED": "OOGA BOOGA"}, clear=True)) stack.enter_context(mock.patch.object(sys.stdout, "isatty", return_value=isatty)) @@ -526,7 +530,7 @@ def test_when_PYCHARM_HOSTED(self, isatty, plat): @pytest.mark.parametrize("isatty", [True, False]) @pytest.mark.parametrize("plat", ["linux", "win32"]) - def test_when_WT_SESSION(self, isatty, plat): + def test_when_WT_SESSION(self, isatty: bool, plat: str): stack = contextlib.ExitStack() stack.enter_context(mock.patch.dict(os.environ, {"WT_SESSION": "OOGA BOOGA"}, clear=True)) stack.enter_context(mock.patch.object(sys.stdout, "isatty", return_value=isatty)) @@ -538,7 +542,7 @@ def test_when_WT_SESSION(self, isatty, plat): class TestHikariVersion: @pytest.mark.parametrize("v", ["1", "1.0.0dev2"]) - def test_init_when_version_number_is_invalid(self, v): + def test_init_when_version_number_is_invalid(self, v: str): with pytest.raises(ValueError, match=rf"Invalid version: '{v}'"): ux.HikariVersion(v) @@ -566,7 +570,7 @@ def test_repr(self): (ux.HikariVersion("1.2.3"), False), ], ) - def test_eq(self, other, result): + def test_eq(self, other: ux.HikariVersion, result: bool): assert (ux.HikariVersion("1.2.3.dev99") == other) is result @pytest.mark.parametrize( @@ -578,7 +582,7 @@ def test_eq(self, other, result): (ux.HikariVersion("1.2.3"), True), ], ) - def test_ne(self, other, result): + def test_ne(self, other: ux.HikariVersion, result: bool): assert (ux.HikariVersion("1.2.3.dev99") != other) is result @pytest.mark.parametrize( @@ -590,7 +594,7 @@ def test_ne(self, other, result): (ux.HikariVersion("1.2.3"), True), ], ) - def test_lt(self, other, result): + def test_lt(self, other: ux.HikariVersion, result: bool): assert (ux.HikariVersion("1.2.3.dev99") < other) is result @pytest.mark.parametrize( @@ -602,7 +606,7 @@ def test_lt(self, other, result): (ux.HikariVersion("1.2.3"), True), ], ) - def test_le(self, other, result): + def test_le(self, other: ux.HikariVersion, result: bool): assert (ux.HikariVersion("1.2.3.dev99") <= other) is result @pytest.mark.parametrize( @@ -614,7 +618,7 @@ def test_le(self, other, result): (ux.HikariVersion("1.2.3"), False), ], ) - def test_ge(self, other, result): + def test_ge(self, other: ux.HikariVersion, result: bool): assert (ux.HikariVersion("1.2.3.dev99") > other) is result @pytest.mark.parametrize( @@ -626,21 +630,23 @@ def test_ge(self, other, result): (ux.HikariVersion("1.2.3"), False), ], ) - def test_gt(self, other, result): + def test_gt(self, other: ux.HikariVersion, result: bool): assert (ux.HikariVersion("1.2.3.dev99") >= other) is result @pytest.mark.asyncio class TestCheckForUpdates: @pytest.fixture - def http_settings(self): + def http_settings(self) -> config.HTTPSettings: return mock.Mock(spec_set=config.HTTPSettings) @pytest.fixture - def proxy_settings(self): + def proxy_settings(self) -> config.ProxySettings: return mock.Mock(spec_set=config.ProxySettings) - async def test_when_not_official_pypi_release(self, http_settings, proxy_settings): + async def test_when_not_official_pypi_release( + self, http_settings: config.HTTPSettings, proxy_settings: config.ProxySettings + ): stack = contextlib.ExitStack() logger = stack.enter_context(mock.patch.object(ux, "_LOGGER")) create_client_session = stack.enter_context(mock.patch.object(net, "create_client_session")) @@ -653,7 +659,7 @@ async def test_when_not_official_pypi_release(self, http_settings, proxy_setting logger.info.assert_not_called() create_client_session.assert_not_called() - async def test_when_error_fetching(self, http_settings, proxy_settings): + async def test_when_error_fetching(self, http_settings: config.HTTPSettings, proxy_settings: config.ProxySettings): ex = RuntimeError("testing") stack = contextlib.ExitStack() logger = stack.enter_context(mock.patch.object(ux, "_LOGGER")) @@ -674,7 +680,9 @@ async def test_when_error_fetching(self, http_settings, proxy_settings): trust_env=proxy_settings.trust_env, ) - async def test_when_no_new_available_releases(self, http_settings, proxy_settings): + async def test_when_no_new_available_releases( + self, http_settings: config.HTTPSettings, proxy_settings: config.ProxySettings + ): data = { "releases": { "0.1.0": [{"yanked": False}], @@ -721,7 +729,9 @@ async def test_when_no_new_available_releases(self, http_settings, proxy_setting ) @pytest.mark.parametrize("v", ["1.0.1", "1.0.1.dev10"]) - async def test_check_for_updates(self, v, http_settings, proxy_settings): + async def test_check_for_updates( + self, v: str, http_settings: config.HTTPSettings, proxy_settings: config.ProxySettings + ): data = { "releases": { v: [{"yanked": False}, {"yanked": True}], diff --git a/tests/hikari/test_applications.py b/tests/hikari/test_applications.py index f349732d3f..9e642ae19c 100644 --- a/tests/hikari/test_applications.py +++ b/tests/hikari/test_applications.py @@ -26,12 +26,11 @@ import pytest from hikari import applications +from hikari import snowflakes +from hikari import traits from hikari import urls from hikari import users -from hikari.errors import ForbiddenError -from hikari.errors import UnauthorizedError from hikari.internal import routes -from tests.hikari import hikari_test_helpers class TestTeamMember: @@ -40,57 +39,59 @@ def model(self): user = mock.Mock(users.User) user.avatar_hash = "a_test" user.banner_hash = "a_test2" - return applications.TeamMember(membership_state=4, permissions=["*"], team_id=34123, user=user) + return applications.TeamMember( + membership_state=4, permissions=["*"], team_id=snowflakes.Snowflake(34123), user=user + ) - def test_app_property(self, model): + def test_app_property(self, model: applications.TeamMember): assert model.app is model.user.app - def test_avatar_decoration_property(self, model): + def test_avatar_decoration_property(self, model: applications.TeamMember): assert model.avatar_decoration is model.user.avatar_decoration - def test_avatar_hash_property(self, model): + def test_avatar_hash_property(self, model: applications.TeamMember): assert model.avatar_hash is model.user.avatar_hash - def test_avatar_url_property(self, model): + def test_avatar_url_property(self, model: applications.TeamMember): assert model.avatar_url is model.user.make_avatar_url() - def test_banner_hash_property(self, model): + def test_banner_hash_property(self, model: applications.TeamMember): assert model.banner_hash is model.user.banner_hash - def test_banner_url_property(self, model): + def test_banner_url_property(self, model: applications.TeamMember): assert model.banner_url is model.user.make_banner_url() - def test_make_avatar_url(self, model): + def test_make_avatar_url(self, model: applications.TeamMember): assert model.make_avatar_url() is model.user.make_avatar_url() - def test_make_banner_url(self, model): + def test_make_banner_url(self, model: applications.TeamMember): assert model.make_banner_url() is model.user.make_banner_url() - def test_accent_color_property(self, model): + def test_accent_color_property(self, model: applications.TeamMember): assert model.accent_color is model.user.accent_color - def test_default_avatar_url_property(self, model): + def test_default_avatar_url_property(self, model: applications.TeamMember): assert model.default_avatar_url is model.user.default_avatar_url - def test_discriminator_property(self, model): + def test_discriminator_property(self, model: applications.TeamMember): assert model.discriminator is model.user.discriminator - def test_flags_property(self, model): + def test_flags_property(self, model: applications.TeamMember): assert model.flags is model.user.flags - def test_id_property(self, model): + def test_id_property(self, model: applications.TeamMember): assert model.id is model.user.id - def test_is_bot_property(self, model): + def test_is_bot_property(self, model: applications.TeamMember): assert model.is_bot is model.user.is_bot - def test_is_system_property(self, model): + def test_is_system_property(self, model: applications.TeamMember): assert model.is_system is model.user.is_system - def test_mention_property(self, model): + def test_mention_property(self, model: applications.TeamMember): assert model.mention is model.user.mention - def test_username_property(self, model): + def test_username_property(self, model: applications.TeamMember): assert model.username is model.user.username def test_str_operator(self): @@ -102,128 +103,119 @@ def test_str_operator(self): class TestTeam: @pytest.fixture - def model(self): - return hikari_test_helpers.mock_class_namespace( - applications.Team, slots_=False, init_=False, id=123, icon_hash="ahashicon" - )() + def team(self) -> applications.Team: + return applications.Team( + app=mock.Mock(traits.RESTAware), + id=snowflakes.Snowflake(123), + name="beanos", + icon_hash="icon_hash", + members={}, + owner_id=snowflakes.Snowflake(456), + ) - def test_str_operator(self): - team = applications.Team(id=696969, app=object(), name="test", icon_hash="", members=[], owner_id=0) - assert str(team) == "Team test (696969)" + def test_str_operator(self, team: applications.Team): + assert str(team) == "Team beanos (123)" - def test_make_icon_url_format_set_to_deprecated_ext_argument_if_provided(self, model): + def test_make_icon_url_format_set_to_deprecated_ext_argument_if_provided(self, team: applications.Team): with mock.patch.object( routes, "CDN_TEAM_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_icon_url(ext="JPEG") == "file" + assert team.make_icon_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, team_id=123, hash="ahashicon", size=4096, file_format="JPEG", lossless=True + urls.CDN_URL, team_id=123, hash="icon_hash", size=4096, file_format="JPEG", lossless=True ) - def test_icon_url_property(self, model): - model.make_icon_url = mock.Mock(return_value="url") + def test_icon_url_property(self, team: applications.Team): + with mock.patch.object(applications.Team, "make_icon_url", return_value="url"): + assert team.icon_url == "url" - assert model.icon_url == "url" + def test_make_icon_url_when_hash_is_None(self, team: applications.Team): + with ( + mock.patch.object(team, "icon_hash", None), + mock.patch.object( + routes, "CDN_TEAM_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) + ) as patched_cdn_team_icon_route, + ): + assert team.make_icon_url(file_format="JPEG", size=1) is None - model.make_icon_url.assert_called_once_with() - - def test_make_icon_url_when_hash_is_None(self, model): - model.icon_hash = None + patched_cdn_team_icon_route.compile_to_file.assert_not_called() + def test_make_icon_url_when_hash_is_not_None(self, team: applications.Team): with mock.patch.object( routes, "CDN_TEAM_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_icon_url(file_format="JPEG", size=1) is None - - route.compile_to_file.assert_not_called() - - def test_make_icon_url_when_hash_is_not_None(self, model): - with mock.patch.object( - routes, "CDN_TEAM_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) - ) as route: - assert model.make_icon_url(file_format="JPEG", size=1) == "file" + assert team.make_icon_url(file_format="JPEG", size=1) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, team_id=123, hash="ahashicon", size=1, file_format="JPEG", lossless=True + urls.CDN_URL, team_id=123, hash="icon_hash", size=1, file_format="JPEG", lossless=True ) class TestApplication: @pytest.fixture - def model(self): - return hikari_test_helpers.mock_class_namespace( - applications.Application, - init_=False, - slots_=False, - id=123, - icon_hash="ahashicon", - cover_image_hash="ahashcover", - )() - - def test_make_icon_url_format_set_to_deprecated_ext_argument_if_provided(self, model): + def application(self) -> applications.Application: + return applications.Application( + id=snowflakes.Snowflake(123), + name="name", + description="description", + icon_hash="icon_hash", + app=mock.Mock(), + is_bot_public=False, + is_bot_code_grant_required=False, + owner=mock.Mock(), + rpc_origins=None, + flags=applications.ApplicationFlags.EMBEDDED, + public_key=b"public_key", + team=None, + cover_image_hash="cover_image_hash", + terms_of_service_url=None, + privacy_policy_url=None, + role_connections_verification_url=None, + custom_install_url=None, + tags=[], + install_parameters=None, + approximate_guild_count=0, + approximate_user_install_count=0, + integration_types_config={}, + ) + + def test_make_icon_url_format_set_to_deprecated_ext_argument_if_provided( + self, application: applications.Application + ): with mock.patch.object( routes, "CDN_APPLICATION_COVER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_cover_image_url(ext="JPEG") == "file" + assert application.make_cover_image_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, application_id=123, hash="ahashcover", size=4096, file_format="JPEG", lossless=True + urls.CDN_URL, application_id=123, hash="cover_image_hash", size=4096, file_format="JPEG", lossless=True ) - def test_cover_image_url_property(self, model): - model.make_cover_image_url = mock.Mock(return_value="url") - - assert model.cover_image_url == "url" - - model.make_cover_image_url.assert_called_once_with() - - def test_make_cover_image_url_when_hash_is_None(self, model): - model.cover_image_hash = None - + def test_cover_image_url_property(self, application: applications.Application): with mock.patch.object( routes, "CDN_APPLICATION_COVER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) - ) as route: - assert model.make_cover_image_url(file_format="JPEG", size=1) is None - - route.compile_to_file.assert_not_called() + ) as patched_cdn_application_cover: + assert application.make_cover_image_url(file_format="JPEG", size=1) == "file" + + patched_cdn_application_cover.compile_to_file.assert_called_once_with( + "https://cdn.discordapp.com", + application_id=123, + hash="cover_image_hash", + size=1, + file_format="JPEG", + lossless=True, + ) - def test_make_cover_image_url_when_hash_is_not_None(self, model): + def test_make_cover_image_url_when_hash_is_not_None(self, application: applications.Application): with mock.patch.object( routes, "CDN_APPLICATION_COVER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) - ) as route: - assert model.make_cover_image_url(file_format="JPEG", size=1) == "file" - - route.compile_to_file.assert_called_once_with( - urls.CDN_URL, application_id=123, hash="ahashcover", size=1, file_format="JPEG", lossless=True - ) - - @pytest.mark.asyncio - async def test_fetch_guild(self, model): - model.guild_id = 1234 - model.fetch_guild = mock.AsyncMock() - - model.fetch_guild.return_value.id = model.guild_id - assert (await model.fetch_guild()).id == model.guild_id - - model.fetch_guild.side_effect = UnauthorizedError( - "blah blah", "interesting", "foo bar", "this is an error", 403 - ) - with pytest.raises(UnauthorizedError): - await model.fetch_guild() - - @pytest.mark.asyncio - async def test_fetch_guild_preview(self, model): - model.fetch_guild_preview = mock.AsyncMock() - - model.fetch_guild_preview.return_value.description = "poggers" - assert (await model.fetch_guild_preview()).description == "poggers" + ) as patched_cdn_application_cover: + assert application.make_cover_image_url(file_format="JPEG", size=1) == "file" - model.fetch_guild_preview.side_effect = ForbiddenError( - "blah blah", "interesting", "foo bar", "this is an error", 403 + patched_cdn_application_cover.compile_to_file.assert_called_once_with( + urls.CDN_URL, application_id=123, hash="cover_image_hash", size=1, file_format="JPEG", lossless=True ) - with pytest.raises(ForbiddenError): - await model.fetch_guild_preview() class TestPartialOAuth2Token: @@ -247,6 +239,6 @@ def test_get_token_id_adds_padding(): @pytest.mark.parametrize("token", ["______.222222.dessddssd", "", "b2tva29r.b2tva29r.b2tva29r"]) -def test_get_token_id_handles_invalid_token(token): +def test_get_token_id_handles_invalid_token(token: str): with pytest.raises(ValueError, match="Unexpected token format"): applications.get_token_id(token) diff --git a/tests/hikari/test_audit_logs.py b/tests/hikari/test_audit_logs.py index 5764a3638f..75a6c40d65 100644 --- a/tests/hikari/test_audit_logs.py +++ b/tests/hikari/test_audit_logs.py @@ -26,94 +26,107 @@ from hikari import audit_logs from hikari import channels from hikari import snowflakes +from hikari import traits @pytest.mark.asyncio class TestMessagePinEntryInfo: - async def test_fetch_channel(self): - app = mock.AsyncMock() - app.rest.fetch_channel.return_value = mock.Mock(spec_set=channels.GuildTextChannel) - model = audit_logs.MessagePinEntryInfo(app=app, channel_id=123, message_id=456) - - assert await model.fetch_channel() is model.app.rest.fetch_channel.return_value - - model.app.rest.fetch_channel.assert_awaited_once_with(123) - - async def test_fetch_message(self): - model = audit_logs.MessagePinEntryInfo(app=mock.AsyncMock(), channel_id=123, message_id=456) - - assert await model.fetch_message() is model.app.rest.fetch_message.return_value + @pytest.fixture + def message_pin_entry_info(hikari_app: traits.RESTAware) -> audit_logs.MessagePinEntryInfo: + return audit_logs.MessagePinEntryInfo( + app=hikari_app, channel_id=snowflakes.Snowflake(123), message_id=snowflakes.Snowflake(456) + ) - model.app.rest.fetch_message.assert_awaited_once_with(123, 456) + async def test_fetch_channel(self, message_pin_entry_info: audit_logs.MessagePinEntryInfo): + with ( + mock.patch.object(message_pin_entry_info, "app") as patched_app, + mock.patch.object( + patched_app.rest, + "fetch_channel", + mock.AsyncMock(return_value=mock.Mock(spec_set=channels.GuildTextChannel)), + ) as patched_fetch_channel, + ): + assert await message_pin_entry_info.fetch_channel() is patched_fetch_channel.return_value + patched_fetch_channel.assert_awaited_once_with(123) + + async def test_fetch_message(self, message_pin_entry_info: audit_logs.MessagePinEntryInfo): + with ( + mock.patch.object(message_pin_entry_info, "app") as patched_app, + mock.patch.object(patched_app.rest, "fetch_message", new_callable=mock.AsyncMock) as patched_fetch_message, + ): + assert await message_pin_entry_info.fetch_message() is patched_fetch_message.return_value + patched_fetch_message.assert_awaited_once_with(123, 456) @pytest.mark.asyncio class TestMessageDeleteEntryInfo: async def test_fetch_channel(self): app = mock.AsyncMock() - app.rest.fetch_channel.return_value = mock.Mock(spec_set=channels.GuildTextChannel) - model = audit_logs.MessageDeleteEntryInfo(app=app, count=1, channel_id=123) + model = audit_logs.MessageDeleteEntryInfo(app=app, count=1, channel_id=snowflakes.Snowflake(123)) - assert await model.fetch_channel() is model.app.rest.fetch_channel.return_value + with mock.patch.object( + app.rest, "fetch_channel", new=mock.AsyncMock(return_value=mock.Mock(spec_set=channels.GuildTextChannel)) + ) as patched_fetch_channel: + assert await model.fetch_channel() is patched_fetch_channel.return_value - model.app.rest.fetch_channel.assert_awaited_once_with(123) + patched_fetch_channel.assert_awaited_once_with(123) @pytest.mark.asyncio class TestMemberMoveEntryInfo: async def test_fetch_channel(self): app = mock.AsyncMock() - app.rest.fetch_channel.return_value = mock.Mock(spec_set=channels.GuildVoiceChannel) - model = audit_logs.MemberMoveEntryInfo(app=app, count=1, channel_id=123) + model = audit_logs.MemberMoveEntryInfo(app=app, count=1, channel_id=snowflakes.Snowflake(123)) - assert await model.fetch_channel() is model.app.rest.fetch_channel.return_value + with mock.patch.object( + app.rest, "fetch_channel", new=mock.AsyncMock(return_value=mock.Mock(spec_set=channels.GuildVoiceChannel)) + ) as patched_fetch_channel: + assert await model.fetch_channel() is patched_fetch_channel.return_value - model.app.rest.fetch_channel.assert_awaited_once_with(123) + patched_fetch_channel.assert_awaited_once_with(123) class TestAuditLogEntry: - @pytest.mark.asyncio - async def test_fetch_user_when_no_user(self): - model = audit_logs.AuditLogEntry( - app=mock.AsyncMock(), - id=123, + @pytest.fixture + def audit_log_entry(hikari_app: traits.RESTAware) -> audit_logs.AuditLogEntry: + return audit_logs.AuditLogEntry( + app=hikari_app, + id=snowflakes.Snowflake(123), target_id=None, changes=[], - user_id=None, + user_id=snowflakes.Snowflake(456), action_type=0, options=None, reason=None, guild_id=snowflakes.Snowflake(34123123), ) - assert await model.fetch_user() is None - - model.app.rest.fetch_user.assert_not_called() - @pytest.mark.asyncio - async def test_fetch_user_when_user(self): - model = audit_logs.AuditLogEntry( - app=mock.AsyncMock(), - id=123, - target_id=None, - changes=[], - user_id=456, - action_type=0, - options=None, - reason=None, - guild_id=snowflakes.Snowflake(123321123), - ) + async def test_fetch_user_when_no_user(self, audit_log_entry: audit_logs.AuditLogEntry): + with ( + mock.patch.object(audit_log_entry, "user_id", None), + mock.patch.object(audit_log_entry, "app") as patched_app, + mock.patch.object(patched_app.rest, "fetch_user") as patched_fetch_user, + ): + assert await audit_log_entry.fetch_user() is None - assert await model.fetch_user() is model.app.rest.fetch_user.return_value + patched_fetch_user.assert_not_called() - model.app.rest.fetch_user.assert_awaited_once_with(456) + @pytest.mark.asyncio + async def test_fetch_user_when_user(self, audit_log_entry: audit_logs.AuditLogEntry): + with ( + mock.patch.object(audit_log_entry, "app") as patched_app, + mock.patch.object(patched_app.rest, "fetch_user", new_callable=mock.AsyncMock) as patched_fetch_user, + ): + assert await audit_log_entry.fetch_user() is patched_fetch_user.return_value + patched_fetch_user.assert_awaited_once_with(456) class TestAuditLog: def test_iter(self): - entry_1 = object() - entry_2 = object() - entry_3 = object() + entry_1 = mock.Mock() + entry_2 = mock.Mock() + entry_3 = mock.Mock() audit_log = audit_logs.AuditLog( auto_mod_rules={}, entries={ @@ -129,15 +142,15 @@ def test_iter(self): assert list(audit_log) == [entry_1, entry_2, entry_3] def test_get_item_with_index(self): - entry = object() - entry_2 = object() + entry = mock.Mock() + entry_2 = mock.Mock() audit_log = audit_logs.AuditLog( auto_mod_rules={}, entries={ - snowflakes.Snowflake(432123): object(), + snowflakes.Snowflake(432123): mock.Mock(), snowflakes.Snowflake(432654): entry, - snowflakes.Snowflake(432888): object(), - snowflakes.Snowflake(677777): object(), + snowflakes.Snowflake(432888): mock.Mock(), + snowflakes.Snowflake(677777): mock.Mock(), snowflakes.Snowflake(999999): entry_2, }, integrations={}, @@ -149,16 +162,16 @@ def test_get_item_with_index(self): assert audit_log[4] is entry_2 def test_get_item_with_slice(self): - entry_1 = object() - entry_2 = object() + entry_1 = mock.Mock() + entry_2 = mock.Mock() audit_log = audit_logs.AuditLog( auto_mod_rules={}, entries={ - snowflakes.Snowflake(432123): object(), + snowflakes.Snowflake(432123): mock.Mock(), snowflakes.Snowflake(432654): entry_1, - snowflakes.Snowflake(432888): object(), + snowflakes.Snowflake(432888): mock.Mock(), snowflakes.Snowflake(666666): entry_2, - snowflakes.Snowflake(783452): object(), + snowflakes.Snowflake(783452): mock.Mock(), }, integrations={}, threads={}, @@ -171,10 +184,10 @@ def test_len(self): audit_log = audit_logs.AuditLog( auto_mod_rules={}, entries={ - snowflakes.Snowflake(432123): object(), - snowflakes.Snowflake(432654): object(), - snowflakes.Snowflake(432888): object(), - snowflakes.Snowflake(783452): object(), + snowflakes.Snowflake(432123): mock.Mock(), + snowflakes.Snowflake(432654): mock.Mock(), + snowflakes.Snowflake(432888): mock.Mock(), + snowflakes.Snowflake(783452): mock.Mock(), }, integrations={}, threads={}, diff --git a/tests/hikari/test_channels.py b/tests/hikari/test_channels.py index 5a3bef12c4..e5d2dd7dcf 100644 --- a/tests/hikari/test_channels.py +++ b/tests/hikari/test_channels.py @@ -29,58 +29,60 @@ from hikari import files from hikari import permissions from hikari import snowflakes +from hikari import traits from hikari import users from hikari import webhooks from hikari.internal import routes from tests.hikari import hikari_test_helpers -@pytest.fixture -def mock_app(): - return mock.Mock() - - class TestChannelFollow: @pytest.mark.asyncio - async def test_fetch_channel(self, mock_app): + async def test_fetch_channel(self, hikari_app: traits.RESTAware): mock_channel = mock.Mock(spec=channels.GuildNewsChannel) - mock_app.rest.fetch_channel = mock.AsyncMock(return_value=mock_channel) + hikari_app.rest.fetch_channel = mock.AsyncMock(return_value=mock_channel) follow = channels.ChannelFollow( - channel_id=snowflakes.Snowflake(9459234123), app=mock_app, webhook_id=snowflakes.Snowflake(3123123) + channel_id=snowflakes.Snowflake(9459234123), app=hikari_app, webhook_id=snowflakes.Snowflake(3123123) ) result = await follow.fetch_channel() assert result is mock_channel - mock_app.rest.fetch_channel.assert_awaited_once_with(9459234123) + hikari_app.rest.fetch_channel.assert_awaited_once_with(9459234123) @pytest.mark.asyncio - async def test_fetch_webhook(self, mock_app): - mock_app.rest.fetch_webhook = mock.AsyncMock(return_value=mock.Mock(webhooks.ChannelFollowerWebhook)) + async def test_fetch_webhook(self, hikari_app: traits.RESTAware): + hikari_app.rest.fetch_webhook = mock.AsyncMock(return_value=mock.Mock(webhooks.ChannelFollowerWebhook)) follow = channels.ChannelFollow( - webhook_id=snowflakes.Snowflake(54123123), app=mock_app, channel_id=snowflakes.Snowflake(94949494) + webhook_id=snowflakes.Snowflake(54123123), app=hikari_app, channel_id=snowflakes.Snowflake(94949494) ) result = await follow.fetch_webhook() - assert result is mock_app.rest.fetch_webhook.return_value - mock_app.rest.fetch_webhook.assert_awaited_once_with(54123123) + assert result is hikari_app.rest.fetch_webhook.return_value + hikari_app.rest.fetch_webhook.assert_awaited_once_with(54123123) - def test_get_channel(self, mock_app): + def test_get_channel(self): mock_channel = mock.Mock(spec=channels.GuildNewsChannel) - mock_app.cache.get_guild_channel = mock.Mock(return_value=mock_channel) - follow = channels.ChannelFollow( - webhook_id=snowflakes.Snowflake(993883), app=mock_app, channel_id=snowflakes.Snowflake(696969) - ) - result = follow.get_channel() + app = mock.Mock(traits.CacheAware, rest=mock.Mock()) + with mock.patch.object( + app.cache, "get_guild_channel", mock.Mock(return_value=mock_channel) + ) as patched_get_guild_channel: + follow = channels.ChannelFollow( + webhook_id=snowflakes.Snowflake(993883), app=app, channel_id=snowflakes.Snowflake(696969) + ) - assert result is mock_channel - mock_app.cache.get_guild_channel.assert_called_once_with(696969) + result = follow.get_channel() + + assert result is mock_channel + patched_get_guild_channel.assert_called_once_with(696969) def test_get_channel_when_no_cache_trait(self): follow = channels.ChannelFollow( - webhook_id=snowflakes.Snowflake(993883), app=object(), channel_id=snowflakes.Snowflake(696969) + webhook_id=snowflakes.Snowflake(993883), + app=mock.Mock(traits.RESTAware), + channel_id=snowflakes.Snowflake(696969), ) assert follow.get_channel() is None @@ -98,54 +100,52 @@ def test_unset(self): class TestPartialChannel: @pytest.fixture - def model(self, mock_app): + def partial_channel(self, hikari_app: traits.RESTAware) -> channels.PartialChannel: return hikari_test_helpers.mock_class_namespace(channels.PartialChannel, rename_impl_=False)( - app=mock_app, id=snowflakes.Snowflake(1234567), name="foo", type=channels.ChannelType.GUILD_NEWS + app=hikari_app, id=snowflakes.Snowflake(1234567), name="foo", type=channels.ChannelType.GUILD_NEWS ) - def test_str_operator(self, model): - assert str(model) == "foo" + def test_str_operator(self, partial_channel: channels.PartialChannel): + assert str(partial_channel) == "foo" - def test_str_operator_when_name_is_None(self, model): - model.name = None - assert str(model) == "Unnamed PartialChannel ID 1234567" + def test_str_operator_when_name_is_None(self, partial_channel: channels.PartialChannel): + partial_channel.name = None + assert str(partial_channel) == "Unnamed PartialChannel ID 1234567" - def test_mention_property(self, model): - assert model.mention == "<#1234567>" + def test_mention_property(self, partial_channel: channels.PartialChannel): + assert partial_channel.mention == "<#1234567>" @pytest.mark.asyncio - async def test_delete(self, model): - model.app.rest.delete_channel = mock.AsyncMock() - - assert await model.delete() is model.app.rest.delete_channel.return_value - - model.app.rest.delete_channel.assert_called_once_with(1234567) + async def test_delete(self, partial_channel: channels.PartialChannel): + with mock.patch.object(partial_channel.app.rest, "delete_channel", mock.AsyncMock()) as patched_delete_channel: + assert await partial_channel.delete() is patched_delete_channel.return_value + patched_delete_channel.assert_called_once_with(1234567) class TestDMChannel: @pytest.fixture - def model(self, mock_app): + def dm_channel(self, hikari_app: traits.RESTAware) -> channels.DMChannel: return channels.DMChannel( id=snowflakes.Snowflake(12345), name="steve", type=channels.ChannelType.DM, last_message_id=snowflakes.Snowflake(12345), recipient=mock.Mock(spec_set=users.UserImpl, __str__=mock.Mock(return_value="snoop#0420")), - app=mock_app, + app=hikari_app, ) - def test_str_operator(self, model): - assert str(model) == "DMChannel with: snoop#0420" + def test_str_operator(self, dm_channel: channels.DMChannel): + assert str(dm_channel) == "DMChannel with: snoop#0420" - def test_shard_id(self, model): - assert model.shard_id == 0 + def test_shard_id(self, dm_channel: channels.DMChannel): + assert dm_channel.shard_id == 0 class TestGroupDMChannel: @pytest.fixture - def model(self, mock_app): + def group_dm_channel(self, hikari_app: traits.RESTAware) -> channels.GroupDMChannel: return channels.GroupDMChannel( - app=mock_app, + app=hikari_app, id=snowflakes.Snowflake(136134), name="super cool group dm", type=channels.ChannelType.DM, @@ -161,18 +161,20 @@ def model(self, mock_app): application_id=None, ) - def test_str_operator(self, model): - assert str(model) == "super cool group dm" + def test_str_operator(self, group_dm_channel: channels.GroupDMChannel): + assert str(group_dm_channel) == "super cool group dm" - def test_str_operator_when_name_is_None(self, model): - model.name = None - assert str(model) == "GroupDMChannel with: snoop#0420, yeet#1012, nice#6969" + def test_str_operator_when_name_is_None(self, group_dm_channel: channels.GroupDMChannel): + group_dm_channel.name = None + assert str(group_dm_channel) == "GroupDMChannel with: snoop#0420, yeet#1012, nice#6969" - def test_make_icon_url_format_set_to_deprecated_ext_argument_if_provided(self, model): + def test_make_icon_url_format_set_to_deprecated_ext_argument_if_provided( + self, group_dm_channel: channels.GroupDMChannel + ): with mock.patch.object( routes, "CDN_CHANNEL_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_icon_url(ext="JPEG") == "file" + assert group_dm_channel.make_icon_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, channel_id=136134, hash="1a2b3c", size=4096, file_format="JPEG", lossless=True @@ -185,39 +187,39 @@ def test_icon_url(self): assert channel.icon_url == "icon-url-here.com" channel.make_icon_url.assert_called_once() - def test_make_icon_url(self, model): - assert model.make_icon_url(file_format="JPEG", size=16) == files.URL( + def test_make_icon_url(self, group_dm_channel: channels.GroupDMChannel): + assert group_dm_channel.make_icon_url(file_format="JPEG", size=16) == files.URL( "https://cdn.discordapp.com/channel-icons/136134/1a2b3c.jpeg?size=16" ) - def test_make_icon_url_without_optional_params(self, model): - assert model.make_icon_url() == files.URL( + def test_make_icon_url_without_optional_params(self, group_dm_channel: channels.GroupDMChannel): + assert group_dm_channel.make_icon_url() == files.URL( "https://cdn.discordapp.com/channel-icons/136134/1a2b3c.png?size=4096" ) - def test_make_icon_url_when_hash_is_None(self, model): - model.icon_hash = None - assert model.make_icon_url() is None + def test_make_icon_url_when_hash_is_None(self, group_dm_channel: channels.GroupDMChannel): + group_dm_channel.icon_hash = None + assert group_dm_channel.make_icon_url() is None class TestTextChannel: @pytest.fixture - def model(self, mock_app): + def text_channel(self, hikari_app: traits.RESTAware) -> channels.TextableChannel: return hikari_test_helpers.mock_class_namespace(channels.TextableChannel)( - app=mock_app, id=snowflakes.Snowflake(12345679), name="foo1", type=channels.ChannelType.GUILD_TEXT + app=hikari_app, id=snowflakes.Snowflake(12345679), name="foo1", type=channels.ChannelType.GUILD_TEXT ) @pytest.mark.asyncio - async def test_fetch_history(self, model): - model.app.rest.fetch_messages = mock.AsyncMock() + async def test_fetch_history(self, text_channel: channels.TextableChannel): + text_channel.app.rest.fetch_messages = mock.AsyncMock() - await model.fetch_history( + await text_channel.fetch_history( before=datetime.datetime(2020, 4, 1, 1, 0, 0), after=datetime.datetime(2020, 4, 1, 0, 0, 0), around=datetime.datetime(2020, 4, 1, 0, 30, 0), ) - model.app.rest.fetch_messages.assert_awaited_once_with( + text_channel.app.rest.fetch_messages.assert_awaited_once_with( 12345679, before=datetime.datetime(2020, 4, 1, 1, 0, 0), after=datetime.datetime(2020, 4, 1, 0, 0, 0), @@ -225,57 +227,57 @@ async def test_fetch_history(self, model): ) @pytest.mark.asyncio - async def test_fetch_message(self, model): - model.app.rest.fetch_message = mock.AsyncMock() + async def test_fetch_message(self, text_channel: channels.TextableChannel): + text_channel.app.rest.fetch_message = mock.AsyncMock() - assert await model.fetch_message(133742069) is model.app.rest.fetch_message.return_value + assert await text_channel.fetch_message(133742069) is text_channel.app.rest.fetch_message.return_value - model.app.rest.fetch_message.assert_awaited_once_with(12345679, 133742069) + text_channel.app.rest.fetch_message.assert_awaited_once_with(12345679, 133742069) @pytest.mark.asyncio - async def test_fetch_pins(self, model): - model.app.rest.fetch_pins = mock.AsyncMock() + async def test_fetch_pins(self, text_channel: channels.TextableChannel): + text_channel.app.rest.fetch_pins = mock.AsyncMock() - await model.fetch_pins() + await text_channel.fetch_pins() - model.app.rest.fetch_pins.assert_awaited_once_with(12345679) + text_channel.app.rest.fetch_pins.assert_awaited_once_with(12345679) @pytest.mark.asyncio - async def test_pin_message(self, model): - model.app.rest.pin_message = mock.AsyncMock() + async def test_pin_message(self, text_channel: channels.TextableChannel): + text_channel.app.rest.pin_message = mock.AsyncMock() - assert await model.pin_message(77790) is model.app.rest.pin_message.return_value + assert await text_channel.pin_message(77790) is text_channel.app.rest.pin_message.return_value - model.app.rest.pin_message.assert_awaited_once_with(12345679, 77790) + text_channel.app.rest.pin_message.assert_awaited_once_with(12345679, 77790) @pytest.mark.asyncio - async def test_unpin_message(self, model): - model.app.rest.unpin_message = mock.AsyncMock() + async def test_unpin_message(self, text_channel: channels.TextableChannel): + text_channel.app.rest.unpin_message = mock.AsyncMock() - assert await model.unpin_message(77790) is model.app.rest.unpin_message.return_value + assert await text_channel.unpin_message(77790) is text_channel.app.rest.unpin_message.return_value - model.app.rest.unpin_message.assert_awaited_once_with(12345679, 77790) + text_channel.app.rest.unpin_message.assert_awaited_once_with(12345679, 77790) @pytest.mark.asyncio - async def test_delete_messages(self, model): - model.app.rest.delete_messages = mock.AsyncMock() + async def test_delete_messages(self, text_channel: channels.TextableChannel): + text_channel.app.rest.delete_messages = mock.AsyncMock() - await model.delete_messages([77790, 88890, 1800], 1337) + await text_channel.delete_messages([77790, 88890, 1800], 1337) - model.app.rest.delete_messages.assert_awaited_once_with(12345679, [77790, 88890, 1800], 1337) + text_channel.app.rest.delete_messages.assert_awaited_once_with(12345679, [77790, 88890, 1800], 1337) @pytest.mark.asyncio - async def test_send(self, model): - model.app.rest.create_message = mock.AsyncMock() - mock_attachment = object() - mock_component = object() - mock_components = [object(), object()] - mock_embed = object() - mock_embeds = object() - mock_attachments = [object(), object(), object()] - mock_reply = object() - - await model.send( + async def test_send(self, text_channel: channels.TextableChannel): + text_channel.app.rest.create_message = mock.AsyncMock() + mock_attachment = mock.Mock() + mock_component = mock.Mock() + mock_components = [mock.Mock(), mock.Mock()] + mock_embed = mock.Mock() + mock_embeds = mock.Mock() + mock_attachments = [mock.Mock(), mock.Mock(), mock.Mock()] + mock_reply = mock.Mock() + + await text_channel.send( content="test content", tts=True, attachment=mock_attachment, @@ -295,7 +297,7 @@ async def test_send(self, model): flags=6969, ) - model.app.rest.create_message.assert_awaited_once_with( + text_channel.app.rest.create_message.assert_awaited_once_with( channel=12345679, content="test content", tts=True, @@ -316,19 +318,19 @@ async def test_send(self, model): flags=6969, ) - def test_trigger_typing(self, model): - model.app.rest.trigger_typing = mock.Mock() + def test_trigger_typing(self, text_channel: channels.TextableChannel): + text_channel.app.rest.trigger_typing = mock.Mock() - model.trigger_typing() + text_channel.trigger_typing() - model.app.rest.trigger_typing.assert_called_once_with(12345679) + text_channel.app.rest.trigger_typing.assert_called_once_with(12345679) class TestGuildChannel: @pytest.fixture - def model(self, mock_app): - return hikari_test_helpers.mock_class_namespace(channels.GuildChannel)( - app=mock_app, + def guild_channel(self, hikari_app: traits.RESTAware) -> channels.GuildChannel: + return channels.GuildChannel( + app=hikari_app, id=snowflakes.Snowflake(69420), name="foo1", type=channels.ChannelType.GUILD_VOICE, @@ -336,28 +338,32 @@ def model(self, mock_app): parent_id=None, ) - def test_shard_id_property_when_not_shard_aware(self, model): - model.app = None + def test_shard_id_property_when_not_shard_aware(self, guild_channel: channels.GuildChannel): + with mock.patch.object(guild_channel, "app", None): + assert guild_channel.shard_id is None - assert model.shard_id is None - - def test_shard_id_property_when_guild_id_is_not_None(self, model): - model.app.shard_count = 3 - assert model.shard_id == 2 + def test_shard_id_property_when_guild_id_is_not_None(self, guild_channel: channels.GuildChannel): + with ( + mock.patch.object(guild_channel, "app", mock.Mock(traits.ShardAware, rest=mock.Mock())) as patched_app, + mock.patch.object(patched_app, "shard_count", 3), + ): + assert guild_channel.shard_id == 2 @pytest.mark.asyncio - async def test_fetch_guild(self, model): - model.app.rest.fetch_guild = mock.AsyncMock() + async def test_fetch_guild(self, guild_channel: channels.GuildChannel): + guild_channel.app.rest.fetch_guild = mock.AsyncMock() - assert await model.fetch_guild() is model.app.rest.fetch_guild.return_value + assert await guild_channel.fetch_guild() is guild_channel.app.rest.fetch_guild.return_value - model.app.rest.fetch_guild.assert_awaited_once_with(123456789) + guild_channel.app.rest.fetch_guild.assert_awaited_once_with(123456789) @pytest.mark.asyncio - async def test_edit(self, model): - model.app.rest.edit_channel = mock.AsyncMock() + async def test_edit(self, guild_channel: channels.GuildChannel): + guild_channel.app.rest.edit_channel = mock.AsyncMock() + + permission_overwrite = mock.Mock(channels.PermissionOverwrite, id=123) - result = await model.edit( + result = await guild_channel.edit( name="Supa fast boike", bitrate=420, reason="left right", @@ -370,8 +376,8 @@ async def test_edit(self, model): rate_limit_per_user=54123123, region="us-west", parent_category=341123123123, - permission_overwrites={123: "123"}, - flags=12, + permission_overwrites=[permission_overwrite], + flags=channels.ChannelFlag.REQUIRE_TAG, archived=True, auto_archive_duration=1234, locked=True, @@ -379,8 +385,8 @@ async def test_edit(self, model): applied_tags=[12345, 54321], ) - assert result is model.app.rest.edit_channel.return_value - model.app.rest.edit_channel.assert_awaited_once_with( + assert result is guild_channel.app.rest.edit_channel.return_value + guild_channel.app.rest.edit_channel.assert_awaited_once_with( 69420, name="Supa fast boike", position=4423, @@ -391,10 +397,10 @@ async def test_edit(self, model): user_limit=123321, rate_limit_per_user=54123123, region="us-west", - permission_overwrites={123: "123"}, + permission_overwrites=[permission_overwrite], parent_category=341123123123, default_auto_archive_duration=123312, - flags=12, + flags=channels.ChannelFlag.REQUIRE_TAG, archived=True, auto_archive_duration=1234, locked=True, @@ -406,9 +412,9 @@ async def test_edit(self, model): class TestPermissibleGuildChannel: @pytest.fixture - def model(self, mock_app): + def permissible_guild_channel(self, hikari_app: traits.RESTAware) -> channels.PermissibleGuildChannel: return hikari_test_helpers.mock_class_namespace(channels.PermissibleGuildChannel)( - app=mock_app, + app=hikari_app, id=snowflakes.Snowflake(69420), name="foo1", type=channels.ChannelType.GUILD_VOICE, @@ -416,14 +422,14 @@ def model(self, mock_app): is_nsfw=True, parent_id=None, position=54, - permission_overwrites=[], + permission_overwrites={}, ) @pytest.mark.asyncio - async def test_edit_overwrite(self, model): - model.app.rest.edit_permission_overwrite = mock.AsyncMock() + async def test_edit_overwrite(self, permissible_guild_channel: channels.PermissibleGuildChannel): + permissible_guild_channel.app.rest.edit_permission_overwrite = mock.AsyncMock() user = mock.Mock(users.PartialUser) - await model.edit_overwrite( + await permissible_guild_channel.edit_overwrite( 333, target_type=user, allow=permissions.Permissions.BAN_MEMBERS, @@ -431,7 +437,7 @@ async def test_edit_overwrite(self, model): reason="vrooom vroom", ) - model.app.rest.edit_permission_overwrite.assert_called_once_with( + permissible_guild_channel.app.rest.edit_permission_overwrite.assert_called_once_with( 69420, 333, target_type=user, @@ -441,14 +447,14 @@ async def test_edit_overwrite(self, model): ) @pytest.mark.asyncio - async def test_edit_overwrite_target_type_none(self, model): - model.app.rest.edit_permission_overwrite = mock.AsyncMock() + async def test_edit_overwrite_target_type_none(self, permissible_guild_channel: channels.PermissibleGuildChannel): + permissible_guild_channel.app.rest.edit_permission_overwrite = mock.AsyncMock() user = mock.Mock(users.PartialUser) - await model.edit_overwrite( + await permissible_guild_channel.edit_overwrite( user, allow=permissions.Permissions.BAN_MEMBERS, deny=permissions.Permissions.CONNECT, reason="vrooom vroom" ) - model.app.rest.edit_permission_overwrite.assert_called_once_with( + permissible_guild_channel.app.rest.edit_permission_overwrite.assert_called_once_with( 69420, user, allow=permissions.Permissions.BAN_MEMBERS, @@ -457,40 +463,49 @@ async def test_edit_overwrite_target_type_none(self, model): ) @pytest.mark.asyncio - async def test_remove_overwrite(self, model): - model.app.rest.delete_permission_overwrite = mock.AsyncMock() + async def test_remove_overwrite(self, permissible_guild_channel: channels.PermissibleGuildChannel): + permissible_guild_channel.app.rest.delete_permission_overwrite = mock.AsyncMock() - await model.remove_overwrite(333) + await permissible_guild_channel.remove_overwrite(333) - model.app.rest.delete_permission_overwrite.assert_called_once_with(69420, 333) + permissible_guild_channel.app.rest.delete_permission_overwrite.assert_called_once_with(69420, 333) - def test_get_guild(self, model): + def test_get_guild(self, permissible_guild_channel: channels.PermissibleGuildChannel): guild = mock.Mock(id=123456789) - model.app.cache.get_guild.side_effect = [guild] - - assert model.get_guild() == guild - - model.app.cache.get_guild.assert_called_once_with(123456789) - - def test_get_guild_when_guild_not_in_cache(self, model): - model.app.cache.get_guild.side_effect = [None] - assert model.get_guild() is None - - model.app.cache.get_guild.assert_called_once_with(123456789) - - def test_get_guild_when_no_cache_trait(self, model): - model.app = object() - - assert model.get_guild() is None + with ( + mock.patch.object( + permissible_guild_channel, "app", mock.Mock(traits.CacheAware, rest=mock.Mock()) + ) as patched_app, + mock.patch.object(patched_app.cache, "get_guild", side_effect=[guild]) as patched_get_guild, + ): + assert permissible_guild_channel.get_guild() == guild + patched_get_guild.assert_called_once_with(123456789) + + def test_get_guild_when_guild_not_in_cache(self, permissible_guild_channel: channels.PermissibleGuildChannel): + with ( + mock.patch.object( + permissible_guild_channel, "app", mock.Mock(traits.CacheAware, rest=mock.Mock()) + ) as patched_app, + mock.patch.object(patched_app.cache, "get_guild", side_effect=[None]) as patched_get_guild, + ): + assert permissible_guild_channel.get_guild() is None + patched_get_guild.assert_called_once_with(123456789) + + def test_get_guild_when_no_cache_trait(self, permissible_guild_channel: channels.PermissibleGuildChannel): + permissible_guild_channel.app = mock.Mock(traits.RESTAware) + + assert permissible_guild_channel.get_guild() is None @pytest.mark.asyncio - async def test_fetch_guild(self, model): - model.app.rest.fetch_guild = mock.AsyncMock() + async def test_fetch_guild(self, permissible_guild_channel: channels.PermissibleGuildChannel): + permissible_guild_channel.app.rest.fetch_guild = mock.AsyncMock() - assert await model.fetch_guild() == model.app.rest.fetch_guild.return_value + assert ( + await permissible_guild_channel.fetch_guild() == permissible_guild_channel.app.rest.fetch_guild.return_value + ) - model.app.rest.fetch_guild.assert_awaited_once_with(123456789) + permissible_guild_channel.app.rest.fetch_guild.assert_awaited_once_with(123456789) class TestForumTag: @@ -498,7 +513,9 @@ class TestForumTag: ("emoji", "expected_unicode_emoji", "expected_emoji_id"), [(123, None, 123), ("emoji", "emoji", None), (None, None, None)], ) - def test_emoji_parameters(self, emoji, expected_emoji_id, expected_unicode_emoji): + def test_emoji_parameters( + self, emoji: int | str | None, expected_emoji_id: str | None, expected_unicode_emoji: int | None + ): tag = channels.ForumTag(name="testing", emoji=emoji) assert tag.emoji_id == expected_emoji_id diff --git a/tests/hikari/test_colors.py b/tests/hikari/test_colors.py index 4a6570cf8f..af0f5bbe8f 100644 --- a/tests/hikari/test_colors.py +++ b/tests/hikari/test_colors.py @@ -155,20 +155,20 @@ class TestColor: @pytest.mark.parametrize("i", [0, 0x1, 0x11, 0x111, 0x1111, 0x11111, 0xFFFFFF]) - def test_Color_validates_constructor_and_passes_for_valid_values(self, i): + def test_Color_validates_constructor_and_passes_for_valid_values(self, i: int): assert colors.Color(i) is not None @pytest.mark.parametrize("i", [-1, 0x1000000]) - def test_Color_validates_constructor_and_fails_for_out_of_range_values(self, i): + def test_Color_validates_constructor_and_fails_for_out_of_range_values(self, i: int): with pytest.raises(ValueError, match=r"raw_rgb must be in the exclusive range of 0 and 16777215"): colors.Color(i) @pytest.mark.parametrize("i", [0, 0x1, 0x11, 0x111, 0x1111, 0x11111, 0xFFFFFF]) - def test_Color_from_int_passes_for_valid_values(self, i): + def test_Color_from_int_passes_for_valid_values(self, i: int): assert colors.Color.from_int(i) is not None @pytest.mark.parametrize("i", [-1, 0x1000000]) - def test_Color_from_int_fails_for_out_of_range_values(self, i): + def test_Color_from_int_fails_for_out_of_range_values(self, i: int): with pytest.raises(ValueError, match=r"raw_rgb must be in the exclusive range of 0 and 16777215"): colors.Color.from_int(i) @@ -181,29 +181,29 @@ def test_cast_to_int(self): @pytest.mark.parametrize( ("i", "string"), [(0x1A2B3C, "Color(r=0x1a, g=0x2b, b=0x3c)"), (0x1A2, "Color(r=0x0, g=0x1, b=0xa2)")] ) - def test_Color_repr_operator(self, i, string): + def test_Color_repr_operator(self, i: int, string: str): assert repr(colors.Color(i)) == string @pytest.mark.parametrize(("i", "string"), [(0x1A2B3C, "#1A2B3C"), (0x1A2, "#0001A2")]) - def test_Color_str_operator(self, i, string): + def test_Color_str_operator(self, i: int, string: str): assert str(colors.Color(i)) == string @pytest.mark.parametrize(("i", "string"), [(0x1A2B3C, "#1A2B3C"), (0x1A2, "#0001A2")]) - def test_Color_hex_code(self, i, string): + def test_Color_hex_code(self, i: int, string: str): assert colors.Color(i).hex_code == string @pytest.mark.parametrize(("i", "string"), [(0x1A2B3C, "1A2B3C"), (0x1A2, "0001A2")]) - def test_Color_raw_hex_code(self, i, string): + def test_Color_raw_hex_code(self, i: int, string: str): assert colors.Color(i).raw_hex_code == string @pytest.mark.parametrize( ("i", "expected_outcome"), [(0x1A2B3C, False), (0x1AAA2B, False), (0x0, True), (0x11AA33, True)] ) - def test_Color_is_web_safe(self, i, expected_outcome): + def test_Color_is_web_safe(self, i: int, expected_outcome: bool): assert colors.Color(i).is_web_safe is expected_outcome @pytest.mark.parametrize(("r", "g", "b", "expected"), [(0x9, 0x18, 0x27, 0x91827), (0x55, 0x1A, 0xFF, 0x551AFF)]) - def test_Color_from_rgb(self, r, g, b, expected): + def test_Color_from_rgb(self, r: int, g: int, b: int, expected: int): assert colors.Color.from_rgb(r, g, b) == expected def test_color_from_rgb_raises_value_error_on_invalid_red(self): @@ -222,7 +222,7 @@ def test_color_from_rgb_raises_value_error_on_invalid_blue(self): ("r", "g", "b", "expected"), [(0x09 / 0xFF, 0x18 / 0xFF, 0x27 / 0xFF, 0x91827), (0x55 / 0xFF, 0x1A / 0xFF, 0xFF / 0xFF, 0x551AFF)], ) - def test_Color_from_rgb_float(self, r, g, b, expected): + def test_Color_from_rgb_float(self, r: int, g: int, b: int, expected: int): assert math.isclose(colors.Color.from_rgb_float(r, g, b), expected, abs_tol=1) def test_color_from_rgb_float_raises_value_error_on_invalid_red(self): @@ -238,21 +238,21 @@ def test_color_from_rgb_float_raises_value_error_on_invalid_blue(self): colors.Color.from_rgb_float(0.5, 0.5, 1.5) @pytest.mark.parametrize(("input", "r", "g", "b"), [(0x91827, 0x9, 0x18, 0x27), (0x551AFF, 0x55, 0x1A, 0xFF)]) - def test_Color_rgb(self, input, r, g, b): + def test_Color_rgb(self, input: int, r: int, g: int, b: int): assert colors.Color(input).rgb == (r, g, b) @pytest.mark.parametrize( ("input", "r", "g", "b"), [(0x91827, 0x09 / 0xFF, 0x18 / 0xFF, 0x27 / 0xFF), (0x551AFF, 0x55 / 0xFF, 0x1A / 0xFF, 0xFF / 0xFF)], ) - def test_Color_rgb_float(self, input, r, g, b): + def test_Color_rgb_float(self, input: int, r: int, g: int, b: int): assert colors.Color(input).rgb_float == (r, g, b) @pytest.mark.parametrize("prefix", ["0x", "0X", "#", ""]) @pytest.mark.parametrize( ("expected", "string"), [(0x1A2B3C, "1A2B3C"), (0x1A2, "0001A2"), (0xAABBCC, "ABC"), (0x00AA00, "0A0")] ) - def test_Color_from_hex_code(self, prefix, string, expected): + def test_Color_from_hex_code(self, prefix: str, string: str, expected: int): actual_string = prefix + string assert colors.Color.from_hex_code(actual_string) == expected @@ -292,13 +292,15 @@ def test_Color_to_bytes(self): *tuple_str_happy_test_data, ], ) - def test_Color_of_happy_path(self, input, expected_result): + def test_Color_of_happy_path( + self, input: colors.Color | int | str | tuple[int | float], expected_result: colors.Color + ): result = colors.Color.of(input) assert result == expected_result, f"{input}" @pytest.mark.parametrize( ("input_string", "value_error_match"), - [ + [ # FIXME: This is a weird issue. It does not like the one with set(). ("blah", r"Could not transform 'blah' into a Color object"), ("0xfff1", r"Color code is invalid length\. Must be 3 or 6 digits"), (lambda: 22, r"Could not transform at 0x[a-zA-Z0-9]+> into a Color object"), @@ -318,7 +320,7 @@ def test_Color_of_happy_path(self, input, expected_result): *tuple_str_sad_test_data, ], ) - def test_Color_of_sad_path(self, input_string, value_error_match): + def test_Color_of_sad_path(self, input_string: str, value_error_match: str): with pytest.raises(ValueError, match=value_error_match): colors.Color.of(input_string) @@ -327,10 +329,10 @@ def test_Color_of_with_multiple_args(self): assert result == colors.Color(0xFF051A) @pytest.mark.parametrize(("input_string", "expected_color"), tuple_str_happy_test_data) - def test_from_tuple_string_happy_path(self, input_string, expected_color): + def test_from_tuple_string_happy_path(self, input_string: str, expected_color: colors.Color): assert colors.Color.from_tuple_string(input_string) == expected_color @pytest.mark.parametrize(("input_string", "value_error_match"), tuple_str_sad_test_data) - def test_from_tuple_string_sad_path(self, input_string, value_error_match): + def test_from_tuple_string_sad_path(self, input_string: str, value_error_match: str): with pytest.raises(ValueError, match=value_error_match): colors.Color.from_tuple_string(input_string) diff --git a/tests/hikari/test_commands.py b/tests/hikari/test_commands.py index 37f980d0f6..685dfb9e4c 100644 --- a/tests/hikari/test_commands.py +++ b/tests/hikari/test_commands.py @@ -25,27 +25,27 @@ from hikari import applications from hikari import commands +from hikari import permissions from hikari import snowflakes from hikari import traits from hikari import undefined from tests.hikari import hikari_test_helpers - -@pytest.fixture -def mock_app(): - return mock.Mock(traits.CacheAware, rest=mock.AsyncMock()) +# @pytest.fixture +# def hikari_app() -> traits.RESTAware: +# return mock.Mock(traits.CacheAware, rest=mock.AsyncMock()) class TestPartialCommand: @pytest.fixture - def mock_command(self, mock_app): + def mock_command(self, hikari_app: traits.RESTAware) -> commands.PartialCommand: return hikari_test_helpers.mock_class_namespace(commands.PartialCommand)( - app=mock_app, + app=hikari_app, id=snowflakes.Snowflake(34123123), type=commands.CommandType.SLASH, application_id=snowflakes.Snowflake(65234123), name="Name", - default_member_permissions=None, + default_member_permissions=permissions.Permissions.NONE, is_nsfw=True, guild_id=snowflakes.Snowflake(31231235), version=snowflakes.Snowflake(43123123), @@ -55,91 +55,135 @@ def mock_command(self, mock_app): ) @pytest.mark.asyncio - async def test_fetch_self(self, mock_command, mock_app): - result = await mock_command.fetch_self() + async def test_fetch_self(self, hikari_app: traits.RESTAware, mock_command: commands.PartialCommand): + with mock.patch.object( + hikari_app.rest, "fetch_application_command", mock.AsyncMock() + ) as patched_fetch_application_command: + result = await mock_command.fetch_self() - assert result is mock_app.rest.fetch_application_command.return_value - mock_app.rest.fetch_application_command.assert_awaited_once_with(65234123, 34123123, 31231235) + assert result is patched_fetch_application_command.return_value + patched_fetch_application_command.assert_awaited_once_with(65234123, 34123123, 31231235) @pytest.mark.asyncio - async def test_fetch_self_when_guild_id_is_none(self, mock_command, mock_app): - mock_command.guild_id = None - - result = await mock_command.fetch_self() - - assert result is mock_app.rest.fetch_application_command.return_value - mock_app.rest.fetch_application_command.assert_awaited_once_with(65234123, 34123123, undefined.UNDEFINED) + async def test_fetch_self_when_guild_id_is_none( + self, hikari_app: traits.RESTAware, mock_command: commands.PartialCommand + ): + with ( + mock.patch.object(mock_command, "guild_id", None), + mock.patch.object( + hikari_app.rest, "fetch_application_command", mock.AsyncMock() + ) as patched_fetch_application_command, + ): + result = await mock_command.fetch_self() + + assert result is patched_fetch_application_command.return_value + patched_fetch_application_command.assert_awaited_once_with(65234123, 34123123, undefined.UNDEFINED) @pytest.mark.asyncio - async def test_edit_without_optional_args(self, mock_command, mock_app): - result = await mock_command.edit() - - assert result is mock_app.rest.edit_application_command.return_value - mock_app.rest.edit_application_command.assert_awaited_once_with( - 65234123, - 34123123, - 31231235, - name=undefined.UNDEFINED, - description=undefined.UNDEFINED, - options=undefined.UNDEFINED, - ) + async def test_edit_without_optional_args( + self, hikari_app: traits.RESTAware, mock_command: commands.PartialCommand + ): + with mock.patch.object( + hikari_app.rest, "edit_application_command", mock.AsyncMock() + ) as patched_edit_application_command: + result = await mock_command.edit() + + assert result is patched_edit_application_command.return_value + patched_edit_application_command.assert_awaited_once_with( + 65234123, + 34123123, + 31231235, + name=undefined.UNDEFINED, + description=undefined.UNDEFINED, + options=undefined.UNDEFINED, + ) @pytest.mark.asyncio - async def test_edit_with_optional_args(self, mock_command, mock_app): - mock_option = object() - result = await mock_command.edit(name="new name", description="very descrypt", options=[mock_option]) + async def test_edit_with_optional_args(self, hikari_app: traits.RESTAware, mock_command: commands.PartialCommand): + mock_option = mock.Mock() - assert result is mock_app.rest.edit_application_command.return_value - mock_app.rest.edit_application_command.assert_awaited_once_with( - 65234123, 34123123, 31231235, name="new name", description="very descrypt", options=[mock_option] - ) + with mock.patch.object( + hikari_app.rest, "edit_application_command", mock.AsyncMock() + ) as patched_edit_application_command: + result = await mock_command.edit(name="new name", description="very descrypt", options=[mock_option]) + + assert result is patched_edit_application_command.return_value + patched_edit_application_command.assert_awaited_once_with( + 65234123, 34123123, 31231235, name="new name", description="very descrypt", options=[mock_option] + ) @pytest.mark.asyncio - async def test_edit_when_guild_id_is_none(self, mock_command, mock_app): + async def test_edit_when_guild_id_is_none( + self, hikari_app: traits.RESTAware, mock_command: commands.PartialCommand + ): mock_command.guild_id = None - result = await mock_command.edit() - - assert result is mock_app.rest.edit_application_command.return_value - mock_app.rest.edit_application_command.assert_awaited_once_with( - 65234123, - 34123123, - undefined.UNDEFINED, - name=undefined.UNDEFINED, - description=undefined.UNDEFINED, - options=undefined.UNDEFINED, - ) + with ( + mock.patch.object(mock_command, "guild_id", None), + mock.patch.object( + hikari_app.rest, "edit_application_command", mock.AsyncMock() + ) as patched_edit_application_command, + ): + result = await mock_command.edit() + + assert result is patched_edit_application_command.return_value + patched_edit_application_command.assert_awaited_once_with( + 65234123, + 34123123, + undefined.UNDEFINED, + name=undefined.UNDEFINED, + description=undefined.UNDEFINED, + options=undefined.UNDEFINED, + ) @pytest.mark.asyncio - async def test_delete(self, mock_command, mock_app): - await mock_command.delete() + async def test_delete(self, hikari_app: traits.RESTAware, mock_command: commands.PartialCommand): + with mock.patch.object( + hikari_app.rest, "delete_application_command", mock.AsyncMock() + ) as patched_delete_application_command: + await mock_command.delete() - mock_app.rest.delete_application_command.assert_awaited_once_with(65234123, 34123123, 31231235) + patched_delete_application_command.assert_awaited_once_with(65234123, 34123123, 31231235) @pytest.mark.asyncio - async def test_delete_when_guild_id_is_none(self, mock_command, mock_app): - mock_command.guild_id = None - - await mock_command.delete() - - mock_app.rest.delete_application_command.assert_awaited_once_with(65234123, 34123123, undefined.UNDEFINED) + async def test_delete_when_guild_id_is_none( + self, hikari_app: traits.RESTAware, mock_command: commands.PartialCommand + ): + with ( + mock.patch.object(mock_command, "guild_id", None), + mock.patch.object( + hikari_app.rest, "delete_application_command", mock.AsyncMock() + ) as patched_delete_application_command, + ): + await mock_command.delete() + + patched_delete_application_command.assert_awaited_once_with(65234123, 34123123, undefined.UNDEFINED) @pytest.mark.asyncio - async def test_fetch_guild_permissions(self, mock_command, mock_app): - result = await mock_command.fetch_guild_permissions(123321) + async def test_fetch_guild_permissions(self, hikari_app: traits.RESTAware, mock_command: commands.PartialCommand): + with mock.patch.object( + hikari_app.rest, "fetch_application_command_permissions", mock.AsyncMock() + ) as patched_fetch_application_command_permissions: + result = await mock_command.fetch_guild_permissions(123321) - assert result is mock_app.rest.fetch_application_command_permissions.return_value - mock_app.rest.fetch_application_command_permissions.assert_awaited_once_with( - application=mock_command.application_id, guild=123321, command=mock_command.id - ) + assert result is patched_fetch_application_command_permissions.return_value + patched_fetch_application_command_permissions.assert_awaited_once_with( + application=mock_command.application_id, guild=123321, command=mock_command.id + ) @pytest.mark.asyncio - async def test_set_guild_permissions(self, mock_command, mock_app): - mock_permissions = object() - - result = await mock_command.set_guild_permissions(312123, mock_permissions) - - assert result is mock_app.rest.set_application_command_permissions.return_value - mock_app.rest.set_application_command_permissions.assert_awaited_once_with( - application=mock_command.application_id, guild=312123, command=mock_command.id, permissions=mock_permissions - ) + async def test_set_guild_permissions(self, hikari_app: traits.RESTAware, mock_command: commands.PartialCommand): + mock_permissions = mock.Mock() + + with mock.patch.object( + hikari_app.rest, "set_application_command_permissions", mock.AsyncMock() + ) as patched_set_application_command_permissions: + result = await mock_command.set_guild_permissions(312123, mock_permissions) + + assert result is patched_set_application_command_permissions.return_value + patched_set_application_command_permissions.assert_awaited_once_with( + application=mock_command.application_id, + guild=312123, + command=mock_command.id, + permissions=mock_permissions, + ) diff --git a/tests/hikari/test_components.py b/tests/hikari/test_components.py index 7365645b7f..bfe75eb00e 100644 --- a/tests/hikari/test_components.py +++ b/tests/hikari/test_components.py @@ -20,33 +20,35 @@ # SOFTWARE. from __future__ import annotations +import mock + from hikari import components class TestActionRowComponent: def test_getitem_operator_with_index(self): - mock_component = object() - row = components.ActionRowComponent(type=1, id=5855932, components=[object(), mock_component, object()]) + mock_component = mock.Mock() + row = components.ActionRowComponent(type=1, id=5855932, components=[mock.Mock(), mock_component, mock.Mock()]) assert row[1] is mock_component def test_getitem_operator_with_slice(self): - mock_component_1 = object() - mock_component_2 = object() + mock_component_1 = mock.Mock() + mock_component_2 = mock.Mock() row = components.ActionRowComponent( - type=1, id=5855932, components=[object(), mock_component_1, object(), mock_component_2] + type=1, id=5855932, components=[mock.Mock(), mock_component_1, mock.Mock(), mock_component_2] ) assert row[1:4:2] == [mock_component_1, mock_component_2] def test_iter_operator(self): - mock_component_1 = object() - mock_component_2 = object() + mock_component_1 = mock.Mock() + mock_component_2 = mock.Mock() row = components.ActionRowComponent(type=1, id=5855932, components=[mock_component_1, mock_component_2]) assert list(row) == [mock_component_1, mock_component_2] def test_len_operator(self): - row = components.ActionRowComponent(type=1, id=5855932, components=[object(), object()]) + row = components.ActionRowComponent(type=1, id=5855932, components=[mock.Mock(), mock.Mock()]) assert len(row) == 2 diff --git a/tests/hikari/test_embeds.py b/tests/hikari/test_embeds.py index 8aee13fa73..4952372f03 100644 --- a/tests/hikari/test_embeds.py +++ b/tests/hikari/test_embeds.py @@ -28,39 +28,42 @@ class TestEmbedResource: @pytest.fixture - def resource(self): + def resource(self) -> embeds.EmbedResource: return embeds.EmbedResource(resource=mock.Mock()) - def test_url(self, resource): + def test_url(self, resource: embeds.EmbedResource): assert resource.url is resource.resource.url - def test_filename(self, resource): + def test_filename(self, resource: embeds.EmbedResource): assert resource.filename is resource.resource.filename - def test_stream(self, resource): - mock_executor = object() + def test_stream(self, resource: embeds.EmbedResource): + mock_executor = mock.Mock() - assert resource.stream(executor=mock_executor, head_only=True) is resource.resource.stream.return_value + with mock.patch.object(resource.resource, "stream") as patched_stream: + assert resource.stream(executor=mock_executor, head_only=True) is patched_stream.return_value - resource.resource.stream.assert_called_once_with(executor=mock_executor, head_only=True) + patched_stream.assert_called_once_with(executor=mock_executor, head_only=True) class TestEmbedResourceWithProxy: @pytest.fixture - def resource_with_proxy(self): + def resource_with_proxy(self) -> embeds.EmbedResourceWithProxy: return embeds.EmbedResourceWithProxy(resource=mock.Mock(), proxy_resource=mock.Mock()) - def test_proxy_url(self, resource_with_proxy): + def test_proxy_url(self, resource_with_proxy: embeds.EmbedResourceWithProxy): + assert resource_with_proxy.proxy_resource is not None assert resource_with_proxy.proxy_url is resource_with_proxy.proxy_resource.url - def test_proxy_url_when_resource_is_none(self, resource_with_proxy): + def test_proxy_url_when_resource_is_none(self, resource_with_proxy: embeds.EmbedResourceWithProxy): resource_with_proxy.proxy_resource = None assert resource_with_proxy.proxy_url is None - def test_proxy_filename(self, resource_with_proxy): + def test_proxy_filename(self, resource_with_proxy: embeds.EmbedResourceWithProxy): + assert resource_with_proxy.proxy_resource is not None assert resource_with_proxy.proxy_filename is resource_with_proxy.proxy_resource.filename - def test_proxy_filename_when_resource_is_none(self, resource_with_proxy): + def test_proxy_filename_when_resource_is_none(self, resource_with_proxy: embeds.EmbedResourceWithProxy): resource_with_proxy.proxy_resource = None assert resource_with_proxy.proxy_filename is None diff --git a/tests/hikari/test_emojis.py b/tests/hikari/test_emojis.py index 5c1f42a1a6..1205ddb0d4 100644 --- a/tests/hikari/test_emojis.py +++ b/tests/hikari/test_emojis.py @@ -20,6 +20,8 @@ # SOFTWARE. from __future__ import annotations +import typing + import pytest from hikari import emojis @@ -40,25 +42,25 @@ class TestEmoji: ), ], ) - def test_parse(self, input, output): + def test_parse(self, input: str, output: emojis.Emoji): assert emojis.Emoji.parse(input) == output class TestUnicodeEmoji: @pytest.fixture - def emoji(self): + def emoji(self) -> emojis.UnicodeEmoji: return emojis.UnicodeEmoji("\N{OK HAND SIGN}") - def test_name_property(self, emoji): + def test_name_property(self, emoji: emojis.UnicodeEmoji): assert emoji.name == emoji - def test_url_name_property(self, emoji): + def test_url_name_property(self, emoji: emojis.UnicodeEmoji): assert emoji.url_name == emoji - def test_mention_property(self, emoji): + def test_mention_property(self, emoji: emojis.UnicodeEmoji): assert emoji.mention == emoji - def test_codepoints_property(self, emoji): + def test_codepoints_property(self, emoji: emojis.UnicodeEmoji): assert emoji.codepoints == [128076] @pytest.mark.parametrize( @@ -78,23 +80,23 @@ def test_codepoints_property(self, emoji): ([0x200D, 0xFE0F, 0x1F44C, 0x1F44C], "200d-1f44c-1f44c"), ], ) - def test_filename_property(self, codepoints, expected_filename): + def test_filename_property(self, codepoints: typing.Sequence[int], expected_filename: str): emoji = emojis.UnicodeEmoji.parse_codepoints(*codepoints) assert emoji.filename == f"{expected_filename}.png" - def test_url_property(self, emoji): + def test_url_property(self, emoji: emojis.UnicodeEmoji): assert emoji.url == "https://raw.githubusercontent.com/discord/twemoji/master/assets/72x72/1f44c.png" - def test_unicode_escape_property(self, emoji): + def test_unicode_escape_property(self, emoji: emojis.UnicodeEmoji): assert emoji.unicode_escape == "\\U0001f44c" - def test_parse_codepoints(self, emoji): + def test_parse_codepoints(self, emoji: emojis.UnicodeEmoji): assert emojis.UnicodeEmoji.parse_codepoints(128076) == emoji - def test_parse_unicode_escape(self, emoji): + def test_parse_unicode_escape(self, emoji: emojis.UnicodeEmoji): assert emojis.UnicodeEmoji.parse_unicode_escape("\\U0001f44c") == emoji - def test_str_operator(self, emoji): + def test_str_operator(self, emoji: emojis.UnicodeEmoji): assert str(emoji) == emoji @pytest.mark.parametrize( @@ -107,34 +109,34 @@ def test_str_operator(self, emoji): ), ], ) - def test_parse(self, input, output): + def test_parse(self, input: str, output: emojis.UnicodeEmoji): assert emojis.UnicodeEmoji.parse(input) == output class TestCustomEmoji: @pytest.fixture - def emoji(self): - return emojis.CustomEmoji(id=3213452, name="ok", is_animated=False) + def emoji(self) -> emojis.CustomEmoji: + return emojis.CustomEmoji(id=snowflakes.Snowflake(3213452), name="ok", is_animated=False) - def test_filename_property(self, emoji): + def test_filename_property(self, emoji: emojis.CustomEmoji): assert emoji.filename == "3213452.png" - def test_filename_property_when_animated(self, emoji): + def test_filename_property_when_animated(self, emoji: emojis.CustomEmoji): emoji.is_animated = True assert emoji.filename == "3213452.gif" - def test_url_name_property(self, emoji): + def test_url_name_property(self, emoji: emojis.CustomEmoji): assert emoji.url_name == "ok:3213452" - def test_mention_property(self, emoji): + def test_mention_property(self, emoji: emojis.CustomEmoji): assert emoji.mention == "<:ok:3213452>" - def test_mention_property_when_animated(self, emoji): + def test_mention_property_when_animated(self, emoji: emojis.CustomEmoji): emoji.is_animated = True assert emoji.mention == "" - def test_url_property(self, emoji): + def test_url_property(self, emoji: emojis.CustomEmoji): assert emoji.url == "https://cdn.discordapp.com/emojis/3213452.png" def test_str_operator_when_populated_name(self): @@ -149,7 +151,7 @@ def test_str_operator_when_populated_name(self): ("", emojis.CustomEmoji(id=snowflakes.Snowflake(12345), name="foo", is_animated=True)), ], ) - def test_parse(self, input, output): + def test_parse(self, input: str, output: emojis.CustomEmoji): assert emojis.CustomEmoji.parse(input) == output def test_parse_unhappy_path(self): diff --git a/tests/hikari/test_errors.py b/tests/hikari/test_errors.py index 5aaf181037..94a4155742 100644 --- a/tests/hikari/test_errors.py +++ b/tests/hikari/test_errors.py @@ -22,6 +22,7 @@ import http import inspect +import typing import mock import pytest @@ -32,96 +33,96 @@ class TestShardCloseCode: @pytest.mark.parametrize(("code", "expected"), [(1000, True), (1001, True), (4000, False), (4014, False)]) - def test_is_standard_property(self, code, expected): + def test_is_standard_property(self, code: int, expected: bool): assert errors.ShardCloseCode(code).is_standard is expected class TestComponentStateConflictError: @pytest.fixture - def error(self): + def error(self) -> errors.ComponentStateConflictError: return errors.ComponentStateConflictError("some reason") - def test_str(self, error): + def test_str(self, error: errors.ComponentStateConflictError): assert str(error) == "some reason" class TestUnrecognisedEntityError: @pytest.fixture - def error(self): + def error(self) -> errors.UnrecognisedEntityError: return errors.UnrecognisedEntityError("some reason") - def test_str(self, error): + def test_str(self, error: errors.UnrecognisedEntityError): assert str(error) == "some reason" class TestGatewayError: @pytest.fixture - def error(self): + def error(self) -> errors.GatewayError: return errors.GatewayError("some reason") - def test_str(self, error): + def test_str(self, error: errors.GatewayError): assert str(error) == "some reason" class TestGatewayServerClosedConnectionError: @pytest.fixture - def error(self): + def error(self) -> errors.GatewayServerClosedConnectionError: return errors.GatewayServerClosedConnectionError("some reason", 123) - def test_str(self, error): + def test_str(self, error: errors.GatewayServerClosedConnectionError): assert str(error) == "Server closed connection with code 123 (some reason)" class TestHTTPResponseError: @pytest.fixture - def error(self): + def error(self) -> errors.HTTPResponseError: return errors.HTTPResponseError( - "https://some.url", http.HTTPStatus.BAD_REQUEST, {}, "raw body", "message", 12345 + url="https://some.url", + status=http.HTTPStatus.BAD_REQUEST, + headers={}, + raw_body="raw body", + message="message", + code=12345, ) - def test_str(self, error): + def test_str(self, error: errors.HTTPResponseError): assert str(error) == "Bad Request 400: (12345) 'message' for https://some.url" - def test_str_when_int_status_code(self, error): + def test_str_when_int_status_code(self, error: errors.HTTPResponseError): error.status = 699 assert str(error) == "Unknown Status 699: (12345) 'message' for https://some.url" - def test_str_when_message_is_None(self, error): - error.message = None - assert str(error) == "Bad Request 400: (12345) 'raw body' for https://some.url" + def test_str_when_message_is_None(self, error: errors.HTTPResponseError): + with mock.patch.object(error, "message", None): + assert str(error) == "Bad Request 400: (12345) 'raw body' for https://some.url" - def test_str_when_code_is_zero(self, error): + def test_str_when_code_is_zero(self, error: errors.HTTPResponseError): error.code = 0 assert str(error) == "Bad Request 400: 'message' for https://some.url" - def test_str_when_code_is_not_zero(self, error): + def test_str_when_code_is_not_zero(self, error: errors.HTTPResponseError): error.code = 100 assert str(error) == "Bad Request 400: (100) 'message' for https://some.url" class TestBadRequestError: @pytest.fixture - def error(self): - return errors.BadRequestError( - "https://some.url", - http.HTTPStatus.BAD_REQUEST, - {}, - "raw body", - errors={ - "": [{"code": "012", "message": "test error"}], - "components": { - "0": { - "_errors": [ - {"code": "123", "message": "something went wrong"}, - {"code": "456", "message": "but more things too!"}, - ] - } - }, - "attachments": {"1": {"_errors": [{"code": "789", "message": "at this point, all wrong!"}]}}, + def error(self) -> errors.BadRequestError: + errors_payload: dict[str, typing.Any] = { + "": [{"code": "012", "message": "test error"}], + "components": { + "0": { + "_errors": [ + {"code": "123", "message": "something went wrong"}, + {"code": "456", "message": "but more things too!"}, + ] + } }, - ) + "attachments": {"1": {"_errors": [{"code": "789", "message": "at this point, all wrong!"}]}}, + } + return errors.BadRequestError(url="https://some.url", headers={}, raw_body="raw body", errors=errors_payload) - def test_str(self, error): + def test_str(self, error: errors.BadRequestError): assert str(error) == inspect.cleandoc( """ Bad Request 400: 'raw body' for https://some.url @@ -138,7 +139,7 @@ def test_str(self, error): """ ) - def test_str_when_dump_error_errors(self, error): + def test_str_when_dump_error_errors(self, error: errors.BadRequestError): with mock.patch.object(errors, "_dump_errors", side_effect=KeyError): string = str(error) @@ -181,15 +182,13 @@ def test_str_when_dump_error_errors(self, error): """ ) - def test_str_when_cached(self, error): - error._cached_str = "ok" - - with mock.patch.object(errors, "_dump_errors") as dump_errors: + def test_str_when_cached(self, error: errors.BadRequestError): + with mock.patch.object(error, "_cached_str", "ok"), mock.patch.object(errors, "_dump_errors") as dump_errors: assert str(error) == "ok" dump_errors.assert_not_called() - def test_str_when_no_errors(self, error): + def test_str_when_no_errors(self, error: errors.BadRequestError): error.errors = None with mock.patch.object(errors, "_dump_errors") as dump_errors: @@ -200,34 +199,40 @@ def test_str_when_no_errors(self, error): class TestRateLimitTooLongError: @pytest.fixture - def error(self): + def error(self) -> errors.RateLimitTooLongError: return errors.RateLimitTooLongError( - route="some route", is_global=False, retry_after=0, max_retry_after=60, reset_at=0, limit=0, period=0 + route=mock.PropertyMock(return_value="some route"), + is_global=False, + retry_after=0, + max_retry_after=60, + reset_at=0, + limit=0, + period=0, ) - def test_remaining(self, error): + def test_remaining(self, error: errors.RateLimitTooLongError): assert error.remaining == 0 - def test_str(self, error): + def test_str(self, error: errors.RateLimitTooLongError): assert str(error) == ( "The request has been rejected, as you would be waiting for more than " - "the max retry-after (60) on route 'some route' [is_global=False]" + f"the max retry-after (60) on route '{error.route}' [is_global=False]" ) class TestBulkDeleteError: @pytest.fixture - def error(self): + def error(self) -> errors.BulkDeleteError: return errors.BulkDeleteError(range(10)) - def test_str(self, error): + def test_str(self, error: errors.BulkDeleteError): assert str(error) == "Error encountered when bulk deleting messages (10 messages deleted)" class TestMissingIntentError: @pytest.fixture - def error(self): + def error(self) -> errors.MissingIntentError: return errors.MissingIntentError(intents.Intents.GUILD_MEMBERS | intents.Intents.GUILD_EMOJIS) - def test_str(self, error): + def test_str(self, error: errors.MissingIntentError): assert str(error) == "You are missing the following intent(s): GUILD_EMOJIS, GUILD_MEMBERS" diff --git a/tests/hikari/test_files.py b/tests/hikari/test_files.py index 0d11c08ddf..595ed96ddd 100644 --- a/tests/hikari/test_files.py +++ b/tests/hikari/test_files.py @@ -23,6 +23,7 @@ import asyncio import pathlib import shutil +import typing import mock import pytest @@ -45,14 +46,14 @@ def test_set_filename(self): class TestAsyncReaderContextManager: @pytest.fixture - def reader(self): + def reader(self) -> typing.Type[files.AsyncReaderContextManager[files.AsyncReader]]: return hikari_test_helpers.mock_class_namespace(files.AsyncReaderContextManager) - def test___enter__(self, reader): + def test___enter__(self, reader: files.AsyncReaderContextManager[files.AsyncReader]): with pytest.raises(TypeError, match=" is async-only, did you mean 'async with'?"): reader().__enter__() - def test___exit__(self, reader): + def test___exit__(self, reader: files.AsyncReaderContextManager[files.AsyncReader]): try: reader().__exit__(None, None, None) except AttributeError as exc: @@ -86,7 +87,7 @@ async def test_exit_dunder_method_when_not_open(self): @pytest.mark.asyncio async def test_context_manager(self): mock_file = mock.Mock() - executor = object() + executor = mock.Mock() path = pathlib.Path("test/path/") loop = mock.Mock(run_in_executor=mock.AsyncMock(side_effect=[mock_file, None])) @@ -152,14 +153,14 @@ def test_open_write_path(): class TestResource: @pytest.fixture - def resource(self): + def resource(self) -> files.Resource[files.AsyncReader]: class MockReader: data = iter(("never", "gonna", "give", "you", "up")) async def __aenter__(self): return self - async def __aexit__(self, *args, **kwargs): + async def __aexit__(self, *args: typing.Any, **kwargs: typing.Any): return def __aiter__(self): @@ -171,16 +172,22 @@ async def __anext__(self): except StopIteration: raise StopAsyncIteration from None - class ResourceImpl(files.Resource): + class ResourceImpl(files.Resource[typing.Any]): stream = mock.Mock(return_value=MockReader()) - url = "https://myspace.com/rickastley/lyrics.txt" - filename = "lyrics.txt" + + @property + def url(self) -> str: + return "https://myspace.com/rickastley/lyrics.txt" + + @property + def filename(self) -> str: + return "lyrics.txt" return ResourceImpl() @pytest.mark.asyncio - async def test_save(self, resource): - executor = object() + async def test_save(self, resource: files.Resource[files.AsyncReader]): + executor = mock.Mock() file_open = mock.Mock() file_open.write = mock.Mock() loop = mock.Mock(run_in_executor=mock.AsyncMock(side_effect=[file_open, None, None, None, None, None, None])) @@ -203,12 +210,13 @@ async def test_save(self, resource): def test_copy_to_path(): + original_path = pathlib.Path("original_path") with mock.patch.object(files, "_to_write_path") as to_write_path: with mock.patch.object(shutil, "copy2") as copy2: - files._copy_to_path("original_path", "path", "some_filename.png", False) + files._copy_to_path(original_path, "path", "some_filename.png", False) to_write_path.assert_called_once_with("path", "some_filename.png", force=False) - copy2.assert_called_once_with("original_path", to_write_path.return_value) + copy2.assert_called_once_with(original_path, to_write_path.return_value) class TestFile: @@ -217,8 +225,8 @@ def file_obj(self): return files.File("one/path/something.txt") @pytest.mark.asyncio - async def test_save(self, file_obj): - mock_executor = object() + async def test_save(self, file_obj: files.File): + mock_executor = mock.Mock() loop = mock.Mock(run_in_executor=mock.AsyncMock()) with mock.patch.object(asyncio, "get_running_loop", return_value=loop): @@ -246,14 +254,16 @@ def test_write_bytes(): class TestBytes: @pytest.fixture - def bytes_obj(self): + def bytes_obj(self) -> files.Bytes: return files.Bytes(b"some data", "something.txt") @pytest.mark.parametrize("data_type", [bytes, bytearray, memoryview]) @pytest.mark.asyncio - async def test_save(self, bytes_obj, data_type): + async def test_save( + self, bytes_obj: files.Bytes, data_type: type[bytes] | type[bytearray] | type[memoryview[typing.Any]] + ): bytes_obj.data = mock.Mock(data_type) - mock_executor = object() + mock_executor = mock.Mock() loop = mock.Mock(run_in_executor=mock.AsyncMock()) with mock.patch.object(asyncio, "get_running_loop", return_value=loop): @@ -266,9 +276,9 @@ async def test_save(self, bytes_obj, data_type): ) @pytest.mark.asyncio - async def test_save_when_data_is_not_bytes(self, bytes_obj): - bytes_obj.data = object() - mock_executor = object() + async def test_save_when_data_is_not_bytes(self, bytes_obj: files.Bytes): + bytes_obj.data = mock.Mock() + mock_executor = mock.Mock() with mock.patch.object(asyncio, "get_running_loop") as get_running_loop: with mock.patch.object(files.Resource, "save") as super_save: diff --git a/tests/hikari/test_guilds.py b/tests/hikari/test_guilds.py index 5db4f89f60..feafafd203 100644 --- a/tests/hikari/test_guilds.py +++ b/tests/hikari/test_guilds.py @@ -30,6 +30,7 @@ from hikari import guilds from hikari import permissions from hikari import snowflakes +from hikari import traits from hikari import undefined from hikari import urls from hikari import users @@ -40,19 +41,19 @@ @pytest.fixture -def mock_app(): +def mock_app() -> traits.RESTAware: return mock.Mock(spec_set=gateway_bot.GatewayBot) class TestPartialRole: @pytest.fixture - def model(self, mock_app): + def model(self, mock_app: traits.RESTAware) -> guilds.PartialRole: return guilds.PartialRole(app=mock_app, id=snowflakes.Snowflake(1106913972), name="The Big Cool") - def test_str_operator(self, model): + def test_str_operator(self, model: guilds.PartialRole): assert str(model) == "The Big Cool" - def test_mention_property(self, model): + def test_mention_property(self, model: guilds.PartialRole): assert model.mention == "<@&1106913972>" @@ -64,61 +65,66 @@ def test_PartialApplication_str_operator(): class TestPartialApplication: @pytest.fixture - def model(self): - return hikari_test_helpers.mock_class_namespace( - guilds.PartialApplication, init_=False, slots_=False, id=123, icon_hash="ahashicon" - )() + def partial_application(self) -> guilds.PartialApplication: + return guilds.PartialApplication( + id=snowflakes.Snowflake(123), + name="partial_application", + description="partial_application_description", + icon_hash="icon_hash", + ) - def test_make_icon_url_format_set_to_deprecated_ext_argument_if_provided(self, model): + def test_make_icon_url_format_set_to_deprecated_ext_argument_if_provided( + self, partial_application: guilds.PartialApplication + ): with mock.patch.object( routes, "CDN_APPLICATION_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_icon_url(ext="JPEG") == "file" + assert partial_application.make_icon_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, application_id=123, hash="ahashicon", size=4096, file_format="JPEG", lossless=True + urls.CDN_URL, application_id=123, hash="icon_hash", size=4096, file_format="JPEG", lossless=True ) - def test_icon_url_property(self, model): - model.make_icon_url = mock.Mock(return_value="url") + def test_icon_url_property(self, partial_application: guilds.PartialApplication): + with mock.patch.object(guilds.PartialApplication, "make_icon_url", return_value="url"): + assert partial_application.icon_url == "url" - assert model.icon_url == "url" + def test_make_icon_url_when_hash_is_None(self, partial_application: guilds.PartialApplication): + partial_application.icon_hash = None - model.make_icon_url.assert_called_once_with() + with ( + mock.patch.object(routes, "CDN_APPLICATION_ICON") as patched_route, + mock.patch.object( + patched_route, "compile_to_file", mock.Mock(return_value="file") + ) as patched_compile_to_file, + ): + assert partial_application.make_icon_url(ext="jpeg", size=1) is None - def test_make_icon_url_when_hash_is_None(self, model): - model.icon_hash = None - - with mock.patch.object( - routes, "CDN_APPLICATION_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) - ) as route: - assert model.make_icon_url(file_format="JPEG", size=1) is None - - route.compile_to_file.assert_not_called() + patched_compile_to_file.assert_not_called() - def test_make_icon_url_when_hash_is_not_None(self, model): + def test_make_icon_url_when_hash_is_not_None(self, partial_application: guilds.PartialApplication): with mock.patch.object( routes, "CDN_APPLICATION_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_icon_url(file_format="JPEG", size=1) == "file" + assert partial_application.make_icon_url(file_format="JPEG", size=1) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, application_id=123, hash="ahashicon", size=1, file_format="JPEG", lossless=True + urls.CDN_URL, application_id=123, hash="icon_hash", size=1, file_format="JPEG", lossless=True ) class TestIntegrationAccount: @pytest.fixture - def model(self, mock_app): + def integration_account(self) -> guilds.IntegrationAccount: return guilds.IntegrationAccount(id="foo", name="bar") - def test_str_operator(self, model): - assert str(model) == "bar" + def test_str_operator(self, integration_account: guilds.IntegrationAccount): + assert str(integration_account) == "bar" class TestPartialIntegration: @pytest.fixture - def model(self, mock_app): + def model(self) -> guilds.PartialIntegration: return guilds.PartialIntegration( account=mock.Mock(return_value=guilds.IntegrationAccount), id=snowflakes.Snowflake(69420), @@ -126,13 +132,13 @@ def model(self, mock_app): type="illegal", ) - def test_str_operator(self, model): + def test_str_operator(self, model: guilds.PartialIntegration): assert str(model) == "nice" class TestRole: @pytest.fixture - def model(self, mock_app): + def role(self, mock_app: traits.RESTAware) -> guilds.Role: return guilds.Role( app=mock_app, id=snowflakes.Snowflake(979899100), @@ -154,47 +160,47 @@ def model(self, mock_app): is_available_for_purchase=True, ) - def test_colour_property(self, model): - assert model.colour == colors.Color(0x1A2B3C) + def test_colour_property(self, role: guilds.Role): + assert role.colour == colors.Color(0x1A2B3C) - def test_make_icon_url_format_set_to_deprecated_ext_argument_if_provided(self, model): + def test_make_icon_url_format_set_to_deprecated_ext_argument_if_provided(self, role: guilds.Role): with mock.patch.object( routes, "CDN_ROLE_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_icon_url(ext="JPEG") == "file" + assert role.make_icon_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, role_id=979899100, hash="icon_hash", size=4096, file_format="JPEG", lossless=True ) - def test_icon_url_property(self, model): - with mock.patch.object(guilds.Role, "make_icon_url") as make_icon_url: - assert model.icon_url == make_icon_url.return_value + def test_icon_url_property(self, role: guilds.Role): + with mock.patch.object(guilds.Role, "make_icon_url") as patched_make_icon_url: + assert role.icon_url == patched_make_icon_url.return_value - model.make_icon_url.assert_called_once_with() + patched_make_icon_url.assert_called_once_with() - def test_mention_property(self, model): - assert model.mention == "<@&979899100>" + def test_mention_property(self, role: guilds.Role): + assert role.mention == "<@&979899100>" - def test_mention_property_when_is_everyone_role(self, model): - model.id = model.guild_id - assert model.mention == "@everyone" + def test_mention_property_when_is_everyone_role(self, role: guilds.Role): + role.id = role.guild_id + assert role.mention == "@everyone" - def test_make_icon_url_when_hash_is_None(self, model): - model.icon_hash = None + def test_make_icon_url_when_hash_is_None(self, role: guilds.Role): + role.icon_hash = None with mock.patch.object( routes, "CDN_ROLE_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_icon_url(file_format="JPEG", size=1) is None + assert role.make_icon_url(file_format="JPEG", size=1) is None route.compile_to_file.assert_not_called() - def test_make_icon_url_when_hash_is_not_None(self, model): + def test_make_icon_url_when_hash_is_not_None(self, role: guilds.Role): with mock.patch.object( routes, "CDN_ROLE_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_icon_url(file_format="JPEG", size=1) == "file" + assert role.make_icon_url(file_format="JPEG", size=1) == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, role_id=979899100, hash="icon_hash", size=1, file_format="JPEG", lossless=True @@ -203,41 +209,44 @@ def test_make_icon_url_when_hash_is_not_None(self, model): class TestGuildWidget: @pytest.fixture - def model(self, mock_app): + def guild_widget(self, mock_app: traits.RESTAware) -> guilds.GuildWidget: return guilds.GuildWidget(app=mock_app, channel_id=snowflakes.Snowflake(420), is_enabled=True) - def test_app_property(self, model, mock_app): - assert model.app is mock_app + def test_app_property(self, guild_widget: guilds.GuildWidget, mock_app: traits.RESTAware): + assert guild_widget.app is mock_app - def test_channel_property(self, model): - assert model.channel_id == snowflakes.Snowflake(420) + def test_channel_property(self, guild_widget: guilds.GuildWidget): + assert guild_widget.channel_id == snowflakes.Snowflake(420) - def test_is_enabled_property(self, model): - assert model.is_enabled is True + def test_is_enabled_property(self, guild_widget: guilds.GuildWidget): + assert guild_widget.is_enabled is True @pytest.mark.asyncio - async def test_fetch_channel(self, model): + async def test_fetch_channel(self, guild_widget: guilds.GuildWidget): mock_channel = mock.Mock(channels_.GuildChannel) - model.app.rest.fetch_channel = mock.AsyncMock(return_value=mock_channel) - assert await model.fetch_channel() is model.app.rest.fetch_channel.return_value - model.app.rest.fetch_channel.assert_awaited_once_with(420) + with mock.patch.object( + guild_widget.app.rest, "fetch_channel", new_callable=mock.AsyncMock, return_value=mock_channel + ) as patched_fetch_channel: + assert await guild_widget.fetch_channel() is patched_fetch_channel.return_value + + patched_fetch_channel.assert_awaited_once_with(420) @pytest.mark.asyncio - async def test_fetch_channel_when_None(self, model): - model.app.rest.fetch_channel = mock.AsyncMock() - model.channel_id = None + async def test_fetch_channel_when_None(self, guild_widget: guilds.GuildWidget): + guild_widget.app.rest.fetch_channel = mock.AsyncMock() + guild_widget.channel_id = None - assert await model.fetch_channel() is None + assert await guild_widget.fetch_channel() is None class TestMember: @pytest.fixture - def mock_user(self): + def mock_user(self) -> users.User: return mock.Mock(id=snowflakes.Snowflake(123)) @pytest.fixture - def model(self, mock_user): + def member(self, mock_user: users.User) -> guilds.Member: return guilds.Member( guild_id=snowflakes.Snowflake(456), is_deaf=True, @@ -259,319 +268,323 @@ def model(self, mock_user): guild_flags=guilds.GuildMemberFlags.NONE, ) - def test_str_operator(self, model, mock_user): - assert str(model) == str(mock_user) + def test_str_operator(self, member: guilds.Member, mock_user: users.User): + assert str(member) == str(mock_user) - def test_app_property(self, model): - assert model.app is model.user.app + def test_app_property(self, member: guilds.Member): + assert member.app is member.user.app - def test_id_property(self, model): - assert model.id is model.user.id + def test_id_property(self, member: guilds.Member): + assert member.id is member.user.id - def test_username_property(self, model): - assert model.username is model.user.username + def test_username_property(self, member: guilds.Member): + assert member.username is member.user.username - def test_discriminator_property(self, model): - assert model.discriminator is model.user.discriminator + def test_discriminator_property(self, member: guilds.Member): + assert member.discriminator is member.user.discriminator - def test_avatar_hash_property(self, model): - assert model.avatar_hash is model.user.avatar_hash + def test_avatar_hash_property(self, member: guilds.Member): + assert member.avatar_hash is member.user.avatar_hash - def test_is_bot_property(self, model): - assert model.is_bot is model.user.is_bot + def test_is_bot_property(self, member: guilds.Member): + assert member.is_bot is member.user.is_bot - def test_is_system_property(self, model): - assert model.is_system is model.user.is_system + def test_is_system_property(self, member: guilds.Member): + assert member.is_system is member.user.is_system - def test_flags_property(self, model): - assert model.flags is model.user.flags + def test_flags_property(self, member: guilds.Member): + assert member.flags is member.user.flags - def test_display_avatar_decoration_property_when_guild_avatar_decoration_is_set(self, model): - assert model.display_avatar_decoration is model.guild_avatar_decoration + def test_display_avatar_decoration_property_when_guild_avatar_decoration_is_set(self, member: guilds.Member): + assert member.display_avatar_decoration is member.guild_avatar_decoration - def test_display_avatar_decoration_property_when_guild_avatar_decoration_is_None(self, model): - model.guild_avatar_decoration = None - assert model.display_avatar_decoration is model.user.display_avatar_decoration + def test_display_avatar_decoration_property_when_guild_avatar_decoration_is_None(self, member: guilds.Member): + member.guild_avatar_decoration = None + assert member.display_avatar_decoration is member.user.display_avatar_decoration - def test_avatar_url_property(self, model): - assert model.avatar_url is model.user.make_avatar_url() + def test_avatar_url_property(self, member: guilds.Member): + assert member.avatar_url is member.user.make_avatar_url() - def test_display_avatar_url_when_guild_hash_is_None(self, model): + def test_display_avatar_url_when_guild_hash_is_None(self, member: guilds.Member): with mock.patch.object(guilds.Member, "make_guild_avatar_url") as mock_make_guild_avatar_url: - assert model.display_avatar_url is mock_make_guild_avatar_url.return_value + assert member.display_avatar_url is mock_make_guild_avatar_url.return_value - def test_display_guild_avatar_url_when_guild_hash_is_not_None(self, model): + def test_display_guild_avatar_url_when_guild_hash_is_not_None(self, member: guilds.Member): with mock.patch.object(guilds.Member, "make_guild_avatar_url", return_value=None): - assert model.display_avatar_url is model.user.display_avatar_url + assert member.display_avatar_url is member.user.display_avatar_url - def test_display_banner_url_when_guild_hash_is_None(self, model): + def test_display_banner_url_when_guild_hash_is_None(self, member: guilds.Member): with mock.patch.object(guilds.Member, "make_guild_banner_url") as mock_make_guild_banner_url: - assert model.display_banner_url is mock_make_guild_banner_url.return_value + assert member.display_banner_url is mock_make_guild_banner_url.return_value - def test_display_banner_url_when_guild_hash_is_not_None(self, model): + def test_display_banner_url_when_guild_hash_is_not_None(self, member: guilds.Member): with mock.patch.object(guilds.Member, "make_guild_banner_url", return_value=None): - assert model.display_banner_url is model.user.display_banner_url + assert member.display_banner_url is member.user.display_banner_url - def test_banner_hash_property(self, model): - assert model.banner_hash is model.user.banner_hash + def test_banner_hash_property(self, member: guilds.Member): + assert member.banner_hash is member.user.banner_hash - def test_banner_url_property(self, model): - assert model.banner_url is model.user.make_banner_url() + def test_banner_url_property(self, member: guilds.Member): + assert member.banner_url is member.user.make_banner_url() - def test_accent_color_property(self, model): - assert model.accent_color is model.user.accent_color + def test_accent_color_property(self, member: guilds.Member): + assert member.accent_color is member.user.accent_color - def test_guild_avatar_url_property(self, model): + def test_guild_avatar_url_property(self, member: guilds.Member): with mock.patch.object(guilds.Member, "make_guild_avatar_url") as make_guild_avatar_url: - assert model.guild_avatar_url is make_guild_avatar_url.return_value + assert member.guild_avatar_url is make_guild_avatar_url.return_value - def test_guild_banner_url_property(self, model): + def test_guild_banner_url_property(self, member: guilds.Member): with mock.patch.object(guilds.Member, "make_guild_banner_url") as make_guild_banner_url: - assert model.guild_banner_url is make_guild_banner_url.return_value + assert member.guild_banner_url is make_guild_banner_url.return_value - def test_communication_disabled_until(self, model): - model.raw_communication_disabled_until = datetime.datetime(2021, 11, 22) + def test_communication_disabled_until(self, member: guilds.Member): + member.raw_communication_disabled_until = datetime.datetime(2021, 11, 22) with mock.patch.object(time, "utc_datetime", return_value=datetime.datetime(2021, 10, 18)): - assert model.communication_disabled_until() == datetime.datetime(2021, 11, 22) + assert member.communication_disabled_until() == datetime.datetime(2021, 11, 22) - def test_communication_disabled_until_when_raw_communication_disabled_until_is_None(self, model): - model.raw_communication_disabled_until = None + def test_communication_disabled_until_when_raw_communication_disabled_until_is_None(self, member: guilds.Member): + member.raw_communication_disabled_until = None with mock.patch.object(time, "utc_datetime", return_value=datetime.datetime(2021, 10, 18)): - assert model.communication_disabled_until() is None + assert member.communication_disabled_until() is None - def test_communication_disabled_until_when_raw_communication_disabled_until_is_in_the_past(self, model): - model.raw_communication_disabled_until = datetime.datetime(2021, 10, 18) + def test_communication_disabled_until_when_raw_communication_disabled_until_is_in_the_past( + self, member: guilds.Member + ): + member.raw_communication_disabled_until = datetime.datetime(2021, 10, 18) with mock.patch.object(time, "utc_datetime", return_value=datetime.datetime(2021, 11, 22)): - assert model.communication_disabled_until() is None + assert member.communication_disabled_until() is None - def test_make_avatar_url_format_set_to_deprecated_ext_argument_if_provided(self, model): + def test_make_avatar_url_format_set_to_deprecated_ext_argument_if_provided(self, member: guilds.Member): with mock.patch.object( routes, "CDN_MEMBER_AVATAR", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_guild_avatar_url(ext="JPEG") == "file" + assert member.make_guild_avatar_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, guild_id=456, user_id=123, hash="dab", size=4096, file_format="JPEG", lossless=True ) - def test_make_avatar_url(self, model): - result = model.make_avatar_url(file_format="PNG", size=4096) + def test_make_avatar_url(self, member: guilds.Member): + with mock.patch.object(member.user, "make_avatar_url") as patched_make_avatar_url: + result = member.make_avatar_url(file_format="PNG", size=4096) - model.user.make_avatar_url.assert_called_once_with( + patched_make_avatar_url.assert_called_once_with( file_format="PNG", size=4096, lossless=True, ext=undefined.UNDEFINED ) - assert result is model.user.make_avatar_url.return_value + assert result is patched_make_avatar_url.return_value - def test_make_guild_avatar_url_when_no_hash(self, model): - model.guild_avatar_hash = None - assert model.make_guild_avatar_url(file_format="PNG", size=1024) is None + def test_make_guild_avatar_url_when_no_hash(self, member: guilds.Member): + member.guild_avatar_hash = None + assert member.make_guild_avatar_url(file_format="PNG", size=1024) is None - def test_make_guild_avatar_url_when_format_is_None_and_avatar_hash_is_for_animated(self, model): - model.guild_avatar_hash = "a_18dnf8dfbakfdh" + def test_make_guild_avatar_url_when_format_is_None_and_avatar_hash_is_for_animated(self, member: guilds.Member): + member.guild_avatar_hash = "a_18dnf8dfbakfdh" with mock.patch.object( routes, "CDN_MEMBER_AVATAR", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_guild_avatar_url(file_format=None, size=4096) == "file" + assert member.make_guild_avatar_url(file_format="PNG", size=4096) == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, - guild_id=model.guild_id, - user_id=model.id, - hash=model.guild_avatar_hash, + guild_id=member.guild_id, + user_id=member.id, + hash=member.guild_avatar_hash, size=4096, - file_format="GIF", + file_format="PNG", lossless=True, ) - def test_make_guild_avatar_url_when_format_is_None_and_avatar_hash_is_not_for_animated(self, model): - model.guild_avatar_hash = "18dnf8dfbakfdh" + def test_make_guild_avatar_url_when_format_is_None_and_avatar_hash_is_not_for_animated(self, member: guilds.Member): + member.guild_avatar_hash = "18dnf8dfbakfdh" with mock.patch.object( routes, "CDN_MEMBER_AVATAR", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_guild_avatar_url(file_format=None, size=4096) == "file" + assert member.make_guild_avatar_url(file_format="PNG", size=4096) == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, - guild_id=model.guild_id, - user_id=model.id, - hash=model.guild_avatar_hash, + guild_id=member.guild_id, + user_id=member.id, + hash=member.guild_avatar_hash, size=4096, file_format="PNG", lossless=True, ) - def test_make_guild_avatar_url_with_all_args(self, model): - model.guild_avatar_hash = "18dnf8dfbakfdh" + def test_make_guild_avatar_url_with_all_args(self, member: guilds.Member): + member.guild_avatar_hash = "18dnf8dfbakfdh" with mock.patch.object( routes, "CDN_MEMBER_AVATAR", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_guild_avatar_url(file_format="URL", size=4096) == "file" + assert member.make_guild_avatar_url(file_format="PNG", size=4096) == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, - guild_id=model.guild_id, - user_id=model.id, - hash=model.guild_avatar_hash, + guild_id=member.guild_id, + user_id=member.id, + hash=member.guild_avatar_hash, size=4096, - file_format="URL", + file_format="PNG", lossless=True, ) - def test_make_banner_url_format_set_to_deprecated_ext_argument_if_provided(self, model): + def test_make_banner_url_format_set_to_deprecated_ext_argument_if_provided(self, member: guilds.Member): with mock.patch.object( routes, "CDN_MEMBER_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_guild_banner_url(ext="JPEG") == "file" + assert member.make_guild_banner_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, - user_id=model.id, - guild_id=model.guild_id, - hash=model.guild_banner_hash, + user_id=member.id, + guild_id=member.guild_id, + hash=member.guild_banner_hash, size=4096, file_format="JPEG", lossless=True, ) - def test_make_banner_url(self, model): - result = model.make_banner_url(file_format="PNG", size=4096) + def test_make_banner_url(self, member: guilds.Member): + with mock.patch.object(member.user, "make_banner_url") as patched_make_banner_url: + result = member.make_banner_url(file_format="PNG", size=4096) - model.user.make_banner_url.assert_called_once_with( + patched_make_banner_url.assert_called_once_with( file_format="PNG", size=4096, lossless=True, ext=undefined.UNDEFINED ) - assert result is model.user.make_banner_url.return_value + assert result is patched_make_banner_url.return_value - def test_make_guild_banner_url_when_no_hash(self, model): - model.guild_banner_hash = None - assert model.make_guild_banner_url(file_format="PNG", size=1024) is None + def test_make_guild_banner_url_when_no_hash(self, member: guilds.Member): + member.guild_banner_hash = None + assert member.make_guild_banner_url(file_format="PNG", size=1024) is None - def test_make_guild_banner_url_when_format_is_None_and_banner_hash_is_for_animated(self, model): - model.guild_banner_hash = "a_18dnf8dfbakfdh" + def test_make_guild_banner_url_when_format_is_None_and_banner_hash_is_for_animated(self, member: guilds.Member): + member.guild_banner_hash = "a_18dnf8dfbakfdh" with mock.patch.object( routes, "CDN_MEMBER_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_guild_banner_url(file_format=None, size=4096) == "file" + assert member.make_guild_banner_url(file_format="GIF", size=4096) == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, - user_id=model.id, - guild_id=model.guild_id, - hash=model.guild_banner_hash, + user_id=member.id, + guild_id=member.guild_id, + hash=member.guild_banner_hash, size=4096, file_format="GIF", lossless=True, ) - def test_make_guild_banner_url_when_format_is_None_and_banner_hash_is_not_for_animated(self, model): - model.guild_banner_hash = "18dnf8dfbakfdh" + def test_make_guild_banner_url_when_format_is_None_and_banner_hash_is_not_for_animated(self, member: guilds.Member): + member.guild_banner_hash = "18dnf8dfbakfdh" with mock.patch.object( routes, "CDN_MEMBER_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_guild_banner_url(file_format=None, size=4096) == "file" + assert member.make_guild_banner_url(file_format="PNG", size=4096) == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, - user_id=model.id, - guild_id=model.guild_id, - hash=model.guild_banner_hash, + user_id=member.id, + guild_id=member.guild_id, + hash=member.guild_banner_hash, size=4096, file_format="PNG", lossless=True, ) - def test_make_guild_banner_url_with_all_args(self, model): - model.guild_banner_hash = "18dnf8dfbakfdh" + def test_make_guild_banner_url_with_all_args(self, member: guilds.Member): + member.guild_banner_hash = "18dnf8dfbakfdh" with mock.patch.object( routes, "CDN_MEMBER_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_guild_banner_url(file_format="URL", size=4096) == "file" + assert member.make_guild_banner_url(file_format="PNG", size=4096) == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, - guild_id=model.guild_id, - user_id=model.id, - hash=model.guild_banner_hash, + guild_id=member.guild_id, + user_id=member.id, + hash=member.guild_banner_hash, size=4096, - file_format="URL", + file_format="PNG", lossless=True, ) @pytest.mark.asyncio - async def test_fetch_dm_channel(self, model): - model.user.fetch_dm_channel = mock.AsyncMock() + async def test_fetch_dm_channel(self, member: guilds.Member): + member.user.fetch_dm_channel = mock.AsyncMock() - assert await model.fetch_dm_channel() is model.user.fetch_dm_channel.return_value + assert await member.fetch_dm_channel() is member.user.fetch_dm_channel.return_value - model.user.fetch_dm_channel.assert_awaited_once_with() + member.user.fetch_dm_channel.assert_awaited_once_with() @pytest.mark.asyncio - async def test_fetch_self(self, model): - model.user.app.rest.fetch_member = mock.AsyncMock() + async def test_fetch_self(self, member: guilds.Member): + member.user.app.rest.fetch_member = mock.AsyncMock() - assert await model.fetch_self() is model.user.app.rest.fetch_member.return_value + assert await member.fetch_self() is member.user.app.rest.fetch_member.return_value - model.user.app.rest.fetch_member.assert_awaited_once_with(456, 123) + member.user.app.rest.fetch_member.assert_awaited_once_with(456, 123) @pytest.mark.asyncio - async def test_fetch_roles(self, model): - model.user.app.rest.fetch_roles = mock.AsyncMock() - await model.fetch_roles() - model.user.app.rest.fetch_roles.assert_awaited_once_with(456) + async def test_fetch_roles(self, member: guilds.Member): + member.user.app.rest.fetch_roles = mock.AsyncMock() + await member.fetch_roles() + member.user.app.rest.fetch_roles.assert_awaited_once_with(456) @pytest.mark.asyncio - async def test_ban(self, model): - model.app.rest.ban_user = mock.AsyncMock() + async def test_ban(self, member: guilds.Member): + member.app.rest.ban_user = mock.AsyncMock() - await model.ban(delete_message_seconds=600, reason="bored") + await member.ban(delete_message_seconds=600, reason="bored") - model.app.rest.ban_user.assert_awaited_once_with(456, 123, delete_message_seconds=600, reason="bored") + member.app.rest.ban_user.assert_awaited_once_with(456, 123, delete_message_seconds=600, reason="bored") @pytest.mark.asyncio - async def test_unban(self, model): - model.app.rest.unban_user = mock.AsyncMock() + async def test_unban(self, member: guilds.Member): + member.app.rest.unban_user = mock.AsyncMock() - await model.unban(reason="Unbored") + await member.unban(reason="Unbored") - model.app.rest.unban_user.assert_awaited_once_with(456, 123, reason="Unbored") + member.app.rest.unban_user.assert_awaited_once_with(456, 123, reason="Unbored") @pytest.mark.asyncio - async def test_kick(self, model): - model.app.rest.kick_user = mock.AsyncMock() + async def test_kick(self, member: guilds.Member): + member.app.rest.kick_user = mock.AsyncMock() - await model.kick(reason="bored") + await member.kick(reason="bored") - model.app.rest.kick_user.assert_awaited_once_with(456, 123, reason="bored") + member.app.rest.kick_user.assert_awaited_once_with(456, 123, reason="bored") @pytest.mark.asyncio - async def test_add_role(self, model): - model.app.rest.add_role_to_member = mock.AsyncMock() + async def test_add_role(self, member: guilds.Member): + member.app.rest.add_role_to_member = mock.AsyncMock() - await model.add_role(563412, reason="Promoted") + await member.add_role(563412, reason="Promoted") - model.app.rest.add_role_to_member.assert_awaited_once_with(456, 123, 563412, reason="Promoted") + member.app.rest.add_role_to_member.assert_awaited_once_with(456, 123, 563412, reason="Promoted") @pytest.mark.asyncio - async def test_remove_role(self, model): - model.app.rest.remove_role_from_member = mock.AsyncMock() + async def test_remove_role(self, member: guilds.Member): + member.app.rest.remove_role_from_member = mock.AsyncMock() - await model.remove_role(563412, reason="Demoted") + await member.remove_role(563412, reason="Demoted") - model.app.rest.remove_role_from_member.assert_awaited_once_with(456, 123, 563412, reason="Demoted") + member.app.rest.remove_role_from_member.assert_awaited_once_with(456, 123, 563412, reason="Demoted") @pytest.mark.asyncio - async def test_edit(self, model): - model.app.rest.edit_member = mock.AsyncMock() + async def test_edit(self, member: guilds.Member): + member.app.rest.edit_member = mock.AsyncMock() disabled_until = datetime.datetime(2021, 11, 17) - edit = await model.edit( + edit = await member.edit( nickname="Imposter", roles=[123, 432, 345], mute=False, @@ -581,7 +594,7 @@ async def test_edit(self, model): reason="I'm God", ) - model.app.rest.edit_member.assert_awaited_once_with( + member.app.rest.edit_member.assert_awaited_once_with( 456, 123, nickname="Imposter", @@ -593,189 +606,213 @@ async def test_edit(self, model): reason="I'm God", ) - assert edit == model.app.rest.edit_member.return_value + assert edit == member.app.rest.edit_member.return_value - def test_default_avatar_url_property(self, model): - assert model.default_avatar_url is model.user.default_avatar_url + def test_default_avatar_url_property(self, member: guilds.Member): + assert member.default_avatar_url is member.user.default_avatar_url - def test_display_name_property_when_nickname(self, model): - assert model.display_name == "davb" + def test_display_name_property_when_nickname(self, member: guilds.Member): + assert member.display_name == "davb" - def test_display_name_property_when_no_nickname(self, model): - model.nickname = None - assert model.display_name is model.user.global_name + def test_display_name_property_when_no_nickname(self, member: guilds.Member): + member.nickname = None + assert member.display_name is member.user.global_name - def test_mention_property(self, model): - assert model.mention == model.user.mention + def test_mention_property(self, member: guilds.Member): + assert member.mention == member.user.mention - def test_get_guild(self, model): + def test_get_guild(self, member: guilds.Member): guild = mock.Mock(id=456) - model.user.app.cache.get_guild.side_effect = [guild] - - assert model.get_guild() == guild - - model.user.app.cache.get_guild.assert_has_calls([mock.call(456)]) - def test_get_guild_when_guild_not_in_cache(self, model): - model.user.app.cache.get_guild.side_effect = [None] + with ( + mock.patch.object(member.user.app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_guild", side_effect=[guild]) as patched_get_guild, + ): + assert member.get_guild() == guild - assert model.get_guild() is None + patched_get_guild.assert_has_calls([mock.call(456)]) - model.user.app.cache.get_guild.assert_has_calls([mock.call(456)]) + def test_get_guild_when_guild_not_in_cache(self, member: guilds.Member): + with ( + mock.patch.object(member.user.app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_guild", side_effect=[None]) as patched_get_guild, + ): + assert member.get_guild() is None - def test_get_guild_when_no_cache_trait(self, model): - model.user.app = object() + patched_get_guild.assert_has_calls([mock.call(456)]) - assert model.get_guild() is None + def test_get_guild_when_no_cache_trait(self, member: guilds.Member): + with ( + mock.patch.object(member.user.app, "cache", mock.Mock()) as mocked_cache, + mock.patch.object(mocked_cache, "get_guild", mock.Mock(return_value=None)), + ): + assert member.get_guild() is None - def test_get_roles(self, model): + def test_get_roles(self, member: guilds.Member): role1 = mock.Mock(id=321, position=2) role2 = mock.Mock(id=654, position=1) - model.user.app.cache.get_role.side_effect = [role1, role2] - model.role_ids = [321, 654] + with ( + mock.patch.object(member.user.app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_role", side_effect=[role1, role2]) as patched_get_role, + ): + member.role_ids = [snowflakes.Snowflake(321), snowflakes.Snowflake(654)] - assert model.get_roles() == [role1, role2] + assert member.get_roles() == [role1, role2] - model.user.app.cache.get_role.assert_has_calls([mock.call(321), mock.call(654)]) + patched_get_role.assert_has_calls([mock.call(321), mock.call(654)]) - def test_get_roles_when_role_ids_not_in_cache(self, model): + def test_get_roles_when_role_ids_not_in_cache(self, member: guilds.Member): role = mock.Mock(id=456, position=1) - model.user.app.cache.get_role.side_effect = [None, role] - model.role_ids = [321, 456] + with ( + mock.patch.object(member.user.app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_role", side_effect=[None, role]) as patched_get_role, + ): + member.role_ids = [snowflakes.Snowflake(321), snowflakes.Snowflake(456)] - assert model.get_roles() == [role] + assert member.get_roles() == [role] - model.user.app.cache.get_role.assert_has_calls([mock.call(321), mock.call(456)]) + patched_get_role.assert_has_calls([mock.call(321), mock.call(456)]) - def test_get_roles_when_empty_cache(self, model): - model.role_ids = [132, 432] - model.user.app.cache.get_role.side_effect = [None, None] + def test_get_roles_when_empty_cache(self, member: guilds.Member): + member.role_ids = [snowflakes.Snowflake(132), snowflakes.Snowflake(432)] + with ( + mock.patch.object(member.user.app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_role", side_effect=[None, None]) as patched_get_role, + ): + assert member.get_roles() == [] - assert model.get_roles() == [] + patched_get_role.assert_has_calls([mock.call(132), mock.call(432)]) - model.user.app.cache.get_role.assert_has_calls([mock.call(132), mock.call(432)]) + def test_get_roles_when_no_cache_trait(self, member: guilds.Member): + with mock.patch.object(member.user, "app", mock.Mock(traits.RESTAware)): + assert member.get_roles() == [] - def test_get_roles_when_no_cache_trait(self, model): - model.user.app = object() - - assert model.get_roles() == [] - - def test_get_top_role(self, model): + def test_get_top_role(self, member: guilds.Member): role1 = mock.Mock(id=321, position=2) role2 = mock.Mock(id=654, position=1) with mock.patch.object(guilds.Member, "get_roles", return_value=[role1, role2]): - assert model.get_top_role() is role1 + assert member.get_top_role() is role1 - def test_get_top_role_when_roles_is_empty(self, model): + def test_get_top_role_when_roles_is_empty(self, member: guilds.Member): with mock.patch.object(guilds.Member, "get_roles", return_value=[]): - assert model.get_top_role() is None + assert member.get_top_role() is None - def test_get_presence(self, model): - assert model.get_presence() is model.user.app.cache.get_presence.return_value - model.user.app.cache.get_presence.assert_called_once_with(456, 123) + def test_get_presence(self, member: guilds.Member): + with ( + mock.patch.object(member.user.app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_presence") as patched_get_presence, + ): + assert member.get_presence() is patched_get_presence.return_value + patched_get_presence.assert_called_once_with(456, 123) - def test_get_presence_when_no_cache_trait(self, model): - model.user.app = object() - assert model.get_presence() is None + def test_get_presence_when_no_cache_trait(self, member: guilds.Member): + with mock.patch.object(member.user, "app", mock.Mock(traits.RESTAware)): + assert member.get_presence() is None class TestPartialGuild: @pytest.fixture - def model(self, mock_app): + def partial_guild(self, mock_app: traits.RESTAware) -> guilds.PartialGuild: return guilds.PartialGuild(app=mock_app, id=snowflakes.Snowflake(90210), icon_hash="yeet", name="hikari") - def test_str_operator(self, model): - assert str(model) == "hikari" + def test_str_operator(self, partial_guild: guilds.PartialGuild): + assert str(partial_guild) == "hikari" - def test_shard_id_property(self, model): - model.app.shard_count = 4 - assert model.shard_id == 0 + def test_shard_id_property(self, partial_guild: guilds.PartialGuild): + with mock.patch.object(partial_guild.app, "shard_count", 4): + assert partial_guild.shard_id == 0 - def test_shard_id_when_not_shard_aware(self, model): - model.app = object() + def test_shard_id_when_not_shard_aware(self, partial_guild: guilds.PartialGuild): + partial_guild.app = mock.Mock(traits.RESTAware) - assert model.shard_id is None + assert partial_guild.shard_id is None - def test_make_icon_url_format_set_to_deprecated_ext_argument_if_provided(self, model): + def test_make_icon_url_format_set_to_deprecated_ext_argument_if_provided(self, partial_guild: guilds.PartialGuild): with mock.patch.object( routes, "CDN_GUILD_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_icon_url(ext="JPEG") == "file" + assert partial_guild.make_icon_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, guild_id=90210, hash="yeet", size=4096, file_format="JPEG", lossless=True ) - def test_icon_url(self, model): + def test_icon_url(self, partial_guild: guilds.PartialGuild): icon = object() with mock.patch.object(guilds.PartialGuild, "make_icon_url", return_value=icon): - assert model.icon_url is icon + assert partial_guild.icon_url is icon - def test_make_icon_url_when_no_hash(self, model): - model.icon_hash = None + def test_make_icon_url_when_no_hash(self, partial_guild: guilds.PartialGuild): + partial_guild.icon_hash = None - assert model.make_icon_url(file_format="PNG", size=2048) is None + assert partial_guild.make_icon_url(file_format="PNG", size=2048) is None - def test_make_icon_url_when_format_is_None_and_avatar_hash_is_for_animated(self, model): - model.icon_hash = "a_yeet" + def test_make_icon_url_when_format_is_None_and_avatar_hash_is_for_animated( + self, partial_guild: guilds.PartialGuild + ): + partial_guild.icon_hash = "a_yeet" with mock.patch.object( routes, "CDN_GUILD_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_icon_url(file_format=None, size=1024) == "file" + assert partial_guild.make_icon_url(file_format="GIF", size=1024) == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, guild_id=90210, hash="a_yeet", size=1024, file_format="GIF", lossless=True ) - def test_make_icon_url_when_format_is_None_and_avatar_hash_is_not_for_animated(self, model): + def test_make_icon_url_when_format_is_None_and_avatar_hash_is_not_for_animated( + self, partial_guild: guilds.PartialGuild + ): with mock.patch.object( routes, "CDN_GUILD_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_icon_url(file_format=None, size=4096) == "file" + assert partial_guild.make_icon_url(file_format="PNG", size=4096) == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, guild_id=90210, hash="yeet", size=4096, file_format="PNG", lossless=True ) - def test_make_icon_url_with_all_args(self, model): + def test_make_icon_url_with_all_args(self, partial_guild: guilds.PartialGuild): with mock.patch.object( routes, "CDN_GUILD_ICON", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_icon_url(file_format="URL", size=2048) == "file" + assert partial_guild.make_icon_url(file_format="PNG", size=2048) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, guild_id=90210, hash="yeet", size=2048, file_format="URL", lossless=True + urls.CDN_URL, guild_id=90210, hash="yeet", size=2048, file_format="PNG", lossless=True ) @pytest.mark.asyncio - async def test_kick(self, model): - model.app.rest.kick_user = mock.AsyncMock() - await model.kick(4321, reason="Go away!") + async def test_kick(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.kick_user = mock.AsyncMock() + await partial_guild.kick(4321, reason="Go away!") - model.app.rest.kick_user.assert_awaited_once_with(90210, 4321, reason="Go away!") + partial_guild.app.rest.kick_user.assert_awaited_once_with(90210, 4321, reason="Go away!") @pytest.mark.asyncio - async def test_ban(self, model): - model.app.rest.ban_user = mock.AsyncMock() + async def test_ban(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.ban_user = mock.AsyncMock() - await model.ban(4321, delete_message_seconds=864000, reason="Go away!") + await partial_guild.ban(4321, delete_message_seconds=864000, reason="Go away!") - model.app.rest.ban_user.assert_awaited_once_with(90210, 4321, delete_message_seconds=864000, reason="Go away!") + partial_guild.app.rest.ban_user.assert_awaited_once_with( + 90210, 4321, delete_message_seconds=864000, reason="Go away!" + ) @pytest.mark.asyncio - async def test_unban(self, model): - model.app.rest.unban_user = mock.AsyncMock() - await model.unban(4321, reason="Comeback!!") + async def test_unban(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.unban_user = mock.AsyncMock() + await partial_guild.unban(4321, reason="Comeback!!") - model.app.rest.unban_user.assert_awaited_once_with(90210, 4321, reason="Comeback!!") + partial_guild.app.rest.unban_user.assert_awaited_once_with(90210, 4321, reason="Comeback!!") @pytest.mark.asyncio - async def test_edit(self, model): - model.app.rest.edit_guild = mock.AsyncMock() - edited_guild = await model.edit( + async def test_edit(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.edit_guild = mock.AsyncMock() + edited_guild = await partial_guild.edit( name="chad server", verification_level=guilds.GuildVerificationLevel.LOW, default_message_notifications=guilds.GuildMessageNotificationsLevel.ALL_MESSAGES, @@ -787,7 +824,7 @@ async def test_edit(self, model): reason="beep boop", ) - model.app.rest.edit_guild.assert_awaited_once_with( + partial_guild.app.rest.edit_guild.assert_awaited_once_with( 90210, name="chad server", verification_level=guilds.GuildVerificationLevel.LOW, @@ -807,102 +844,102 @@ async def test_edit(self, model): reason="beep boop", ) - assert edited_guild is model.app.rest.edit_guild.return_value + assert edited_guild is partial_guild.app.rest.edit_guild.return_value @pytest.mark.asyncio - async def test_set_incident_actions(self, model: guilds.PartialGuild): - model.app.rest.set_guild_incident_actions = mock.AsyncMock() + async def test_set_incident_actions(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.set_guild_incident_actions = mock.AsyncMock() - updated_incident_data = await model.set_incident_actions( + updated_incident_data = await partial_guild.set_incident_actions( invites_disabled_until=datetime.datetime(2021, 11, 17), dms_disabled_until=datetime.datetime(2021, 11, 18) ) - assert updated_incident_data is model.app.rest.set_guild_incident_actions.return_value - model.app.rest.set_guild_incident_actions.assert_awaited_once_with( + assert updated_incident_data is partial_guild.app.rest.set_guild_incident_actions.return_value + partial_guild.app.rest.set_guild_incident_actions.assert_awaited_once_with( 90210, invites_disabled_until=datetime.datetime(2021, 11, 17), dms_disabled_until=datetime.datetime(2021, 11, 18), ) @pytest.mark.asyncio - async def test_fetch_emojis(self, model): - model.app.rest.fetch_guild_emojis = mock.AsyncMock() + async def test_fetch_emojis(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.fetch_guild_emojis = mock.AsyncMock() - emojis = await model.fetch_emojis() + emojis = await partial_guild.fetch_emojis() - model.app.rest.fetch_guild_emojis.assert_awaited_once_with(model.id) - assert emojis is model.app.rest.fetch_guild_emojis.return_value + partial_guild.app.rest.fetch_guild_emojis.assert_awaited_once_with(partial_guild.id) + assert emojis is partial_guild.app.rest.fetch_guild_emojis.return_value @pytest.mark.asyncio - async def test_fetch_emoji(self, model): - model.app.rest.fetch_emoji = mock.AsyncMock() + async def test_fetch_emoji(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.fetch_emoji = mock.AsyncMock() - emoji = await model.fetch_emoji(349) + emoji = await partial_guild.fetch_emoji(349) - model.app.rest.fetch_emoji.assert_awaited_once_with(model.id, 349) - assert emoji is model.app.rest.fetch_emoji.return_value + partial_guild.app.rest.fetch_emoji.assert_awaited_once_with(partial_guild.id, 349) + assert emoji is partial_guild.app.rest.fetch_emoji.return_value @pytest.mark.asyncio - async def test_fetch_stickers(self, model): - model.app.rest.fetch_guild_stickers = mock.AsyncMock() + async def test_fetch_stickers(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.fetch_guild_stickers = mock.AsyncMock() - stickers = await model.fetch_stickers() + stickers = await partial_guild.fetch_stickers() - model.app.rest.fetch_guild_stickers.assert_awaited_once_with(model.id) - assert stickers is model.app.rest.fetch_guild_stickers.return_value + partial_guild.app.rest.fetch_guild_stickers.assert_awaited_once_with(partial_guild.id) + assert stickers is partial_guild.app.rest.fetch_guild_stickers.return_value @pytest.mark.asyncio - async def test_fetch_sticker(self, model): - model.app.rest.fetch_guild_sticker = mock.AsyncMock() + async def test_fetch_sticker(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.fetch_guild_sticker = mock.AsyncMock() - sticker = await model.fetch_sticker(6969) + sticker = await partial_guild.fetch_sticker(6969) - model.app.rest.fetch_guild_sticker.assert_awaited_once_with(model.id, 6969) - assert sticker is model.app.rest.fetch_guild_sticker.return_value + partial_guild.app.rest.fetch_guild_sticker.assert_awaited_once_with(partial_guild.id, 6969) + assert sticker is partial_guild.app.rest.fetch_guild_sticker.return_value @pytest.mark.asyncio - async def test_create_sticker(self, model): - model.app.rest.create_sticker = mock.AsyncMock() - file = object() + async def test_create_sticker(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.create_sticker = mock.AsyncMock() + file = mock.Mock() - sticker = await model.create_sticker( + sticker = await partial_guild.create_sticker( "NewSticker", "funny", file, description="A sticker", reason="blah blah blah" ) - assert sticker is model.app.rest.create_sticker.return_value + assert sticker is partial_guild.app.rest.create_sticker.return_value - model.app.rest.create_sticker.assert_awaited_once_with( + partial_guild.app.rest.create_sticker.assert_awaited_once_with( 90210, "NewSticker", "funny", file, description="A sticker", reason="blah blah blah" ) @pytest.mark.asyncio - async def test_edit_sticker(self, model): - model.app.rest.edit_sticker = mock.AsyncMock() + async def test_edit_sticker(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.edit_sticker = mock.AsyncMock() - sticker = await model.edit_sticker(4567, name="Brilliant", tag="parmesan", description="amazing") + sticker = await partial_guild.edit_sticker(4567, name="Brilliant", tag="parmesan", description="amazing") - model.app.rest.edit_sticker.assert_awaited_once_with( + partial_guild.app.rest.edit_sticker.assert_awaited_once_with( 90210, 4567, name="Brilliant", tag="parmesan", description="amazing", reason=undefined.UNDEFINED ) - assert sticker is model.app.rest.edit_sticker.return_value + assert sticker is partial_guild.app.rest.edit_sticker.return_value @pytest.mark.asyncio - async def test_delete_sticker(self, model): - model.app.rest.delete_sticker = mock.AsyncMock() + async def test_delete_sticker(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.delete_sticker = mock.AsyncMock() - sticker = await model.delete_sticker(951) + sticker = await partial_guild.delete_sticker(951) - model.app.rest.delete_sticker.assert_awaited_once_with(90210, 951, reason=undefined.UNDEFINED) + partial_guild.app.rest.delete_sticker.assert_awaited_once_with(90210, 951, reason=undefined.UNDEFINED) - assert sticker is model.app.rest.delete_sticker.return_value + assert sticker is partial_guild.app.rest.delete_sticker.return_value @pytest.mark.asyncio - async def test_create_category(self, model): - model.app.rest.create_guild_category = mock.AsyncMock() + async def test_create_category(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.create_guild_category = mock.AsyncMock() - category = await model.create_category("very cool category", position=2) + category = await partial_guild.create_category("very cool category", position=2) - model.app.rest.create_guild_category.assert_awaited_once_with( + partial_guild.app.rest.create_guild_category.assert_awaited_once_with( 90210, "very cool category", position=2, @@ -910,17 +947,17 @@ async def test_create_category(self, model): reason=undefined.UNDEFINED, ) - assert category is model.app.rest.create_guild_category.return_value + assert category is partial_guild.app.rest.create_guild_category.return_value @pytest.mark.asyncio - async def test_create_text_channel(self, model): - model.app.rest.create_guild_text_channel = mock.AsyncMock() + async def test_create_text_channel(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.create_guild_text_channel = mock.AsyncMock() - text_channel = await model.create_text_channel( + text_channel = await partial_guild.create_text_channel( "cool text channel", position=3, nsfw=False, rate_limit_per_user=30 ) - model.app.rest.create_guild_text_channel.assert_awaited_once_with( + partial_guild.app.rest.create_guild_text_channel.assert_awaited_once_with( 90210, "cool text channel", position=3, @@ -932,17 +969,17 @@ async def test_create_text_channel(self, model): reason=undefined.UNDEFINED, ) - assert text_channel is model.app.rest.create_guild_text_channel.return_value + assert text_channel is partial_guild.app.rest.create_guild_text_channel.return_value @pytest.mark.asyncio - async def test_create_news_channel(self, model): - model.app.rest.create_guild_news_channel = mock.AsyncMock() + async def test_create_news_channel(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.create_guild_news_channel = mock.AsyncMock() - news_channel = await model.create_news_channel( + news_channel = await partial_guild.create_news_channel( "cool news channel", position=1, nsfw=False, rate_limit_per_user=420 ) - model.app.rest.create_guild_news_channel.assert_awaited_once_with( + partial_guild.app.rest.create_guild_news_channel.assert_awaited_once_with( 90210, "cool news channel", position=1, @@ -954,17 +991,17 @@ async def test_create_news_channel(self, model): reason=undefined.UNDEFINED, ) - assert news_channel is model.app.rest.create_guild_news_channel.return_value + assert news_channel is partial_guild.app.rest.create_guild_news_channel.return_value @pytest.mark.asyncio - async def test_create_forum_channel(self, model): - model.app.rest.create_guild_forum_channel = mock.AsyncMock() + async def test_create_forum_channel(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.create_guild_forum_channel = mock.AsyncMock() - forum_channel = await model.create_forum_channel( + forum_channel = await partial_guild.create_forum_channel( "cool forum channel", position=1, nsfw=False, rate_limit_per_user=420 ) - model.app.rest.create_guild_forum_channel.assert_awaited_once_with( + partial_guild.app.rest.create_guild_forum_channel.assert_awaited_once_with( 90210, "cool forum channel", position=1, @@ -982,17 +1019,17 @@ async def test_create_forum_channel(self, model): default_reaction_emoji=undefined.UNDEFINED, ) - assert forum_channel is model.app.rest.create_guild_forum_channel.return_value + assert forum_channel is partial_guild.app.rest.create_guild_forum_channel.return_value @pytest.mark.asyncio - async def test_create_voice_channel(self, model): - model.app.rest.create_guild_voice_channel = mock.AsyncMock() + async def test_create_voice_channel(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.create_guild_voice_channel = mock.AsyncMock() - voice_channel = await model.create_voice_channel( + voice_channel = await partial_guild.create_voice_channel( "cool voice channel", position=1, bitrate=3200, video_quality_mode=2 ) - model.app.rest.create_guild_voice_channel.assert_awaited_once_with( + partial_guild.app.rest.create_guild_voice_channel.assert_awaited_once_with( 90210, "cool voice channel", position=1, @@ -1005,15 +1042,17 @@ async def test_create_voice_channel(self, model): reason=undefined.UNDEFINED, ) - assert voice_channel is model.app.rest.create_guild_voice_channel.return_value + assert voice_channel is partial_guild.app.rest.create_guild_voice_channel.return_value @pytest.mark.asyncio - async def test_create_stage_channel(self, model): - model.app.rest.create_guild_stage_channel = mock.AsyncMock() + async def test_create_stage_channel(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.create_guild_stage_channel = mock.AsyncMock() - stage_channel = await model.create_stage_channel("cool stage channel", position=1, bitrate=3200, user_limit=100) + stage_channel = await partial_guild.create_stage_channel( + "cool stage channel", position=1, bitrate=3200, user_limit=100 + ) - model.app.rest.create_guild_stage_channel.assert_awaited_once_with( + partial_guild.app.rest.create_guild_stage_channel.assert_awaited_once_with( 90210, "cool stage channel", position=1, @@ -1025,38 +1064,38 @@ async def test_create_stage_channel(self, model): reason=undefined.UNDEFINED, ) - assert stage_channel is model.app.rest.create_guild_stage_channel.return_value + assert stage_channel is partial_guild.app.rest.create_guild_stage_channel.return_value @pytest.mark.asyncio - async def test_delete_channel(self, model): + async def test_delete_channel(self, partial_guild: guilds.PartialGuild): mock_channel = mock.Mock(channels_.GuildChannel) - model.app.rest.delete_channel = mock.AsyncMock(return_value=mock_channel) + partial_guild.app.rest.delete_channel = mock.AsyncMock(return_value=mock_channel) - deleted_channel = await model.delete_channel(1288820) + deleted_channel = await partial_guild.delete_channel(1288820) - model.app.rest.delete_channel.assert_awaited_once_with(1288820) - assert deleted_channel is model.app.rest.delete_channel.return_value + partial_guild.app.rest.delete_channel.assert_awaited_once_with(1288820) + assert deleted_channel is partial_guild.app.rest.delete_channel.return_value @pytest.mark.asyncio - async def test_fetch_guild(self, model): - model.app.rest.fetch_guild = mock.AsyncMock(return_value=model) + async def test_fetch_guild(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.fetch_guild = mock.AsyncMock(return_value=partial_guild) - assert await model.fetch_self() is model.app.rest.fetch_guild.return_value - model.app.rest.fetch_guild.assert_awaited_once_with(model.id) + assert await partial_guild.fetch_self() is partial_guild.app.rest.fetch_guild.return_value + partial_guild.app.rest.fetch_guild.assert_awaited_once_with(partial_guild.id) @pytest.mark.asyncio - async def test_fetch_roles(self, model): - model.app.rest.fetch_roles = mock.AsyncMock() + async def test_fetch_roles(self, partial_guild: guilds.PartialGuild): + partial_guild.app.rest.fetch_roles = mock.AsyncMock() - roles = await model.fetch_roles() + roles = await partial_guild.fetch_roles() - model.app.rest.fetch_roles.assert_awaited_once_with(90210) - assert roles is model.app.rest.fetch_roles.return_value + partial_guild.app.rest.fetch_roles.assert_awaited_once_with(90210) + assert roles is partial_guild.app.rest.fetch_roles.return_value class TestGuildPreview: @pytest.fixture - def model(self, mock_app): + def guild_preview(self, mock_app: traits.RESTAware) -> guilds.GuildPreview: return guilds.GuildPreview( app=mock_app, features=["huge super secret nsfw channel"], @@ -1071,76 +1110,80 @@ def model(self, mock_app): description="the place for quality shitposting!", ) - def test_make_splash_url_format_set_to_deprecated_ext_argument_if_provided(self, model): + def test_make_splash_url_format_set_to_deprecated_ext_argument_if_provided( + self, guild_preview: guilds.GuildPreview + ): with mock.patch.object( routes, "CDN_GUILD_SPLASH", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_splash_url(ext="JPEG") == "file" + assert guild_preview.make_splash_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, guild_id=123, hash="dis is also mah splash hash", size=4096, file_format="JPEG", lossless=True ) - def test_make_discovery_splash_url_format_set_to_deprecated_ext_argument_if_provided(self, model): - model.discovery_splash_hash = "18dnf8dfbakfdh" + def test_make_discovery_splash_url_format_set_to_deprecated_ext_argument_if_provided( + self, guild_preview: guilds.GuildPreview + ): + guild_preview.discovery_splash_hash = "18dnf8dfbakfdh" with mock.patch.object( routes, "CDN_GUILD_DISCOVERY_SPLASH", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_discovery_splash_url(ext="JPEG") == "file" + assert guild_preview.make_discovery_splash_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, guild_id=123, hash="18dnf8dfbakfdh", size=4096, file_format="JPEG", lossless=True ) - def test_splash_url(self, model): + def test_splash_url(self, guild_preview: guilds.GuildPreview): splash = object() with mock.patch.object(guilds.GuildPreview, "make_splash_url", return_value=splash): - assert model.splash_url is splash + assert guild_preview.splash_url is splash - def test_make_splash_url_when_hash(self, model): - model.splash_hash = "18dnf8dfbakfdh" + def test_make_splash_url_when_hash(self, guild_preview: guilds.GuildPreview): + guild_preview.splash_hash = "18dnf8dfbakfdh" with mock.patch.object( routes, "CDN_GUILD_SPLASH", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_splash_url(file_format="URL", size=1024) == "file" + assert guild_preview.make_splash_url(file_format="PNG", size=1024) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, guild_id=123, hash="18dnf8dfbakfdh", size=1024, file_format="URL", lossless=True + urls.CDN_URL, guild_id=123, hash="18dnf8dfbakfdh", size=1024, file_format="PNG", lossless=True ) - def test_make_splash_url_when_no_hash(self, model): - model.splash_hash = None - assert model.make_splash_url(ext="png", size=512) is None + def test_make_splash_url_when_no_hash(self, guild_preview: guilds.GuildPreview): + guild_preview.splash_hash = None + assert guild_preview.make_splash_url(ext="png", size=512) is None - def test_discovery_splash_url(self, model): - discovery_splash = object() + def test_discovery_splash_url(self, guild_preview: guilds.GuildPreview): + discovery_splash = mock.Mock() with mock.patch.object(guilds.GuildPreview, "make_discovery_splash_url", return_value=discovery_splash): - assert model.discovery_splash_url is discovery_splash + assert guild_preview.discovery_splash_url is discovery_splash - def test_make_discovery_splash_url_when_hash(self, model): - model.discovery_splash_hash = "18dnf8dfbakfdh" + def test_make_discovery_splash_url_when_hash(self, guild_preview: guilds.GuildPreview): + guild_preview.discovery_splash_hash = "18dnf8dfbakfdh" with mock.patch.object( routes, "CDN_GUILD_DISCOVERY_SPLASH", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_discovery_splash_url(file_format="URL", size=2048) == "file" + assert guild_preview.make_discovery_splash_url(file_format="PNG", size=2048) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, guild_id=123, hash="18dnf8dfbakfdh", size=2048, file_format="URL", lossless=True + urls.CDN_URL, guild_id=123, hash="18dnf8dfbakfdh", size=2048, file_format="PNG", lossless=True ) - def test_make_discovery_splash_url_when_no_hash(self, model): - model.discovery_splash_hash = None - assert model.make_discovery_splash_url(file_format="PNG", size=4096) is None + def test_make_discovery_splash_url_when_no_hash(self, guild_preview: guilds.GuildPreview): + guild_preview.discovery_splash_hash = None + assert guild_preview.make_discovery_splash_url(file_format="PNG", size=4096) is None class TestGuild: @pytest.fixture - def model(self, mock_app): + def guild(self, mock_app: traits.RESTAware) -> guilds.Guild: return hikari_test_helpers.mock_class_namespace(guilds.Guild)( app=mock_app, id=snowflakes.Snowflake(123), @@ -1179,389 +1222,465 @@ def model(self, mock_app): system_channel_flags=guilds.GuildSystemChannelFlag.SUPPRESS_PREMIUM_SUBSCRIPTION, ) - def test_get_channels(self, model): - assert model.get_channels() is model.app.cache.get_guild_channels_view_for_guild.return_value - model.app.cache.get_guild_channels_view_for_guild.assert_called_once_with(123) - - def test_get_channels_when_no_cache_trait(self, model): - model.app = object() - assert model.get_channels() == {} - - def test_get_members(self, model): - assert model.get_members() is model.app.cache.get_members_view_for_guild.return_value - model.app.cache.get_members_view_for_guild.assert_called_once_with(123) - - def test_get_members_when_no_cache_trait(self, model): - model.app = object() - assert model.get_members() == {} - - def test_get_presences(self, model): - assert model.get_presences() is model.app.cache.get_presences_view_for_guild.return_value - model.app.cache.get_presences_view_for_guild.assert_called_once_with(123) - - def test_get_presences_when_no_cache_trait(self, model): - model.app = object() - assert model.get_presences() == {} - - def test_get_voice_states(self, model): - assert model.get_voice_states() is model.app.cache.get_voice_states_view_for_guild.return_value - model.app.cache.get_voice_states_view_for_guild.assert_called_once_with(123) - - def test_get_voice_states_when_no_cache_trait(self, model): - model.app = object() - assert model.get_voice_states() == {} - - def test_get_emojis(self, model): - assert model.get_emojis() is model.app.cache.get_emojis_view_for_guild.return_value - model.app.cache.get_emojis_view_for_guild.assert_called_once_with(123) - - def test_emojis_when_no_cache_trait(self, model): - model.app = object() - assert model.get_emojis() == {} - - def test_get_sticker(self, model): - model.app.cache.get_sticker.return_value.guild_id = model.id - assert model.get_sticker(456) is model.app.cache.get_sticker.return_value - model.app.cache.get_sticker.assert_called_once_with(456) - - def test_get_sticker_when_not_from_guild(self, model): - model.app.cache.get_sticker.return_value.guild_id = 546123123433 - - assert model.get_sticker(456) is None - - model.app.cache.get_sticker.assert_called_once_with(456) - - def test_get_sticker_when_no_cache_trait(self, model): - model.app = object() - assert model.get_sticker(1234) is None - - def test_get_stickers(self, model): - assert model.get_stickers() is model.app.cache.get_stickers_view_for_guild.return_value - model.app.cache.get_stickers_view_for_guild.assert_called_once_with(123) - - def test_get_stickers_when_no_cache_trait(self, model): - model.app = object() - assert model.get_stickers() == {} - - def test_roles(self, model): - assert model.get_roles() is model.app.cache.get_roles_view_for_guild.return_value - model.app.cache.get_roles_view_for_guild.assert_called_once_with(123) - - def test_get_roles_when_no_cache_trait(self, model): - model.app = object() - assert model.get_roles() == {} - - def test_get_emoji(self, model): - model.app.cache.get_emoji.return_value.guild_id = model.id - assert model.get_emoji(456) is model.app.cache.get_emoji.return_value - model.app.cache.get_emoji.assert_called_once_with(456) - - def test_get_emoji_when_not_from_guild(self, model): - model.app.cache.get_emoji.return_value.guild_id = 1233212 - - assert model.get_emoji(456) is None - - model.app.cache.get_emoji.assert_called_once_with(456) - - def test_get_emoji_when_no_cache_trait(self, model): - model.app = object() - assert model.get_emoji(456) is None - - def test_get_role(self, model): - model.app.cache.get_role.return_value.guild_id = model.id - assert model.get_role(456) is model.app.cache.get_role.return_value - model.app.cache.get_role.assert_called_once_with(456) - - def test_get_role_when_not_from_guild(self, model): - model.app.cache.get_role.return_value.guild_id = 7623123321123 - - assert model.get_role(456) is None - - model.app.cache.get_role.assert_called_once_with(456) - - def test_get_role_when_no_cache_trait(self, model): - model.app = object() - assert model.get_role(456) is None + def test_get_channels(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object( + patched_cache, "get_guild_channels_view_for_guild" + ) as patched_get_guild_channels_view_for_guild, + ): + assert guild.get_channels() is patched_get_guild_channels_view_for_guild.return_value + patched_get_guild_channels_view_for_guild.assert_called_once_with(123) + + def test_get_channels_when_no_cache_trait(self, guild: guilds.Guild): + guild.app = mock.Mock(traits.RESTAware) + assert guild.get_channels() == {} + + def test_get_members(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_members_view_for_guild") as patched_get_members_view_for_guild, + ): + assert guild.get_members() is patched_get_members_view_for_guild.return_value + patched_get_members_view_for_guild.assert_called_once_with(123) + + def test_get_members_when_no_cache_trait(self, guild: guilds.Guild): + guild.app = mock.Mock(traits.RESTAware) + assert guild.get_members() == {} + + def test_get_presences(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_presences_view_for_guild") as patched_get_presences_view_for_guild, + ): + assert guild.get_presences() is patched_get_presences_view_for_guild.return_value + patched_get_presences_view_for_guild.assert_called_once_with(123) + + def test_get_presences_when_no_cache_trait(self, guild: guilds.Guild): + guild.app = mock.Mock(traits.RESTAware) + assert guild.get_presences() == {} + + def test_get_voice_states(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object( + patched_cache, "get_voice_states_view_for_guild" + ) as patched_get_voice_states_view_for_guild, + ): + assert guild.get_voice_states() is patched_get_voice_states_view_for_guild.return_value + patched_get_voice_states_view_for_guild.assert_called_once_with(123) + + def test_get_voice_states_when_no_cache_trait(self, guild: guilds.Guild): + guild.app = mock.Mock(traits.RESTAware) + assert guild.get_voice_states() == {} + + def test_get_emojis(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_emojis_view_for_guild") as patched_get_emojis_view_for_guild, + ): + assert guild.get_emojis() is patched_get_emojis_view_for_guild.return_value + patched_get_emojis_view_for_guild.assert_called_once_with(123) + + def test_emojis_when_no_cache_trait(self, guild: guilds.Guild): + guild.app = mock.Mock(traits.RESTAware) + assert guild.get_emojis() == {} + + def test_get_sticker(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object( + patched_cache, "get_sticker", mock.Mock(return_value=mock.Mock(guild_id=guild.id)) + ) as patched_get_sticker, + ): + assert guild.get_sticker(456) is patched_get_sticker.return_value + patched_get_sticker.assert_called_once_with(456) + + def test_get_sticker_when_not_from_guild(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object( + patched_cache, "get_sticker", mock.Mock(return_value=mock.Mock(guild_id=546123123433)) + ) as patched_get_sticker, + ): + assert guild.get_sticker(456) is None + + patched_get_sticker.assert_called_once_with(456) + + def test_get_sticker_when_no_cache_trait(self, guild: guilds.Guild): + guild.app = mock.Mock() + assert guild.get_sticker(1234) is None + + def test_get_stickers(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_stickers_view_for_guild") as patched_get_stickers_view_for_guild, + ): + assert guild.get_stickers() is patched_get_stickers_view_for_guild.return_value + patched_get_stickers_view_for_guild.assert_called_once_with(123) + + def test_get_stickers_when_no_cache_trait(self, guild: guilds.Guild): + guild.app = mock.Mock(traits.RESTAware) + assert guild.get_stickers() == {} + + def test_roles(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_roles_view_for_guild") as patched_get_roles_view_for_guild, + ): + assert guild.get_roles() is patched_get_roles_view_for_guild.return_value + patched_get_roles_view_for_guild.assert_called_once_with(123) + + def test_get_roles_when_no_cache_trait(self, guild: guilds.Guild): + guild.app = mock.Mock(traits.RESTAware) + assert guild.get_roles() == {} + + def test_get_emoji(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object( + patched_cache, "get_emoji", mock.Mock(return_value=mock.Mock(guild_id=guild.id)) + ) as patched_get_emoji, + ): + assert guild.get_emoji(456) is patched_get_emoji.return_value + patched_get_emoji.assert_called_once_with(456) + + def test_get_emoji_when_not_from_guild(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object( + patched_cache, "get_emoji", mock.Mock(return_value=mock.Mock(guild_id=1233212)) + ) as patched_get_emoji, + ): + assert guild.get_emoji(456) is None + + patched_get_emoji.assert_called_once_with(456) + + def test_get_emoji_when_no_cache_trait(self, guild: guilds.Guild): + guild.app = mock.Mock() + assert guild.get_emoji(456) is None + + def test_get_role(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object( + patched_cache, "get_role", mock.Mock(return_value=mock.Mock(guild_id=guild.id)) + ) as patched_get_role, + ): + assert guild.get_role(456) is patched_get_role.return_value + patched_get_role.assert_called_once_with(456) + + def test_get_role_when_not_from_guild(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object( + patched_cache, "get_role", mock.Mock(return_value=mock.Mock(guild_id=7623123321123)) + ) as patched_get_role, + ): + assert guild.get_role(456) is None + patched_get_role.assert_called_once_with(456) + + def test_get_role_when_no_cache_trait(self, guild: guilds.Guild): + guild.app = mock.Mock() + assert guild.get_role(456) is None @pytest.mark.asyncio - async def test_invites_disabled_default(self, model): - assert model.invites_disabled is False + async def test_invites_disabled_default(self, guild: guilds.Guild): + assert guild.invites_disabled is False @pytest.mark.asyncio - async def test_invites_disabled_via_incidents(self, model): - model.incidents = guilds.GuildIncidents( + async def test_invites_disabled_via_incidents(self, guild: guilds.Guild): + guild.incidents = guilds.GuildIncidents( invites_disabled_until=datetime.datetime(2021, 11, 17), dms_disabled_until=None, dm_spam_detected_at=None, raid_detected_at=None, ) - assert model.invites_disabled is True + assert guild.invites_disabled is True @pytest.mark.asyncio - async def test_invites_disabled_via_feature(self, model): - model.features.append(guilds.GuildFeature.INVITES_DISABLED) - assert model.invites_disabled is True + async def test_invites_disabled_via_feature(self, guild: guilds.Guild): + guild.features.append(guilds.GuildFeature.INVITES_DISABLED) + assert guild.invites_disabled is True - def test_make_banner_url_format_set_to_deprecated_ext_argument_if_provided(self, model): + def test_make_banner_url_format_set_to_deprecated_ext_argument_if_provided(self, guild: guilds.Guild): with mock.patch.object( routes, "CDN_GUILD_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_banner_url(ext="JPEG") == "file" + assert guild.make_banner_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, guild_id=123, hash="banner_hash", size=4096, file_format="JPEG", lossless=True ) - def test_make_discovery_splash_url_format_set_to_deprecated_ext_argument_if_provided(self, model): + def test_make_discovery_splash_url_format_set_to_deprecated_ext_argument_if_provided(self, guild: guilds.Guild): with mock.patch.object( routes, "CDN_GUILD_DISCOVERY_SPLASH", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_discovery_splash_url(ext="JPEG") == "file" + assert guild.make_discovery_splash_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, guild_id=123, hash="discovery_splash_hash", size=4096, file_format="JPEG", lossless=True ) - def test_make_splash_url_format_set_to_deprecated_ext_argument_if_provided(self, model): + def test_make_splash_url_format_set_to_deprecated_ext_argument_if_provided(self, guild: guilds.Guild): with mock.patch.object( routes, "CDN_GUILD_SPLASH", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_splash_url(ext="JPEG") == "file" + assert guild.make_splash_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, guild_id=123, hash="splash_hash", size=4096, file_format="JPEG", lossless=True ) - def test_splash_url(self, model): + def test_splash_url(self, guild: guilds.Guild): splash = object() with mock.patch.object(guilds.Guild, "make_splash_url", return_value=splash): - assert model.splash_url is splash + assert guild.splash_url is splash - def test_make_splash_url_when_hash(self, model): - model.splash_hash = "18dnf8dfbakfdh" + def test_make_splash_url_when_hash(self, guild: guilds.Guild): + guild.splash_hash = "18dnf8dfbakfdh" with mock.patch.object( routes, "CDN_GUILD_SPLASH", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_splash_url(file_format="URL", size=2) == "file" + assert guild.make_splash_url(file_format="PNG", size=2) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, guild_id=123, hash="18dnf8dfbakfdh", size=2, file_format="URL", lossless=True + urls.CDN_URL, guild_id=123, hash="18dnf8dfbakfdh", size=2, file_format="PNG", lossless=True ) - def test_make_splash_url_when_no_hash(self, model): - model.splash_hash = None - assert model.make_splash_url(ext="png", size=1024) is None + def test_make_splash_url_when_no_hash(self, guild: guilds.Guild): + guild.splash_hash = None + assert guild.make_splash_url(ext="png", size=1024) is None - def test_discovery_splash_url(self, model): - discovery_splash = object() + def test_discovery_splash_url(self, guild: guilds.Guild): + discovery_splash = mock.Mock() with mock.patch.object(guilds.Guild, "make_discovery_splash_url", return_value=discovery_splash): - assert model.discovery_splash_url is discovery_splash + assert guild.discovery_splash_url is discovery_splash - def test_make_discovery_splash_url_when_hash(self, model): - model.discovery_splash_hash = "18dnf8dfbakfdh" + def test_make_discovery_splash_url_when_hash(self, guild: guilds.Guild): + guild.discovery_splash_hash = "18dnf8dfbakfdh" with mock.patch.object( routes, "CDN_GUILD_DISCOVERY_SPLASH", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_discovery_splash_url(file_format="URL", size=1024) == "file" + assert guild.make_discovery_splash_url(file_format="PNG", size=1024) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, guild_id=123, hash="18dnf8dfbakfdh", size=1024, file_format="URL", lossless=True + urls.CDN_URL, guild_id=123, hash="18dnf8dfbakfdh", size=1024, file_format="PNG", lossless=True ) - def test_make_discovery_splash_url_when_no_hash(self, model): - model.discovery_splash_hash = None - assert model.make_discovery_splash_url(ext="png", size=2048) is None + def test_make_discovery_splash_url_when_no_hash(self, guild: guilds.Guild): + guild.discovery_splash_hash = None + assert guild.make_discovery_splash_url(ext="png", size=2048) is None - def test_banner_url(self, model): - banner = object() + def test_banner_url(self, guild: guilds.Guild): + banner = mock.Mock() with mock.patch.object(guilds.Guild, "make_banner_url", return_value=banner): - assert model.banner_url is banner + assert guild.banner_url is banner - def test_make_banner_url_when_hash(self, model): + def test_make_banner_url_when_hash(self, guild: guilds.Guild): with mock.patch.object( routes, "CDN_GUILD_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_banner_url(file_format="URL", size=512) == "file" + assert guild.make_banner_url(file_format="PNG", size=512) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, guild_id=123, hash="banner_hash", size=512, file_format="URL", lossless=True + urls.CDN_URL, guild_id=123, hash="banner_hash", size=512, file_format="PNG", lossless=True ) - def test_make_banner_url_when_format_is_None_and_banner_hash_is_for_animated(self, model): - model.banner_hash = "a_18dnf8dfbakfdh" + def test_make_banner_url_when_format_is_None_and_banner_hash_is_for_animated(self, guild: guilds.Guild): + guild.banner_hash = "a_18dnf8dfbakfdh" with mock.patch.object( routes, "CDN_GUILD_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_banner_url(file_format=None, size=4096) == "file" + assert guild.make_banner_url(file_format="GIF", size=4096) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, guild_id=model.id, hash="a_18dnf8dfbakfdh", size=4096, file_format="GIF", lossless=True + urls.CDN_URL, guild_id=guild.id, hash="a_18dnf8dfbakfdh", size=4096, file_format="GIF", lossless=True ) - def test_make_banner_url_when_format_is_None_and_banner_hash_is_not_for_animated(self, model): - model.banner_hash = "18dnf8dfbakfdh" + def test_make_banner_url_when_format_is_None_and_banner_hash_is_not_for_animated(self, guild: guilds.Guild): + guild.banner_hash = "18dnf8dfbakfdh" with mock.patch.object( routes, "CDN_GUILD_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_banner_url(file_format=None, size=4096) == "file" + assert guild.make_banner_url(file_format="PNG", size=4096) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, guild_id=model.id, hash=model.banner_hash, size=4096, file_format="PNG", lossless=True + urls.CDN_URL, guild_id=guild.id, hash=guild.banner_hash, size=4096, file_format="PNG", lossless=True ) - def test_make_banner_url_when_no_hash(self, model): - model.banner_hash = None - assert model.make_banner_url(file_format="PNG", size=2048) is None + def test_make_banner_url_when_no_hash(self, guild: guilds.Guild): + guild.banner_hash = None + assert guild.make_banner_url(file_format="PNG", size=2048) is None @pytest.mark.asyncio - async def test_fetch_owner(self, model): - model.app.rest.fetch_member = mock.AsyncMock() + async def test_fetch_owner(self, guild: guilds.Guild): + guild.app.rest.fetch_member = mock.AsyncMock() - assert await model.fetch_owner() is model.app.rest.fetch_member.return_value - model.app.rest.fetch_member.assert_awaited_once_with(123, 1111) + assert await guild.fetch_owner() is guild.app.rest.fetch_member.return_value + guild.app.rest.fetch_member.assert_awaited_once_with(123, 1111) @pytest.mark.asyncio - async def test_fetch_widget_channel(self, model): + async def test_fetch_widget_channel(self, guild: guilds.Guild): mock_channel = mock.Mock(channels_.GuildChannel) - model.app.rest.fetch_channel = mock.AsyncMock(return_value=mock_channel) + guild.app.rest.fetch_channel = mock.AsyncMock(return_value=mock_channel) - assert await model.fetch_widget_channel() is model.app.rest.fetch_channel.return_value - model.app.rest.fetch_channel.assert_awaited_once_with(192729) + assert await guild.fetch_widget_channel() is guild.app.rest.fetch_channel.return_value + guild.app.rest.fetch_channel.assert_awaited_once_with(192729) @pytest.mark.asyncio - async def test_fetch_widget_channel_when_None(self, model): - model.widget_channel_id = None + async def test_fetch_widget_channel_when_None(self, guild: guilds.Guild): + guild.widget_channel_id = None - assert await model.fetch_widget_channel() is None + assert await guild.fetch_widget_channel() is None @pytest.mark.asyncio - async def test_fetch_rules_channel(self, model): + async def test_fetch_rules_channel(self, guild: guilds.Guild): mock_channel = mock.Mock(channels_.GuildTextChannel) - model.app.rest.fetch_channel = mock.AsyncMock(return_value=mock_channel) + guild.app.rest.fetch_channel = mock.AsyncMock(return_value=mock_channel) - assert await model.fetch_rules_channel() is model.app.rest.fetch_channel.return_value - model.app.rest.fetch_channel.assert_awaited_once_with(123445) + assert await guild.fetch_rules_channel() is guild.app.rest.fetch_channel.return_value + guild.app.rest.fetch_channel.assert_awaited_once_with(123445) @pytest.mark.asyncio - async def test_fetch_rules_channel_when_None(self, model): - model.rules_channel_id = None + async def test_fetch_rules_channel_when_None(self, guild: guilds.Guild): + guild.rules_channel_id = None - assert await model.fetch_rules_channel() is None + assert await guild.fetch_rules_channel() is None @pytest.mark.asyncio - async def test_fetch_system_channel(self, model): + async def test_fetch_system_channel(self, guild: guilds.Guild): mock_channel = mock.Mock(channels_.GuildTextChannel) - model.app.rest.fetch_channel = mock.AsyncMock(return_value=mock_channel) + guild.app.rest.fetch_channel = mock.AsyncMock(return_value=mock_channel) - assert await model.fetch_system_channel() is model.app.rest.fetch_channel.return_value - model.app.rest.fetch_channel.assert_awaited_once_with(123888) + assert await guild.fetch_system_channel() is guild.app.rest.fetch_channel.return_value + guild.app.rest.fetch_channel.assert_awaited_once_with(123888) @pytest.mark.asyncio - async def test_fetch_system_channel_when_None(self, model): - model.system_channel_id = None + async def test_fetch_system_channel_when_None(self, guild: guilds.Guild): + guild.system_channel_id = None - assert await model.fetch_system_channel() is None + assert await guild.fetch_system_channel() is None @pytest.mark.asyncio - async def test_fetch_public_updates_channel(self, model): + async def test_fetch_public_updates_channel(self, guild: guilds.Guild): mock_channel = mock.Mock(channels_.GuildTextChannel) - model.app.rest.fetch_channel = mock.AsyncMock(return_value=mock_channel) + guild.app.rest.fetch_channel = mock.AsyncMock(return_value=mock_channel) - assert await model.fetch_public_updates_channel() is model.app.rest.fetch_channel.return_value - model.app.rest.fetch_channel.assert_awaited_once_with(99699) + assert await guild.fetch_public_updates_channel() is guild.app.rest.fetch_channel.return_value + guild.app.rest.fetch_channel.assert_awaited_once_with(99699) @pytest.mark.asyncio - async def test_fetch_public_updates_channel_when_None(self, model): - model.public_updates_channel_id = None + async def test_fetch_public_updates_channel_when_None(self, guild: guilds.Guild): + guild.public_updates_channel_id = None - assert await model.fetch_public_updates_channel() is None + assert await guild.fetch_public_updates_channel() is None @pytest.mark.asyncio - async def test_fetch_afk_channel(self, model): + async def test_fetch_afk_channel(self, guild: guilds.Guild): mock_channel = mock.Mock(channels_.GuildVoiceChannel) - model.app.rest.fetch_channel = mock.AsyncMock(return_value=mock_channel) + guild.app.rest.fetch_channel = mock.AsyncMock(return_value=mock_channel) - assert await model.fetch_afk_channel() is model.app.rest.fetch_channel.return_value - model.app.rest.fetch_channel.assert_awaited_once_with(1234) + assert await guild.fetch_afk_channel() is guild.app.rest.fetch_channel.return_value + guild.app.rest.fetch_channel.assert_awaited_once_with(1234) @pytest.mark.asyncio - async def test_fetch_afk_channel_when_None(self, model): - model.afk_channel_id = None - - assert await model.fetch_afk_channel() is None - - def test_get_channel(self, model): - model.app.cache.get_guild_channel.return_value.guild_id = model.id - assert model.get_channel(456) is model.app.cache.get_guild_channel.return_value - model.app.cache.get_guild_channel.assert_called_once_with(456) - - def test_get_channel_when_not_from_guild(self, model): - model.app.cache.get_guild_channel.return_value.guild_id = 654523123 - - assert model.get_channel(456) is None - - model.app.cache.get_guild_channel.assert_called_once_with(456) - - def test_get_channel_when_no_cache_trait(self, model): - model.app = object() - assert model.get_channel(456) is None - - def test_get_member(self, model): - assert model.get_member(456) is model.app.cache.get_member.return_value - model.app.cache.get_member.assert_called_once_with(123, 456) - - def test_get_member_when_no_cache_trait(self, model): - model.app = object() - assert model.get_member(456) is None - - def test_get_presence(self, model): - assert model.get_presence(456) is model.app.cache.get_presence.return_value - model.app.cache.get_presence.assert_called_once_with(123, 456) - - def test_get_presence_when_no_cache_trait(self, model): - model.app = object() - assert model.get_presence(456) is None - - def test_get_voice_state(self, model): - assert model.get_voice_state(456) is model.app.cache.get_voice_state.return_value - model.app.cache.get_voice_state.assert_called_once_with(123, 456) - - def test_get_voice_state_when_no_cache_trait(self, model): - model.app = object() - assert model.get_voice_state(456) is None - - def test_get_my_member_when_not_shardaware(self, model): - model.app = object() - assert model.get_my_member() is None - - def test_get_my_member_when_no_me(self, model): - model.app.get_me = mock.Mock(return_value=None) - - assert model.get_my_member() is None - - model.app.get_me.assert_called_once_with() - - def test_get_my_member(self, model): - model.app.get_me = mock.Mock() - model.app.get_me.return_value.id = 123 - - with mock.patch.object(guilds.Guild, "get_member") as get_member: - assert model.get_my_member() is get_member.return_value - - get_member.assert_called_once_with(123) - model.app.get_me.assert_called_once_with() + async def test_fetch_afk_channel_when_None(self, guild: guilds.Guild): + guild.afk_channel_id = None + + assert await guild.fetch_afk_channel() is None + + def test_get_channel(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object( + patched_cache, "get_guild_channel", mock.Mock(return_value=mock.Mock(guild_id=guild.id)) + ) as patched_get_guild_channel, + ): + assert guild.get_channel(456) is patched_get_guild_channel.return_value + patched_get_guild_channel.assert_called_once_with(456) + + def test_get_channel_when_not_from_guild(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object( + patched_cache, "get_guild_channel", mock.Mock(return_value=mock.Mock(guild_id=654523123)) + ) as patched_get_guild_channel, + ): + assert guild.get_channel(456) is None + patched_get_guild_channel.assert_called_once_with(456) + + def test_get_channel_when_no_cache_trait(self, guild: guilds.Guild): + guild.app = mock.Mock() + assert guild.get_channel(456) is None + + def test_get_member(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_member") as patched_get_member, + ): + assert guild.get_member(456) is patched_get_member.return_value + patched_get_member.assert_called_once_with(123, 456) + + def test_get_member_when_no_cache_trait(self, guild: guilds.Guild): + guild.app = mock.Mock(traits.RESTAware) + assert guild.get_member(456) is None + + def test_get_presence(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_presence") as patched_get_presence, + ): + assert guild.get_presence(456) is patched_get_presence.return_value + patched_get_presence.assert_called_once_with(123, 456) + + def test_get_presence_when_no_cache_trait(self, guild: guilds.Guild): + guild.app = mock.Mock(traits.RESTAware) + assert guild.get_presence(456) is None + + def test_get_voice_state(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "cache") as patched_cache, + mock.patch.object(patched_cache, "get_voice_state") as patched_get_voice_state, + ): + assert guild.get_voice_state(456) is patched_get_voice_state.return_value + patched_get_voice_state.assert_called_once_with(123, 456) + + def test_get_voice_state_when_no_cache_trait(self, guild: guilds.Guild): + guild.app = mock.Mock(traits.RESTAware) + assert guild.get_voice_state(456) is None + + def test_get_my_member_when_not_shardaware(self, guild: guilds.Guild): + guild.app = mock.Mock(traits.RESTAware) + assert guild.get_my_member() is None + + def test_get_my_member_when_no_me(self, guild: guilds.Guild): + with mock.patch.object(guild.app, "get_me", mock.Mock(return_value=None)) as patched_get_me: + assert guild.get_my_member() is None + patched_get_me.assert_called_once_with() + + def test_get_my_member(self, guild: guilds.Guild): + with ( + mock.patch.object(guild.app, "get_me", mock.Mock(return_value=mock.Mock(id=123))) as patched_get_me, + mock.patch.object(guilds.Guild, "get_member") as patched_get_member, + ): + assert guild.get_my_member() is patched_get_member.return_value + + patched_get_member.assert_called_once_with(123) + patched_get_me.assert_called_once_with() class TestRestGuild: @pytest.fixture - def model(self, mock_app): + def rest_guild(self, mock_app: traits.RESTAware) -> guilds.RESTGuild: return guilds.RESTGuild( app=mock_app, id=snowflakes.Snowflake(123), diff --git a/tests/hikari/test_invites.py b/tests/hikari/test_invites.py index ae1f4ec105..afa903e105 100644 --- a/tests/hikari/test_invites.py +++ b/tests/hikari/test_invites.py @@ -20,29 +20,41 @@ # SOFTWARE. from __future__ import annotations +import datetime + import mock import pytest +from hikari import guilds from hikari import invites +from hikari import snowflakes from hikari import urls from hikari.internal import routes -from tests.hikari import hikari_test_helpers class TestInviteCode: - def test_str_operator(self): - mock_invite = hikari_test_helpers.mock_class_namespace( - invites.InviteCode, code=mock.PropertyMock(return_value="hikari") - )() - assert str(mock_invite) == "https://discord.gg/hikari" + class MockInviteCode(invites.InviteCode): + def __init__(self): + self._code = "hikari" + + @property + def code(self) -> str: + return self._code + + @pytest.fixture + def invite_code(self) -> invites.InviteCode: + return TestInviteCode.MockInviteCode() + + def test_str_operator(self, invite_code: invites.InviteCode): + assert str(invite_code) == "https://discord.gg/hikari" class TestInviteGuild: @pytest.fixture - def model(self): + def model(self) -> invites.InviteGuild: return invites.InviteGuild( app=mock.Mock(), - id=123321, + id=snowflakes.Snowflake(123321), icon_hash="hi", name="bye", features=[], @@ -52,10 +64,10 @@ def model(self): verification_level=1, vanity_url_code=None, welcome_screen=None, - nsfw_level=2, + nsfw_level=guilds.GuildNSFWLevel.SAFE, ) - def test_make_splash_url_format_set_to_deprecated_ext_argument_if_provided(self, model): + def test_make_splash_url_format_set_to_deprecated_ext_argument_if_provided(self, model: invites.InviteGuild): with mock.patch.object( routes, "CDN_GUILD_SPLASH", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: @@ -66,7 +78,7 @@ def test_make_splash_url_format_set_to_deprecated_ext_argument_if_provided(self, ) def test_splash_url(self, model: invites.InviteGuild): - splash = object() + splash = mock.Mock() with mock.patch.object(invites.InviteGuild, "make_splash_url", return_value=splash): assert model.splash_url is splash @@ -87,7 +99,7 @@ def test_make_splash_url_when_no_hash(self, model: invites.InviteGuild): model.splash_hash = None assert model.make_splash_url(ext="png", size=1024) is None - def test_make_banner_url_format_set_to_deprecated_ext_argument_if_provided(self, model): + def test_make_banner_url_format_set_to_deprecated_ext_argument_if_provided(self, model: invites.InviteGuild): with mock.patch.object( routes, "CDN_GUILD_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: @@ -98,7 +110,7 @@ def test_make_banner_url_format_set_to_deprecated_ext_argument_if_provided(self, ) def test_banner_url(self, model: invites.InviteGuild): - banner = object() + banner = mock.Mock() with mock.patch.object(invites.InviteGuild, "make_banner_url", return_value=banner): assert model.banner_url is banner @@ -143,14 +155,33 @@ def test_make_banner_url_when_no_hash(self, model: invites.InviteGuild): class TestInviteWithMetadata: - def test_uses_left(self): - mock_invite = hikari_test_helpers.mock_class_namespace( - invites.InviteWithMetadata, init_=False, max_uses=123, uses=55 - )() + @pytest.fixture + def invite_with_metadata(self): + return invites.InviteWithMetadata( + app=mock.Mock(), + code="invite_code", + guild=mock.Mock(), + guild_id=snowflakes.Snowflake(12345), + channel=mock.Mock(), + channel_id=snowflakes.Snowflake(54321), + inviter=mock.Mock(), + target_type=invites.TargetType.EMBEDDED_APPLICATION, + target_user=mock.Mock(), + target_application=mock.Mock(), + approximate_active_member_count=3, + expires_at=datetime.datetime.fromtimestamp(3000), + approximate_member_count=10, + uses=55, + max_uses=123, + max_age=datetime.timedelta(3), + is_temporary=True, + created_at=datetime.datetime.fromtimestamp(2000), + ) - assert mock_invite.uses_left == 68 + def test_uses_left(self, invite_with_metadata: invites.InviteWithMetadata): + assert invite_with_metadata.uses_left == 68 - def test_uses_left_when_none(self): - mock_invite = hikari_test_helpers.mock_class_namespace(invites.InviteWithMetadata, init_=False, max_uses=None)() + def test_uses_left_when_none(self, invite_with_metadata: invites.InviteWithMetadata): + invite_with_metadata.max_uses = None - assert mock_invite.uses_left is None + assert invite_with_metadata.uses_left is None diff --git a/tests/hikari/test_iterators.py b/tests/hikari/test_iterators.py index 2ee1604f80..776f4eab34 100644 --- a/tests/hikari/test_iterators.py +++ b/tests/hikari/test_iterators.py @@ -20,18 +20,23 @@ # SOFTWARE. from __future__ import annotations +import typing + import pytest from hikari import iterators -from tests.hikari import hikari_test_helpers class TestLazyIterator: + class MockLazyIterator(iterators.LazyIterator[typing.Any]): + def __anext__(self) -> typing.Any: + pass + @pytest.fixture - def lazy_iterator(self): - return hikari_test_helpers.mock_class_namespace(iterators.LazyIterator)() + def lazy_iterator(self) -> iterators.LazyIterator[typing.Any]: + return TestLazyIterator.MockLazyIterator() - def test_asynchronous_only(self, lazy_iterator): + def test_asynchronous_only(self, lazy_iterator: iterators.LazyIterator[typing.Any]): with pytest.raises(TypeError, match="is async-only, did you mean 'async for' or `anext`?"): next(lazy_iterator) diff --git a/tests/hikari/test_messages.py b/tests/hikari/test_messages.py index 51a83b7226..b427864873 100644 --- a/tests/hikari/test_messages.py +++ b/tests/hikari/test_messages.py @@ -29,6 +29,7 @@ from hikari import guilds from hikari import messages from hikari import snowflakes +from hikari import traits from hikari import undefined from hikari import urls from hikari import users @@ -38,7 +39,7 @@ class TestAttachment: def test_str_operator(self): attachment = messages.Attachment( - id=123, + id=snowflakes.Snowflake(123), filename="super_cool_file.cool", title="other title", description="description!", @@ -63,12 +64,14 @@ def test_str_operator(self): class TestMessageApplication: @pytest.fixture - def message_application(self): + def message_application(self) -> messages.MessageApplication: return messages.MessageApplication( - id=123, name="test app", description="", icon_hash="123abc", cover_image_hash="abc123" + id=snowflakes.Snowflake(123), name="test app", description="", icon_hash="123abc", cover_image_hash="abc123" ) - def test_make_cover_url_format_set_to_deprecated_ext_argument_if_provided(self, message_application): + def test_make_cover_url_format_set_to_deprecated_ext_argument_if_provided( + self, message_application: messages.MessageApplication + ): with mock.patch.object( routes, "CDN_APPLICATION_COVER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: @@ -78,16 +81,16 @@ def test_make_cover_url_format_set_to_deprecated_ext_argument_if_provided(self, urls.CDN_URL, application_id=123, hash="abc123", size=4096, file_format="JPEG", lossless=True ) - def test_cover_image_url(self, message_application): + def test_cover_image_url(self, message_application: messages.MessageApplication): with mock.patch.object(messages.MessageApplication, "make_cover_image_url") as mock_cover_image: assert message_application.cover_image_url is mock_cover_image() - def test_make_cover_image_url_when_hash_is_none(self, message_application): + def test_make_cover_image_url_when_hash_is_none(self, message_application: messages.MessageApplication): message_application.cover_image_hash = None assert message_application.make_cover_image_url() is None - def test_make_cover_image_url_when_hash_is_not_none(self, message_application): + def test_make_cover_image_url_when_hash_is_not_none(self, message_application: messages.MessageApplication): with mock.patch.object( routes, "CDN_APPLICATION_COVER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: @@ -99,9 +102,9 @@ def test_make_cover_image_url_when_hash_is_not_none(self, message_application): @pytest.fixture -def message(): +def message(hikari_app: traits.RESTAware) -> messages.Message: return messages.Message( - app=None, + app=hikari_app, id=snowflakes.Snowflake(1234), channel_id=snowflakes.Snowflake(5678), guild_id=snowflakes.Snowflake(910112), @@ -117,7 +120,7 @@ def message(): mentions_everyone=False, attachments=(), embeds=(), - poll=object(), + poll=None, reactions=(), is_pinned=True, webhook_id=None, @@ -125,12 +128,12 @@ def message(): activity=None, application=None, message_reference=None, - message_snapshots=None, - flags=None, + flags=messages.MessageFlag.NONE, + message_snapshots=[], nonce=None, referenced_message=None, stickers=[], - application_id=123123, + application_id=snowflakes.Snowflake(123123), components=[], thread=None, interaction_metadata=None, @@ -138,22 +141,22 @@ def message(): class TestMessage: - def test_make_link_when_guild_is_not_none(self, message): - message.id = 789 - message.channel_id = 456 + def test_make_link_when_guild_is_not_none(self, message: messages.Message): + message.id = snowflakes.Snowflake(789) + message.channel_id = snowflakes.Snowflake(456) assert message.make_link(123) == "https://discord.com/channels/123/456/789" - def test_make_link_when_guild_is_none(self, message): + def test_make_link_when_guild_is_none(self, message: messages.Message): message.app = mock.Mock() - message.id = 789 - message.channel_id = 456 + message.id = snowflakes.Snowflake(789) + message.channel_id = snowflakes.Snowflake(456) assert message.make_link(None) == "https://discord.com/channels/@me/456/789" @pytest.fixture -def message_reference(): +def message_reference(hikari_app: traits.RESTAware) -> messages.MessageReference: return messages.MessageReference( - app=None, + app=hikari_app, guild_id=snowflakes.Snowflake(123), channel_id=snowflakes.Snowflake(456), id=snowflakes.Snowflake(789), @@ -162,16 +165,16 @@ def message_reference(): class TestMessageReference: - def test_make_link_when_guild_is_not_none(self, message_reference): + def test_make_link_when_guild_is_not_none(self, message_reference: messages.MessageReference): assert message_reference.message_link == "https://discord.com/channels/123/456/789" assert message_reference.channel_link == "https://discord.com/channels/123/456" - def test_make_link_when_guild_is_none(self, message_reference): + def test_make_link_when_guild_is_none(self, message_reference: messages.MessageReference): message_reference.guild_id = None assert message_reference.message_link == "https://discord.com/channels/@me/456/789" assert message_reference.channel_link == "https://discord.com/channels/@me/456" - def test_make_link_when_id_is_none(self, message_reference): + def test_make_link_when_id_is_none(self, message_reference: messages.MessageReference): message_reference.id = None assert message_reference.message_link is None assert message_reference.channel_link == "https://discord.com/channels/123/456" @@ -179,22 +182,22 @@ def test_make_link_when_id_is_none(self, message_reference): @pytest.mark.asyncio class TestAsyncMessage: - async def test_fetch_channel(self, message): + async def test_fetch_channel(self, message: messages.Message): message.app = mock.AsyncMock() - message.channel_id = 123 + message.channel_id = snowflakes.Snowflake(123) await message.fetch_channel() message.app.rest.fetch_channel.assert_awaited_once_with(123) - async def test_edit(self, message): + async def test_edit(self, message: messages.Message): message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 - embed = object() - embeds = [object(), object()] - component = object() - components = object(), object() - attachment = object() - roles = [object()] + message.id = snowflakes.Snowflake(123) + message.channel_id = snowflakes.Snowflake(456) + embed = mock.Mock() + embeds = [mock.Mock(), mock.Mock()] + component = mock.Mock() + components = mock.Mock(), mock.Mock() + attachment = mock.Mock() + roles = [mock.Mock()] await message.edit( content="test content", embed=embed, @@ -226,19 +229,20 @@ async def test_edit(self, message): flags=messages.MessageFlag.URGENT, ) - async def test_respond(self, message): + async def test_respond(self, message: messages.Message): message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 - embed = object() - embeds = [object(), object()] - poll = object() - roles = [object()] - attachment = object() - attachments = [object()] - component = object() - components = object(), object() - reference_messsage = object() + message.id = snowflakes.Snowflake(123) + message.channel_id = snowflakes.Snowflake(456) + embed = mock.Mock() + embeds = [mock.Mock(), mock.Mock()] + roles = [mock.Mock()] + attachment = mock.Mock() + attachments = [mock.Mock()] + component = mock.Mock() + components = mock.Mock(), mock.Mock() + reference_messsage = mock.Mock() + poll = mock.Mock() + await message.respond( content="test content", embed=embed, @@ -259,6 +263,7 @@ async def test_respond(self, message): mentions_reply=True, flags=321123, ) + message.app.rest.create_message.assert_awaited_once_with( channel=456, content="test content", @@ -281,10 +286,10 @@ async def test_respond(self, message): flags=321123, ) - async def test_respond_when_reply_is_True(self, message): + async def test_respond_when_reply_is_True(self, message: messages.Message): message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 + message.id = snowflakes.Snowflake(123) + message.channel_id = snowflakes.Snowflake(456) await message.respond(reply=True) message.app.rest.create_message.assert_awaited_once_with( channel=456, @@ -308,10 +313,10 @@ async def test_respond_when_reply_is_True(self, message): flags=undefined.UNDEFINED, ) - async def test_respond_when_reply_is_False(self, message): + async def test_respond_when_reply_is_False(self, message: messages.Message): message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 + message.id = snowflakes.Snowflake(123) + message.channel_id = snowflakes.Snowflake(456) await message.respond(reply=False) message.app.rest.create_message.assert_awaited_once_with( channel=456, @@ -335,50 +340,50 @@ async def test_respond_when_reply_is_False(self, message): flags=undefined.UNDEFINED, ) - async def test_delete(self, message): + async def test_delete(self, message: messages.Message): message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 + message.id = snowflakes.Snowflake(123) + message.channel_id = snowflakes.Snowflake(456) await message.delete() message.app.rest.delete_message.assert_awaited_once_with(456, 123) - async def test_add_reaction(self, message): + async def test_add_reaction(self, message: messages.Message): message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 + message.id = snowflakes.Snowflake(123) + message.channel_id = snowflakes.Snowflake(456) await message.add_reaction("👌", 123123) message.app.rest.add_reaction.assert_awaited_once_with(channel=456, message=123, emoji="👌", emoji_id=123123) - async def test_remove_reaction(self, message): + async def test_remove_reaction(self, message: messages.Message): message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 + message.id = snowflakes.Snowflake(123) + message.channel_id = snowflakes.Snowflake(456) await message.remove_reaction("👌", 341231) message.app.rest.delete_my_reaction.assert_awaited_once_with( channel=456, message=123, emoji="👌", emoji_id=341231 ) - async def test_remove_reaction_with_user(self, message): + async def test_remove_reaction_with_user(self, message: messages.Message): message.app = mock.AsyncMock() - user = object() - message.id = 123 - message.channel_id = 456 + user = mock.Mock() + message.id = snowflakes.Snowflake(123) + message.channel_id = snowflakes.Snowflake(456) await message.remove_reaction("👌", 31231, user=user) message.app.rest.delete_reaction.assert_awaited_once_with( channel=456, message=123, emoji="👌", emoji_id=31231, user=user ) - async def test_remove_all_reactions(self, message): + async def test_remove_all_reactions(self, message: messages.Message): message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 + message.id = snowflakes.Snowflake(123) + message.channel_id = snowflakes.Snowflake(456) await message.remove_all_reactions() message.app.rest.delete_all_reactions.assert_awaited_once_with(channel=456, message=123) - async def test_remove_all_reactions_with_emoji(self, message): + async def test_remove_all_reactions_with_emoji(self, message: messages.Message): message.app = mock.AsyncMock() - message.id = 123 - message.channel_id = 456 + message.id = snowflakes.Snowflake(123) + message.channel_id = snowflakes.Snowflake(456) await message.remove_all_reactions("👌", emoji_id=65655) message.app.rest.delete_all_reactions_for_emoji.assert_awaited_once_with( channel=456, message=123, emoji="👌", emoji_id=65655 diff --git a/tests/hikari/test_presences.py b/tests/hikari/test_presences.py index c7ed916ff8..9c540fcdfe 100644 --- a/tests/hikari/test_presences.py +++ b/tests/hikari/test_presences.py @@ -26,6 +26,7 @@ from hikari import files from hikari import presences from hikari import snowflakes +from hikari import traits from hikari import urls from hikari.impl import gateway_bot from hikari.internal import routes @@ -81,7 +82,11 @@ def test_large_image_url_property_when_runtime_error(self): def test_make_large_image_url(self): asset = presences.ActivityAssets( - application_id=45123123, large_image="541sdfasdasd", large_text=None, small_image=None, small_text=None + application_id=snowflakes.Snowflake(45123123), + large_image="541sdfasdasd", + large_text=None, + small_image=None, + small_text=None, ) with mock.patch.object(routes, "CDN_APPLICATION_ASSET") as route: @@ -160,7 +165,11 @@ def test_small_image_url_property_when_runtime_error(self): def test_make_small_image_url(self): asset = presences.ActivityAssets( - application_id=123321, large_image=None, large_text=None, small_image="aseqwsdas", small_text=None + application_id=snowflakes.Snowflake(123321), + large_image=None, + large_text=None, + small_image="aseqwsdas", + small_text=None, ) with mock.patch.object(routes, "CDN_APPLICATION_ASSET") as route: @@ -202,7 +211,7 @@ def test_str_operator(self): class TestMemberPresence: @pytest.fixture - def model(self, mock_app): + def model(self, mock_app: traits.RESTAware) -> presences.MemberPresence: return presences.MemberPresence( app=mock_app, user_id=snowflakes.Snowflake(432), @@ -213,14 +222,14 @@ def model(self, mock_app): ) @pytest.mark.asyncio - async def test_fetch_user(self, model): + async def test_fetch_user(self, model: presences.MemberPresence): model.app.rest.fetch_user = mock.AsyncMock() assert await model.fetch_user() is model.app.rest.fetch_user.return_value model.app.rest.fetch_user.assert_awaited_once_with(432) @pytest.mark.asyncio - async def test_fetch_member(self, model): + async def test_fetch_member(self, model: presences.MemberPresence): model.app.rest.fetch_member = mock.AsyncMock() assert await model.fetch_member() is model.app.rest.fetch_member.return_value diff --git a/tests/hikari/test_scheduled_events.py b/tests/hikari/test_scheduled_events.py index 790849bf7c..0beb619ccf 100644 --- a/tests/hikari/test_scheduled_events.py +++ b/tests/hikari/test_scheduled_events.py @@ -20,55 +20,70 @@ # SOFTWARE. from __future__ import annotations +import datetime + import mock import pytest from hikari import scheduled_events from hikari import snowflakes +from hikari import traits from hikari import urls from hikari.internal import routes -from tests.hikari import hikari_test_helpers class TestScheduledEvent: @pytest.fixture - def model(self) -> scheduled_events.ScheduledEvent: - return hikari_test_helpers.mock_class_namespace(scheduled_events.ScheduledEvent, init_=False, slots_=False)() + def scheduled_event(self, hikari_app: traits.RESTAware) -> scheduled_events.ScheduledEvent: + return scheduled_events.ScheduledEvent( + app=hikari_app, + id=snowflakes.Snowflake(123456), + guild_id=snowflakes.Snowflake(654321), + name="scheduled_event", + description="scheduled_event_description", + start_time=datetime.datetime.fromtimestamp(1000), + end_time=datetime.datetime.fromtimestamp(2000), + privacy_level=scheduled_events.EventPrivacyLevel.GUILD_ONLY, + status=scheduled_events.ScheduledEventStatus.SCHEDULED, + entity_type=scheduled_events.ScheduledEventType.VOICE, + creator=mock.Mock(), + user_count=3, + image_hash="image_hash", + ) - def test_make_image_url_format_set_to_deprecated_ext_argument_if_provided(self, model): - model.id = snowflakes.Snowflake(543123) - model.image_hash = "ododododo" + def test_make_image_url_format_set_to_deprecated_ext_argument_if_provided( + self, scheduled_event: scheduled_events.ScheduledEvent + ): + scheduled_event.id = snowflakes.Snowflake(543123) + scheduled_event.image_hash = "ododododo" with mock.patch.object( routes, "SCHEDULED_EVENT_COVER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_image_url(ext="JPEG") == "file" + assert scheduled_event.make_image_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, scheduled_event_id=543123, hash="ododododo", size=4096, file_format="JPEG", lossless=True ) - def test_image_url_property(self, model: scheduled_events.ScheduledEvent): - model.make_image_url = mock.Mock() - - assert model.image_url == model.make_image_url.return_value + def test_image_url_property(self, scheduled_event: scheduled_events.ScheduledEvent): + with mock.patch.object(scheduled_events.ScheduledEvent, "make_image_url") as patched_make_image_url: + assert scheduled_event.image_url == patched_make_image_url.return_value - model.make_image_url.assert_called_once_with() - - def test_image_url(self, model: scheduled_events.ScheduledEvent): - model.id = snowflakes.Snowflake(543123) - model.image_hash = "ododododo" + def test_image_url(self, scheduled_event: scheduled_events.ScheduledEvent): + scheduled_event.id = snowflakes.Snowflake(543123) + scheduled_event.image_hash = "ododododo" with mock.patch.object(routes, "SCHEDULED_EVENT_COVER") as route: - assert model.make_image_url(file_format="JPEG", size=1) is route.compile_to_file.return_value + assert scheduled_event.make_image_url(file_format="JPEG", size=1) is route.compile_to_file.return_value route.compile_to_file.assert_called_once_with( urls.CDN_URL, scheduled_event_id=543123, hash="ododododo", size=1, file_format="JPEG", lossless=True ) - def test_make_image_url_when_image_hash_is_none(self, model: scheduled_events.ScheduledEvent): - model.image_hash = None + def test_make_image_url_when_image_hash_is_none(self, scheduled_event: scheduled_events.ScheduledEvent): + scheduled_event.image_hash = None - with mock.patch.object(routes, "SCHEDULED_EVENT_COVER") as route: - assert model.make_image_url(ext="JPEG", size=1) is None + with mock.patch.object(routes, "SCHEDULED_EVENT_COVER") as patched_scheduled_event_cover: + assert scheduled_event.make_image_url(ext="JPEG", size=1) is None - route.compile_to_file.assert_not_called() + patched_scheduled_event_cover.assert_not_called() diff --git a/tests/hikari/test_sessions.py b/tests/hikari/test_sessions.py index c6d8d47961..3a13733d93 100644 --- a/tests/hikari/test_sessions.py +++ b/tests/hikari/test_sessions.py @@ -22,6 +22,8 @@ import datetime +import mock + from hikari import sessions @@ -36,6 +38,8 @@ def test_SessionStartLimit_reset_at_property(): obj = sessions.SessionStartLimit( total=100, remaining=2, reset_after=datetime.timedelta(hours=1, days=10), max_concurrency=1 ) - obj._created_at = datetime.datetime(2020, 7, 22, 22, 22, 36, 988017, tzinfo=datetime.timezone.utc) - assert obj.reset_at == datetime.datetime(2020, 8, 1, 23, 22, 36, 988017, tzinfo=datetime.timezone.utc) + with mock.patch.object( + obj, "_created_at", datetime.datetime(2020, 7, 22, 22, 22, 36, 988017, tzinfo=datetime.timezone.utc) + ): + assert obj.reset_at == datetime.datetime(2020, 8, 1, 23, 22, 36, 988017, tzinfo=datetime.timezone.utc) diff --git a/tests/hikari/test_snowflake.py b/tests/hikari/test_snowflake.py index e1e6c4bd99..c9542209a8 100644 --- a/tests/hikari/test_snowflake.py +++ b/tests/hikari/test_snowflake.py @@ -22,6 +22,7 @@ import datetime import operator +import typing import mock import pytest @@ -31,54 +32,54 @@ @pytest.fixture -def raw_id(): +def raw_id() -> int: return 537_340_989_808_050_216 @pytest.fixture -def neko_snowflake(raw_id): +def neko_snowflake(raw_id: int) -> snowflakes.Snowflake: return snowflakes.Snowflake(raw_id) class TestSnowflake: - def test_created_at(self, neko_snowflake): + def test_created_at(self, neko_snowflake: snowflakes.Snowflake): assert neko_snowflake.created_at == datetime.datetime( 2019, 1, 22, 18, 41, 15, 283_000, tzinfo=datetime.timezone.utc ) - def test_increment(self, neko_snowflake): + def test_increment(self, neko_snowflake: snowflakes.Snowflake): assert neko_snowflake.increment == 40 - def test_internal_process_id(self, neko_snowflake): + def test_internal_process_id(self, neko_snowflake: snowflakes.Snowflake): assert neko_snowflake.internal_process_id == 0 - def test_internal_worker_id(self, neko_snowflake): + def test_internal_worker_id(self, neko_snowflake: snowflakes.Snowflake): assert neko_snowflake.internal_worker_id == 2 - def test_hash(self, neko_snowflake, raw_id): + def test_hash(self, neko_snowflake: snowflakes.Snowflake, raw_id: int): assert hash(neko_snowflake) == hash(raw_id) - def test_index(self, neko_snowflake, raw_id): + def test_index(self, neko_snowflake: snowflakes.Snowflake, raw_id: int): assert operator.index(neko_snowflake) == raw_id - def test_int_cast(self, neko_snowflake, raw_id): + def test_int_cast(self, neko_snowflake: snowflakes.Snowflake, raw_id: int): assert int(neko_snowflake) == raw_id - def test_str_cast(self, neko_snowflake, raw_id): + def test_str_cast(self, neko_snowflake: snowflakes.Snowflake, raw_id: int): assert str(neko_snowflake) == str(raw_id) - def test_repr_cast(self, neko_snowflake, raw_id): + def test_repr_cast(self, neko_snowflake: snowflakes.Snowflake, raw_id: int): assert repr(neko_snowflake) == repr(raw_id) - def test_eq(self, neko_snowflake, raw_id): + def test_eq(self, neko_snowflake: snowflakes.Snowflake, raw_id: int): assert neko_snowflake == raw_id assert neko_snowflake == snowflakes.Snowflake(raw_id) assert str(raw_id) != neko_snowflake - def test_lt(self, neko_snowflake, raw_id): + def test_lt(self, neko_snowflake: snowflakes.Snowflake, raw_id: int): assert neko_snowflake < raw_id + 1 - def test_deserialize(self, neko_snowflake, raw_id): + def test_deserialize(self, neko_snowflake: snowflakes.Snowflake, raw_id: int): assert neko_snowflake == snowflakes.Snowflake(raw_id) def test_from_datetime(self): @@ -97,29 +98,35 @@ def test_max(self): class TestUnique: @pytest.fixture - def neko_unique(self, neko_snowflake): + def neko_unique(self, neko_snowflake: snowflakes.Snowflake) -> snowflakes.Unique: class NekoUnique(snowflakes.Unique): - id = neko_snowflake + @property + def id(self) -> snowflakes.Snowflake: + return neko_snowflake return NekoUnique() - def test_created_at(self, neko_unique): + def test_created_at(self, neko_unique: snowflakes.Unique): assert neko_unique.created_at == datetime.datetime( 2019, 1, 22, 18, 41, 15, 283_000, tzinfo=datetime.timezone.utc ) - def test_index(self, neko_unique, raw_id): + def test_index(self, neko_unique: snowflakes.Unique, raw_id: int): assert operator.index(neko_unique) == raw_id - def test__hash__(self, neko_unique, raw_id): + def test__hash__(self, neko_unique: snowflakes.Unique, raw_id: int): assert hash(neko_unique) == hash(raw_id) - def test__eq__(self, neko_snowflake, raw_id): + def test__eq__(self, neko_snowflake: snowflakes.Snowflake, raw_id: int): class NekoUnique(snowflakes.Unique): - id = neko_snowflake + @property + def id(self) -> snowflakes.Snowflake: + return neko_snowflake class NekoUnique2(snowflakes.Unique): - id = neko_snowflake + @property + def id(self) -> snowflakes.Snowflake: + return neko_snowflake unique1 = NekoUnique() unique2 = NekoUnique() @@ -133,7 +140,7 @@ class NekoUnique2(snowflakes.Unique): ("guild_id", "expected_id"), [(140502780547694592, 2), ("655288690192416778", 1), (snowflakes.Snowflake(105785483455418368), 3)], ) -def test_calculate_shard_id_with_shard_count(guild_id, expected_id): +def test_calculate_shard_id_with_shard_count(guild_id: typing.Any, expected_id: int): assert snowflakes.calculate_shard_id(4, guild_id) == expected_id @@ -141,6 +148,6 @@ def test_calculate_shard_id_with_shard_count(guild_id, expected_id): ("guild_id", "expected_id"), [(140502780547694592, 2), ("115590097100865541", 5), (snowflakes.Snowflake(105785483455418368), 7)], ) -def test_calculate_shard_id_with_app(guild_id, expected_id): +def test_calculate_shard_id_with_app(guild_id: typing.Any, expected_id: int): mock_app = mock.Mock(gateway_bot.GatewayBot, shard_count=8) assert snowflakes.calculate_shard_id(mock_app, guild_id) == expected_id diff --git a/tests/hikari/test_stage_instances.py b/tests/hikari/test_stage_instances.py index 8902b6cfa9..48d4f57b1b 100644 --- a/tests/hikari/test_stage_instances.py +++ b/tests/hikari/test_stage_instances.py @@ -20,23 +20,18 @@ # SOFTWARE. from __future__ import annotations -import mock import pytest from hikari import snowflakes from hikari import stage_instances - - -@pytest.fixture -def mock_app(): - return mock.Mock() +from hikari import traits class TestStageInstance: @pytest.fixture - def stage_instance(self, mock_app): + def stage_instance(self, hikari_app: traits.RESTAware) -> stage_instances.StageInstance: return stage_instances.StageInstance( - app=mock_app, + app=hikari_app, id=snowflakes.Snowflake(123), channel_id=snowflakes.Snowflake(6969), guild_id=snowflakes.Snowflake(420), @@ -46,26 +41,26 @@ def stage_instance(self, mock_app): scheduled_event_id=snowflakes.Snowflake(1337), ) - def test_id_property(self, stage_instance): + def test_id_property(self, stage_instance: stage_instances.StageInstance): assert stage_instance.id == 123 - def test_app_property(self, stage_instance, mock_app): - assert stage_instance.app is mock_app + def test_app_property(self, stage_instance: stage_instances.StageInstance, hikari_app: traits.RESTAware): + assert stage_instance.app is hikari_app - def test_channel_id_property(self, stage_instance): + def test_channel_id_property(self, stage_instance: stage_instances.StageInstance): assert stage_instance.channel_id == 6969 - def test_guild_id_property(self, stage_instance): + def test_guild_id_property(self, stage_instance: stage_instances.StageInstance): assert stage_instance.guild_id == 420 - def test_topic_property(self, stage_instance): + def test_topic_property(self, stage_instance: stage_instances.StageInstance): assert stage_instance.topic == "beanos" - def test_privacy_level_property(self, stage_instance): + def test_privacy_level_property(self, stage_instance: stage_instances.StageInstance): assert stage_instance.privacy_level == stage_instances.StageInstancePrivacyLevel.GUILD_ONLY - def test_discoverable_disabled_property(self, stage_instance): + def test_discoverable_disabled_property(self, stage_instance: stage_instances.StageInstance): assert stage_instance.discoverable_disabled is True - def test_guild_scheduled_event_id_property(self, stage_instance): + def test_guild_scheduled_event_id_property(self, stage_instance: stage_instances.StageInstance): assert stage_instance.scheduled_event_id == 1337 diff --git a/tests/hikari/test_stickers.py b/tests/hikari/test_stickers.py index 3ce96efa12..8a730f9133 100644 --- a/tests/hikari/test_stickers.py +++ b/tests/hikari/test_stickers.py @@ -31,18 +31,18 @@ class TestStickerPack: @pytest.fixture - def model(self): + def model(self) -> stickers.StickerPack: return stickers.StickerPack( - id=123, + id=snowflakes.Snowflake(123), name="testing", description="testing description", cover_sticker_id=snowflakes.Snowflake(6541234), stickers=[], - sku_id=123, + sku_id=snowflakes.Snowflake(123), banner_asset_id=snowflakes.Snowflake(541231), ) - def test_make_banner_url_format_set_to_deprecated_ext_argument_if_provided(self, model): + def test_make_banner_url_format_set_to_deprecated_ext_argument_if_provided(self, model: stickers.StickerPack): with mock.patch.object( routes, "CDN_STICKER_PACK_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: @@ -52,34 +52,36 @@ def test_make_banner_url_format_set_to_deprecated_ext_argument_if_provided(self, urls.CDN_URL, hash=541231, size=4096, file_format="JPEG", lossless=True ) - def test_banner_url(self, model): + def test_banner_url(self, model: stickers.StickerPack): banner = object() with mock.patch.object(stickers.StickerPack, "make_banner_url", return_value=banner): assert model.banner_url is banner - def test_make_banner_url(self, model): + def test_make_banner_url(self, model: stickers.StickerPack): with mock.patch.object( routes, "CDN_STICKER_PACK_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert model.make_banner_url(file_format="URL", size=512) == "file" + assert model.make_banner_url(file_format="PNG", size=512) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, hash=541231, size=512, file_format="URL", lossless=True + urls.CDN_URL, hash=541231, size=512, file_format="PNG", lossless=True ) - def test_make_banner_url_when_no_banner_asset(self, model): + def test_make_banner_url_when_no_banner_asset(self, model: stickers.StickerPack): model.banner_asset_id = None - assert model.make_banner_url(file_format="URL", size=512) is None + assert model.make_banner_url(file_format="PNG", size=512) is None class TestPartialSticker: @pytest.fixture - def model(self): - return stickers.PartialSticker(id=123, name="testing", format_type="some") + def model(self) -> stickers.PartialSticker: + return stickers.PartialSticker( + id=snowflakes.Snowflake(123), name="testing", format_type=stickers.StickerFormatType.PNG + ) - def test_make_url_uses_CDN_when_LOTTIE(self, model): + def test_make_url_uses_CDN_when_LOTTIE(self, model: stickers.PartialSticker): model.format_type = stickers.StickerFormatType.LOTTIE with mock.patch.object( @@ -91,7 +93,7 @@ def test_make_url_uses_CDN_when_LOTTIE(self, model): urls.CDN_URL, sticker_id=123, file_format="LOTTIE", size=4096, lossless=True ) - def test_make_url_uses_MEDIA_PROXY_when_not_LOTTIE(self, model): + def test_make_url_uses_MEDIA_PROXY_when_not_LOTTIE(self, model: stickers.PartialSticker): model.format_type = stickers.StickerFormatType.GIF with mock.patch.object( @@ -103,13 +105,15 @@ def test_make_url_uses_MEDIA_PROXY_when_not_LOTTIE(self, model): urls.MEDIA_PROXY_URL, sticker_id=123, file_format="GIF", size=4096, lossless=True ) - def test_make_url_raises_TypeError_when_GIF_sticker_requested_as_APNG(self, model): + def test_make_url_raises_TypeError_when_GIF_sticker_requested_as_APNG(self, model: stickers.PartialSticker): model.format_type = stickers.StickerFormatType.GIF with pytest.raises(TypeError): model.make_url(file_format="APNG") - def test_make_url_raises_TypeError_when_APNG_sticker_requested_as_AWEBP_or_GIF(self, model): + def test_make_url_raises_TypeError_when_APNG_sticker_requested_as_AWEBP_or_GIF( + self, model: stickers.PartialSticker + ): model.format_type = stickers.StickerFormatType.APNG with pytest.raises(TypeError): @@ -118,7 +122,9 @@ def test_make_url_raises_TypeError_when_APNG_sticker_requested_as_AWEBP_or_GIF(s with pytest.raises(TypeError): model.make_url(file_format="GIF") - def test_make_url_raises_TypeError_when_PNG_sticker_requested_as_animated_format(self, model): + def test_make_url_raises_TypeError_when_PNG_sticker_requested_as_animated_format( + self, model: stickers.PartialSticker + ): model.format_type = stickers.StickerFormatType.PNG with pytest.raises(TypeError): @@ -130,19 +136,23 @@ def test_make_url_raises_TypeError_when_PNG_sticker_requested_as_animated_format with pytest.raises(TypeError): model.make_url(file_format="GIF") - def test_make_url_raises_TypeError_when_LOTTIE_sticker_requested_as_non_LOTTIE_format(self, model): + def test_make_url_raises_TypeError_when_LOTTIE_sticker_requested_as_non_LOTTIE_format( + self, model: stickers.PartialSticker + ): model.format_type = stickers.StickerFormatType.LOTTIE with pytest.raises(TypeError): model.make_url(file_format="PNG") - def test_make_url_raises_TypeError_when_non_LOTTIE_sticker_requested_as_LOTTIE(self, model): + def test_make_url_raises_TypeError_when_non_LOTTIE_sticker_requested_as_LOTTIE( + self, model: stickers.PartialSticker + ): model.format_type = stickers.StickerFormatType.PNG with pytest.raises(TypeError): model.make_url(file_format="LOTTIE") - def test_make_url_applies_correct_settings_for_APNG(self, model): + def test_make_url_applies_correct_settings_for_APNG(self, model: stickers.PartialSticker): model.format_type = stickers.StickerFormatType.APNG with mock.patch.object( @@ -154,7 +164,7 @@ def test_make_url_applies_correct_settings_for_APNG(self, model): urls.MEDIA_PROXY_URL, sticker_id=123, file_format="PNG", size=4096, lossless=True ) - def test_make_url_applies_correct_settings_for_AWEBP(self, model): + def test_make_url_applies_correct_settings_for_AWEBP(self, model: stickers.PartialSticker): model.format_type = stickers.StickerFormatType.GIF with mock.patch.object( @@ -166,7 +176,7 @@ def test_make_url_applies_correct_settings_for_AWEBP(self, model): urls.MEDIA_PROXY_URL, sticker_id=123, file_format="AWEBP", size=4096, lossless=True ) - def test_make_url_applies_correct_settings_for_WEBP_lossless(self, model): + def test_make_url_applies_correct_settings_for_WEBP_lossless(self, model: stickers.PartialSticker): model.format_type = stickers.StickerFormatType.PNG with mock.patch.object( @@ -178,7 +188,7 @@ def test_make_url_applies_correct_settings_for_WEBP_lossless(self, model): urls.MEDIA_PROXY_URL, sticker_id=123, file_format="WEBP", size=4096, lossless=True ) - def test_make_url_applies_correct_settings_for_WEBP_lossy(self, model): + def test_make_url_applies_correct_settings_for_WEBP_lossy(self, model: stickers.PartialSticker): model.format_type = stickers.StickerFormatType.PNG with mock.patch.object( @@ -190,7 +200,7 @@ def test_make_url_applies_correct_settings_for_WEBP_lossy(self, model): urls.MEDIA_PROXY_URL, sticker_id=123, file_format="WEBP", size=4096, lossless=False ) - def test_make_url_applies_no_extra_settings_for_non_special_formats(self, model): + def test_make_url_applies_no_extra_settings_for_non_special_formats(self, model: stickers.PartialSticker): model.format_type = stickers.StickerFormatType.PNG with mock.patch.object( diff --git a/tests/hikari/test_templates.py b/tests/hikari/test_templates.py index 300e201144..225196dbbe 100644 --- a/tests/hikari/test_templates.py +++ b/tests/hikari/test_templates.py @@ -28,22 +28,22 @@ class TestTemplate: @pytest.fixture - def obj(self): + def obj(self) -> templates.Template: return templates.Template( app=mock.Mock(), code="abc123", name="Test Template", description="Template used for testing", usage_count=101, - creator=object(), - created_at=object(), - updated_at=object(), + creator=mock.Mock(), + created_at=mock.Mock(), + updated_at=mock.Mock(), source_guild=mock.Mock(id=123), is_unsynced=True, ) @pytest.mark.asyncio - async def test_fetch_self(self, obj): + async def test_fetch_self(self, obj: templates.Template): obj.app.rest.fetch_template = mock.AsyncMock() assert await obj.fetch_self() is obj.app.rest.fetch_template.return_value @@ -51,7 +51,7 @@ async def test_fetch_self(self, obj): obj.app.rest.fetch_template.assert_awaited_once_with("abc123") @pytest.mark.asyncio - async def test_edit(self, obj): + async def test_edit(self, obj: templates.Template): obj.app.rest.edit_template = mock.AsyncMock() returned = await obj.edit(name="Test Template 2", description="Electric Boogaloo") @@ -62,7 +62,7 @@ async def test_edit(self, obj): ) @pytest.mark.asyncio - async def test_delete(self, obj): + async def test_delete(self, obj: templates.Template): obj.app.rest.delete_template = mock.AsyncMock() await obj.delete() @@ -70,12 +70,12 @@ async def test_delete(self, obj): obj.app.rest.delete_template.assert_awaited_once_with(obj.source_guild, obj) @pytest.mark.asyncio - async def test_sync(self, obj): + async def test_sync(self, obj: templates.Template): obj.app.rest.sync_guild_template = mock.AsyncMock() assert await obj.sync() is obj.app.rest.sync_guild_template.return_value obj.app.rest.sync_guild_template.assert_awaited_once_with(123, "abc123") - def test_str(self, obj): + def test_str(self, obj: templates.Template): assert str(obj) == "https://discord.new/abc123" diff --git a/tests/hikari/test_undefined.py b/tests/hikari/test_undefined.py index 7ee9ad0a1a..a353a7ba43 100644 --- a/tests/hikari/test_undefined.py +++ b/tests/hikari/test_undefined.py @@ -77,7 +77,7 @@ def test__getstate__(self): ((undefined.UNDEFINED, undefined.UNDEFINED, undefined.UNDEFINED, 34123), False), ], ) -def test_all_undefined(values, result): +def test_all_undefined(values: tuple[str | undefined.UndefinedType | int, ...], result: bool): assert undefined.all_undefined(*values) is result @@ -92,7 +92,7 @@ def test_all_undefined(values, result): ((undefined.UNDEFINED, 34123, 5432123, "312", False), True), ], ) -def test_any_undefined(values, result): +def test_any_undefined(values: tuple[str | undefined.UndefinedType | int, ...], result: bool): assert undefined.any_undefined(*values) is result @@ -107,5 +107,5 @@ def test_any_undefined(values, result): ((undefined.UNDEFINED, "32123123", undefined.UNDEFINED, 34123123, undefined.UNDEFINED), 3), ], ) -def test_count(values, result): +def test_count(values: tuple[str | undefined.UndefinedType | int, ...], result: bool): assert undefined.count(*values) == result diff --git a/tests/hikari/test_users.py b/tests/hikari/test_users.py index e1cf61598b..e86e72657f 100644 --- a/tests/hikari/test_users.py +++ b/tests/hikari/test_users.py @@ -23,360 +23,530 @@ import mock import pytest +from hikari import colors from hikari import snowflakes from hikari import traits from hikari import undefined from hikari import urls from hikari import users from hikari.internal import routes -from tests.hikari import hikari_test_helpers class TestPartialUser: + class MockedPartialUser(users.PartialUser): + def __init__(self, app: traits.RESTAware): + self._app = app + self._id = snowflakes.Snowflake(12) + self._avatar_hash = "avatar_hash" + self._banner_hash = "banner_hash" + self._accent_color = colors.Color.from_hex_code("FFB123") + self._discriminator = "discriminator" + self._username = "username" + self._global_name = "global_name" + self._display_name = "display_name" + self._avatar_decoration = users.AvatarDecoration( + asset_hash="avatar_decoration_asset_hash", sku_id=snowflakes.Snowflake(999), expires_at=None + ) + self._is_bot = False + self._is_system = False + self._flags = users.UserFlag.NONE + self._mention = "mention" + self._primary_guild = users.PrimaryGuild( + identity_guild_id=snowflakes.Snowflake(1234), identity_enabled=True, tag="HKRI", badge_hash="1234" + ) + + @property + def app(self) -> traits.RESTAware: + return self._app + + @property + def id(self) -> snowflakes.Snowflake: + return self._id + + @property + def avatar_hash(self) -> undefined.UndefinedNoneOr[str]: + return self._avatar_hash + + @property + def banner_hash(self) -> undefined.UndefinedNoneOr[str]: + return self._banner_hash + + @property + def accent_color(self) -> undefined.UndefinedNoneOr[colors.Color]: + return self._accent_color + + @property + def discriminator(self) -> undefined.UndefinedOr[str]: + return self._discriminator + + @property + def username(self) -> undefined.UndefinedOr[str]: + return self._username + + @property + def global_name(self) -> undefined.UndefinedNoneOr[str]: + return self._global_name + + @property + def display_name(self) -> undefined.UndefinedNoneOr[str]: + return self._display_name + + @property + def avatar_decoration(self) -> users.AvatarDecoration | None: + return self._avatar_decoration + + @property + def is_bot(self) -> undefined.UndefinedOr[bool]: + return self._is_bot + + @property + def is_system(self) -> undefined.UndefinedOr[bool]: + return self._is_system + + @property + def flags(self) -> undefined.UndefinedOr[users.UserFlag]: + return self._flags + + @property + def mention(self) -> str: + return self._mention + + @property + def primary_guild(self) -> users.PrimaryGuild | None: + return self._primary_guild + @pytest.fixture - def obj(self): + def partial_user(self, hikari_app: traits.RESTAware) -> users.PartialUser: # ABC, so must be stubbed. - return hikari_test_helpers.mock_class_namespace(users.PartialUser, slots_=False)() - - def test_accent_colour_alias_property(self, obj): - obj.accent_color = object() + return TestPartialUser.MockedPartialUser(hikari_app) - assert obj.accent_colour is obj.accent_color + def test_accent_colour_alias_property(self, partial_user: users.PartialUser): + with mock.patch.object(partial_user, "_accent_color", mock.Mock): + assert partial_user.accent_colour is partial_user.accent_color @pytest.mark.asyncio - async def test_fetch_self(self, obj): - obj.id = 123 - obj.app = mock.AsyncMock() - - assert await obj.fetch_self() is obj.app.rest.fetch_user.return_value - obj.app.rest.fetch_user.assert_awaited_once_with(user=123) + async def test_fetch_self(self, partial_user: users.PartialUser): + with ( + mock.patch.object(partial_user, "_id", snowflakes.Snowflake(123)), + mock.patch.object(partial_user.app.rest, "fetch_user", new_callable=mock.AsyncMock) as patched_fetch_user, + ): + assert await partial_user.fetch_self() is patched_fetch_user.return_value + patched_fetch_user.assert_awaited_once_with(user=123) @pytest.mark.asyncio - async def test_send_uses_cached_id(self, obj): - obj.id = 4123123 - embed = object() - embeds = [object()] - attachment = object() - attachments = [object(), object()] - component = object() - components = [object(), object()] - user_mentions = [object(), object()] - role_mentions = [object(), object()] - reply = object() - mentions_reply = object() - - obj.app = mock.Mock(spec=traits.CacheAware, rest=mock.AsyncMock()) - obj.fetch_dm_channel = mock.AsyncMock() - - returned = await obj.send( - content="test", - embed=embed, - embeds=embeds, - attachment=attachment, - attachments=attachments, - component=component, - components=components, - tts=True, - reply=reply, - reply_must_exist=False, - mentions_everyone=False, - user_mentions=user_mentions, - role_mentions=role_mentions, - mentions_reply=mentions_reply, - flags=34123342, - ) - - assert returned is obj.app.rest.create_message.return_value - - obj.app.cache.get_dm_channel_id.assert_called_once_with(4123123) - obj.fetch_dm_channel.assert_not_called() - obj.app.rest.create_message.assert_awaited_once_with( - channel=obj.app.cache.get_dm_channel_id.return_value, - content="test", - embed=embed, - embeds=embeds, - attachment=attachment, - attachments=attachments, - component=component, - components=components, - tts=True, - mentions_everyone=False, - reply_must_exist=False, - reply=reply, - user_mentions=user_mentions, - role_mentions=role_mentions, - mentions_reply=mentions_reply, - flags=34123342, - ) + async def test_send_uses_cached_id(self, partial_user: users.PartialUser): + embed = mock.Mock() + embeds = [mock.Mock()] + attachment = mock.Mock() + attachments = [mock.Mock(), mock.Mock()] + component = mock.Mock() + components = [mock.Mock(), mock.Mock()] + user_mentions = [mock.Mock(), mock.Mock()] + role_mentions = [mock.Mock(), mock.Mock()] + reply = mock.Mock() + mentions_reply = mock.Mock() + + # partial_user.fetch_dm_channel = mock.AsyncMock() + + with ( + mock.patch.object(partial_user, "fetch_dm_channel", new=mock.AsyncMock()) as patched_fetch_dm_channel, + mock.patch.object( + partial_user, "_app", mock.Mock(spec=traits.CacheAware, rest=mock.AsyncMock()) + ) as patched_app, + mock.patch.object(patched_app.cache, "get_dm_channel_id") as patched_get_dm_channel_id, + mock.patch.object(patched_app.rest, "create_message") as patched_create_message, + ): + returned = await partial_user.send( + content="test", + embed=embed, + embeds=embeds, + attachment=attachment, + attachments=attachments, + component=component, + components=components, + tts=True, + reply=reply, + reply_must_exist=False, + mentions_everyone=False, + user_mentions=user_mentions, + role_mentions=role_mentions, + mentions_reply=mentions_reply, + flags=34123342, + ) + + assert returned is patched_create_message.return_value + + patched_get_dm_channel_id.assert_called_once_with(12) + patched_fetch_dm_channel.assert_not_called() + patched_create_message.assert_awaited_once_with( + channel=patched_get_dm_channel_id.return_value, + content="test", + embed=embed, + embeds=embeds, + attachment=attachment, + attachments=attachments, + component=component, + components=components, + tts=True, + mentions_everyone=False, + reply_must_exist=False, + reply=reply, + user_mentions=user_mentions, + role_mentions=role_mentions, + mentions_reply=mentions_reply, + flags=34123342, + ) @pytest.mark.asyncio - async def test_send_when_not_cached(self, obj): - obj.id = 522234 - obj.app = mock.Mock(spec=traits.CacheAware, rest=mock.AsyncMock()) - obj.app.cache.get_dm_channel_id = mock.Mock(return_value=None) - obj.fetch_dm_channel = mock.AsyncMock() - - returned = await obj.send() - - assert returned is obj.app.rest.create_message.return_value - - obj.app.cache.get_dm_channel_id.assert_called_once_with(522234) - obj.fetch_dm_channel.assert_awaited_once() - obj.app.rest.create_message.assert_awaited_once_with( - channel=obj.fetch_dm_channel.return_value.id, - content=undefined.UNDEFINED, - embed=undefined.UNDEFINED, - embeds=undefined.UNDEFINED, - attachment=undefined.UNDEFINED, - attachments=undefined.UNDEFINED, - component=undefined.UNDEFINED, - components=undefined.UNDEFINED, - tts=undefined.UNDEFINED, - mentions_everyone=undefined.UNDEFINED, - reply=undefined.UNDEFINED, - reply_must_exist=undefined.UNDEFINED, - user_mentions=undefined.UNDEFINED, - role_mentions=undefined.UNDEFINED, - mentions_reply=undefined.UNDEFINED, - flags=undefined.UNDEFINED, - ) + async def test_send_when_not_cached(self, partial_user: users.PartialUser): + with ( + mock.patch.object( + partial_user, "_app", mock.Mock(spec=traits.CacheAware, rest=mock.AsyncMock()) + ) as patched_app, + mock.patch.object( + patched_app.cache, "get_dm_channel_id", new=mock.Mock(return_value=None) + ) as patched_get_dm_channel_id, + mock.patch.object(patched_app.rest, "create_message") as patched_create_message, + mock.patch.object(partial_user, "fetch_dm_channel", new=mock.AsyncMock()) as patched_fetch_dm_channel, + ): + returned = await partial_user.send() + + assert returned is patched_create_message.return_value + + patched_get_dm_channel_id.assert_called_once_with(12) + patched_fetch_dm_channel.assert_awaited_once() + patched_create_message.assert_awaited_once_with( + channel=patched_fetch_dm_channel.return_value.id, + content=undefined.UNDEFINED, + embed=undefined.UNDEFINED, + embeds=undefined.UNDEFINED, + attachment=undefined.UNDEFINED, + attachments=undefined.UNDEFINED, + component=undefined.UNDEFINED, + components=undefined.UNDEFINED, + tts=undefined.UNDEFINED, + mentions_everyone=undefined.UNDEFINED, + reply=undefined.UNDEFINED, + reply_must_exist=undefined.UNDEFINED, + user_mentions=undefined.UNDEFINED, + role_mentions=undefined.UNDEFINED, + mentions_reply=undefined.UNDEFINED, + flags=undefined.UNDEFINED, + ) @pytest.mark.asyncio - async def test_send_when_not_cache_aware(self, obj): - obj.id = 522234 - obj.app = mock.Mock(spec=traits.RESTAware, rest=mock.AsyncMock()) - obj.fetch_dm_channel = mock.AsyncMock() - - returned = await obj.send() - - assert returned is obj.app.rest.create_message.return_value - - obj.fetch_dm_channel.assert_awaited_once() - obj.app.rest.create_message.assert_awaited_once_with( - channel=obj.fetch_dm_channel.return_value.id, - content=undefined.UNDEFINED, - embed=undefined.UNDEFINED, - embeds=undefined.UNDEFINED, - attachment=undefined.UNDEFINED, - attachments=undefined.UNDEFINED, - component=undefined.UNDEFINED, - components=undefined.UNDEFINED, - tts=undefined.UNDEFINED, - mentions_everyone=undefined.UNDEFINED, - reply=undefined.UNDEFINED, - reply_must_exist=undefined.UNDEFINED, - user_mentions=undefined.UNDEFINED, - role_mentions=undefined.UNDEFINED, - mentions_reply=undefined.UNDEFINED, - flags=undefined.UNDEFINED, - ) + async def test_send_when_not_cache_aware(self, partial_user: users.PartialUser): + with ( + mock.patch.object(partial_user, "_id", snowflakes.Snowflake(522234)), + mock.patch.object( + partial_user, "fetch_dm_channel", new_callable=mock.AsyncMock + ) as patched_fetch_dm_channel, + mock.patch.object( + partial_user.app.rest, "create_message", new_callable=mock.AsyncMock + ) as patched_create_message, + ): + returned = await partial_user.send() + + assert returned is patched_create_message.return_value + + patched_fetch_dm_channel.assert_awaited_once() + patched_create_message.assert_awaited_once_with( + channel=patched_fetch_dm_channel.return_value.id, + content=undefined.UNDEFINED, + embed=undefined.UNDEFINED, + embeds=undefined.UNDEFINED, + attachment=undefined.UNDEFINED, + attachments=undefined.UNDEFINED, + component=undefined.UNDEFINED, + components=undefined.UNDEFINED, + tts=undefined.UNDEFINED, + mentions_everyone=undefined.UNDEFINED, + reply=undefined.UNDEFINED, + reply_must_exist=undefined.UNDEFINED, + user_mentions=undefined.UNDEFINED, + role_mentions=undefined.UNDEFINED, + mentions_reply=undefined.UNDEFINED, + flags=undefined.UNDEFINED, + ) @pytest.mark.asyncio - async def test_fetch_dm_channel(self, obj): - obj.id = 123 - obj.app = mock.Mock() - obj.app.rest.create_dm_channel = mock.AsyncMock() - - assert await obj.fetch_dm_channel() == obj.app.rest.create_dm_channel.return_value + async def test_fetch_dm_channel(self, partial_user: users.PartialUser): + with ( + mock.patch.object(partial_user, "_id", snowflakes.Snowflake(123)), + mock.patch.object( + partial_user.app.rest, "create_dm_channel", new_callable=mock.AsyncMock + ) as patched_create_dm_channel, + ): + assert await partial_user.fetch_dm_channel() == patched_create_dm_channel.return_value - obj.app.rest.create_dm_channel.assert_awaited_once_with(123) + patched_create_dm_channel.assert_awaited_once_with(123) class TestUser: + class MockedUser(users.User): + def __init__(self, app: traits.RESTAware): + self._app = app + self._id = snowflakes.Snowflake(12) + self._avatar_hash = "avatar_hash" + self._banner_hash = "banner_hash" + self._accent_color = colors.Color.from_hex_code("FFB123") + self._discriminator = "discriminator" + self._username = "username" + self._global_name = "global_name" + self._global_name = "global_name" + self._display_name = "display_name" + self._avatar_decoration = users.AvatarDecoration( + asset_hash="avatar_decoration_asset_hash", sku_id=snowflakes.Snowflake(999), expires_at=None + ) + self._is_bot = False + self._is_system = False + self._flags = users.UserFlag.NONE + self._mention = "mention" + self._primary_guild = users.PrimaryGuild( + identity_guild_id=snowflakes.Snowflake(123), identity_enabled=True, tag="HKRI", badge_hash="amogus" + ) + + @property + def app(self) -> traits.RESTAware: + return self._app + + @property + def id(self) -> snowflakes.Snowflake: + return self._id + + @property + def avatar_hash(self) -> str: + return self._avatar_hash + + @property + def banner_hash(self) -> str: + return self._banner_hash + + @property + def accent_color(self) -> colors.Color: + return self._accent_color + + @property + def discriminator(self) -> str: + return self._discriminator + + @property + def username(self) -> str: + return self._username + + @property + def global_name(self) -> str: + return self._global_name + + @property + def display_name(self) -> str: + return "display_name" + + @property + def avatar_decoration(self) -> users.AvatarDecoration | None: + return self._avatar_decoration + + @property + def is_bot(self) -> bool: + return self._is_bot + + @property + def is_system(self) -> bool: + return self._is_system + + @property + def flags(self) -> users.UserFlag: + return self._flags + + @property + def mention(self) -> str: + return self._mention + + @property + def primary_guild(self) -> users.PrimaryGuild | None: + return self._primary_guild + @pytest.fixture - def obj(self): + def user(self, hikari_app: traits.RESTAware) -> users.User: # ABC, so must be stubbed. - return hikari_test_helpers.mock_class_namespace(users.User, slots_=False)() - - def test_accent_colour_alias_property(self, obj): - obj.accent_color = object() - - assert obj.accent_colour is obj.accent_color + return TestUser.MockedUser(hikari_app) - def test_avatar_decoration_property(self, obj): - obj.avatar_decoration = users.AvatarDecoration( - asset_hash="18dnf8dfbakfdh", sku_id=snowflakes.Snowflake(123), expires_at=None - ) + def test_accent_colour_alias_property(self, user: users.User): + assert user.accent_colour is user.accent_color + def test_avatar_decoration_property(self, user: users.User): with mock.patch.object(users.AvatarDecoration, "make_url") as make_url: - assert obj.avatar_decoration.url is make_url.return_value - - def test_avatar_decoration_make_url_with_all_args(self, obj): - obj.avatar_decoration = users.AvatarDecoration( - asset_hash="18dnf8dfbakfdh", sku_id=snowflakes.Snowflake(123), expires_at=None - ) + assert user.avatar_decoration is not None + assert user.avatar_decoration.url is make_url.return_value + def test_avatar_decoration_make_url_with_all_args(self, user: users.User): with mock.patch.object( routes, "CDN_AVATAR_DECORATION", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert obj.avatar_decoration.make_url(size=4096) == "file" + assert user.avatar_decoration is not None + assert user.avatar_decoration.make_url(size=4096) == "file" route.compile_to_file.assert_called_once_with( - urls.MEDIA_PROXY_URL, hash=obj.avatar_decoration.asset_hash, size=4096, file_format="PNG", lossless=True + urls.MEDIA_PROXY_URL, hash=user.avatar_decoration.asset_hash, size=4096, file_format="PNG", lossless=True ) - def test_make_avatar_url_format_set_to_deprecated_ext_argument_if_provided(self, obj): - obj.id = 123321 - obj.avatar_hash = "fofoof" - + def test_make_avatar_url_format_set_to_deprecated_ext_argument_if_provided(self, user: users.User): with mock.patch.object( routes, "CDN_USER_AVATAR", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert obj.make_avatar_url(ext="JPEG") == "file" + assert user.make_avatar_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, user_id=123321, hash="fofoof", size=4096, file_format="JPEG", lossless=True + urls.CDN_URL, user_id=12, hash="avatar_hash", size=4096, file_format="JPEG", lossless=True ) - def test_avatar_url_property(self, obj): + def test_avatar_url_property(self, user: users.User): with mock.patch.object(users.User, "make_avatar_url") as make_avatar_url: - assert obj.avatar_url is make_avatar_url.return_value - - def test_make_avatar_url_when_no_hash(self, obj): - obj.avatar_hash = None - assert obj.make_avatar_url(file_format="PNG", size=1024) is None + assert user.avatar_url is make_avatar_url.return_value - def test_make_avatar_url_when_format_is_None_and_avatar_hash_is_for_gif(self, obj): - obj.avatar_hash = "a_18dnf8dfbakfdh" + def test_make_avatar_url_when_no_hash(self, user: users.User): + with mock.patch.object(user, "_avatar_hash", None): + assert user.make_avatar_url(file_format="PNG", size=1024) is None - with mock.patch.object( - routes, "CDN_USER_AVATAR", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) - ) as route: - assert obj.make_avatar_url(file_format=None, size=4096) == "file" + def test_make_avatar_url_when_format_is_None_and_avatar_hash_is_for_gif(self, user: users.User): + with ( + mock.patch.object(user, "_avatar_hash", "a_18dnf8dfbakfdh"), + mock.patch.object( + routes, "CDN_USER_AVATAR", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) + ) as route, + ): + assert user.make_avatar_url(file_format="GIF", size=4096) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, user_id=obj.id, hash="a_18dnf8dfbakfdh", size=4096, file_format="GIF", lossless=True + urls.CDN_URL, user_id=user.id, hash="a_18dnf8dfbakfdh", size=4096, file_format="GIF", lossless=True ) - def test_make_avatar_url_when_format_is_None_and_avatar_hash_is_not_for_gif(self, obj): - obj.avatar_hash = "18dnf8dfbakfdh" - + def test_make_avatar_url_when_format_is_None_and_avatar_hash_is_not_for_gif(self, user: users.User): with mock.patch.object( routes, "CDN_USER_AVATAR", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert obj.make_avatar_url(file_format=None, size=4096) == "file" + assert user.make_avatar_url(file_format="PNG", size=4096) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, user_id=obj.id, hash=obj.avatar_hash, size=4096, file_format="PNG", lossless=True + urls.CDN_URL, user_id=user.id, hash=user.avatar_hash, size=4096, file_format="PNG", lossless=True ) - def test_make_avatar_url_with_all_args(self, obj): - obj.avatar_hash = "18dnf8dfbakfdh" - + def test_make_avatar_url_with_all_args(self, user: users.User): with mock.patch.object( routes, "CDN_USER_AVATAR", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert obj.make_avatar_url(file_format="URL", size=4096) == "file" + assert user.make_avatar_url(file_format="JPG", size=4096) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, user_id=obj.id, hash=obj.avatar_hash, size=4096, file_format="URL", lossless=True + urls.CDN_URL, user_id=user.id, hash=user.avatar_hash, size=4096, file_format="JPG", lossless=True ) - def test_display_avatar_url_when_avatar_url(self, obj): + def test_display_avatar_url_when_avatar_url(self, user: users.User): with mock.patch.object(users.User, "make_avatar_url") as mock_make_avatar_url: - assert obj.display_avatar_url is mock_make_avatar_url.return_value + assert user.display_avatar_url is mock_make_avatar_url.return_value - def test_display_avatar_url_when_no_avatar_url(self, obj): - with mock.patch.object(users.User, "make_avatar_url", return_value=None): - with mock.patch.object(users.User, "default_avatar_url") as mock_default_avatar_url: - assert obj.display_avatar_url is mock_default_avatar_url + def test_display_avatar_url_when_no_avatar_url(self, user: users.User): + with ( + mock.patch.object(users.User, "make_avatar_url", return_value=None), + mock.patch.object(users.User, "default_avatar_url") as mock_default_avatar_url, + ): + assert user.display_avatar_url is mock_default_avatar_url - def test_display_banner_url_when_banner_url(self, obj): + def test_display_banner_url_when_banner_url(self, user: users.User): with mock.patch.object(users.User, "make_banner_url") as mock_make_banner_url: - assert obj.display_banner_url is mock_make_banner_url.return_value + assert user.display_banner_url is mock_make_banner_url.return_value - def test_display_banner_url_when_no_banner_url(self, obj): + def test_display_banner_url_when_no_banner_url(self, user: users.User): with mock.patch.object(users.User, "make_banner_url", return_value=None): - assert obj.display_banner_url is None - - def test_default_avatar(self, obj): - obj.avatar_hash = "18dnf8dfbakfdh" - obj.discriminator = "1234" - - with mock.patch.object( - routes, "CDN_DEFAULT_USER_AVATAR", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) - ) as route: - assert obj.default_avatar_url == "file" - - route.compile_to_file.assert_called_once_with(urls.CDN_URL, style=4, file_format="PNG") - - def test_default_avatar_for_migrated_users(self, obj): - obj.id = 377812572784820226 - obj.avatar_hash = "18dnf8dfbakfdh" - obj.discriminator = "0" - - with mock.patch.object( - routes, "CDN_DEFAULT_USER_AVATAR", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) - ) as route: - assert obj.default_avatar_url == "file" - - route.compile_to_file.assert_called_once_with(urls.CDN_URL, style=0, file_format="PNG") - - def test_make_banner_url_format_set_to_deprecated_ext_argument_if_provided(self, obj): - obj.id = 123321 - obj.banner_hash = "fofoof" - + assert user.display_banner_url is None + + def test_default_avatar(self, user: users.User): + with ( + mock.patch.object(user, "_id", 377812572784820226), + mock.patch.object(user, "_discriminator", "1234"), + mock.patch.object(routes, "CDN_DEFAULT_USER_AVATAR") as patched_route, + mock.patch.object( + patched_route, "compile_to_file", new=mock.Mock(return_value="file") + ) as patched_compile_to_file, + ): + assert user.default_avatar_url == "file" + + patched_compile_to_file.assert_called_once_with(urls.CDN_URL, style=4, file_format="PNG") + + def test_default_avatar_for_migrated_users(self, user: users.User): + with ( + mock.patch.object(user, "_id", 377812572784820226), + mock.patch.object(user, "_discriminator", "0"), + mock.patch.object(routes, "CDN_DEFAULT_USER_AVATAR") as patched_route, + mock.patch.object( + patched_route, "compile_to_file", new=mock.Mock(return_value="file") + ) as patched_compile_to_file, + ): + assert user.default_avatar_url == "file" + + patched_compile_to_file.assert_called_once_with(urls.CDN_URL, style=0, file_format="PNG") + + def test_make_banner_url_format_set_to_deprecated_ext_argument_if_provided(self, user: users.User): with mock.patch.object( routes, "CDN_USER_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert obj.make_banner_url(ext="JPEG") == "file" + assert user.make_banner_url(ext="JPEG") == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, user_id=123321, hash="fofoof", size=4096, file_format="JPEG", lossless=True + urls.CDN_URL, user_id=12, hash="banner_hash", size=4096, file_format="JPEG", lossless=True ) - def test_banner_url_property(self, obj): + def test_banner_url_property(self, user: users.User): with mock.patch.object(users.User, "make_banner_url") as make_banner_url: - assert obj.banner_url is make_banner_url.return_value - - def test_make_banner_url_when_no_hash(self, obj): - obj.banner_hash = None - - with mock.patch.object(routes, "CDN_USER_BANNER") as route: - assert obj.make_banner_url(file_format=None, size=4096) is None + assert user.banner_url is make_banner_url.return_value - route.compile_to_file.assert_not_called() + def test_make_banner_url_when_no_hash(self, user: users.User): + with mock.patch.object(user, "_banner_hash", None), mock.patch.object(routes, "CDN_USER_BANNER"): + assert user.make_banner_url(file_format="JPG", size=4096) is None - def test_make_banner_url_when_format_is_None_and_banner_hash_is_for_gif(self, obj): - obj.banner_hash = "a_18dnf8dfbakfdh" - - with mock.patch.object( - routes, "CDN_USER_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) - ) as route: - assert obj.make_banner_url(file_format=None, size=4096) == "file" + def test_make_banner_url_when_format_is_None_and_banner_hash_is_for_gif(self, user: users.User): + with ( + mock.patch.object(user, "_banner_hash", "a_18dnf8dfbakfdh"), + mock.patch.object( + routes, "CDN_USER_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) + ) as route, + ): + assert user.make_banner_url(file_format="GIF", size=4096) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, user_id=obj.id, hash="a_18dnf8dfbakfdh", size=4096, file_format="GIF", lossless=True + urls.CDN_URL, user_id=user.id, hash="a_18dnf8dfbakfdh", size=4096, file_format="GIF", lossless=True ) - def test_make_banner_url_when_format_is_None_and_banner_hash_is_not_for_gif(self, obj): - obj.banner_hash = "18dnf8dfbakfdh" - - with mock.patch.object( - routes, "CDN_USER_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) - ) as route: - assert obj.make_banner_url(file_format=None, size=4096) == "file" + def test_make_banner_url_when_format_is_None_and_banner_hash_is_not_for_gif(self, user: users.User): + with ( + mock.patch.object(user, "_banner_hash", "banner_hash"), + mock.patch.object( + routes, "CDN_USER_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) + ) as route, + ): + assert user.make_banner_url(file_format="PNG", size=4096) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, user_id=obj.id, hash=obj.banner_hash, size=4096, file_format="PNG", lossless=True + urls.CDN_URL, user_id=user.id, hash=user.banner_hash, size=4096, file_format="PNG", lossless=True ) - def test_make_banner_url_with_all_args(self, obj): - obj.banner_hash = "18dnf8dfbakfdh" - - with mock.patch.object( - routes, "CDN_USER_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) - ) as route: - assert obj.make_banner_url(file_format="URL", size=4096) == "file" + def test_make_banner_url_with_all_args(self, user: users.User): + with ( + mock.patch.object(user, "_banner_hash", "banner_hash"), + mock.patch.object( + routes, "CDN_USER_BANNER", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) + ) as route, + ): + assert user.make_banner_url(file_format="PNG", size=4096) == "file" route.compile_to_file.assert_called_once_with( - urls.CDN_URL, user_id=obj.id, hash=obj.banner_hash, size=4096, file_format="URL", lossless=True + urls.CDN_URL, user_id=user.id, hash=user.banner_hash, size=4096, file_format="PNG", lossless=True ) class TestPartialUserImpl: @pytest.fixture - def obj(self): + def partial_user(self) -> users.PartialUserImpl: return users.PartialUserImpl( id=snowflakes.Snowflake(123), app=mock.Mock(), @@ -393,36 +563,36 @@ def obj(self): primary_guild=None, ) - def test_str_operator(self, obj): - assert str(obj) == "thomm.o#8637" + def test_str_operator(self, partial_user: users.PartialUserImpl): + assert str(partial_user) == "thomm.o#8637" - def test_str_operator_when_partial(self, obj): - obj.username = undefined.UNDEFINED - assert str(obj) == "Partial user ID 123" + def test_str_operator_when_partial(self, partial_user: users.PartialUserImpl): + partial_user.username = undefined.UNDEFINED + assert str(partial_user) == "Partial user ID 123" - def test_mention_property(self, obj): - assert obj.mention == "<@123>" + def test_mention_property(self, partial_user: users.PartialUserImpl): + assert partial_user.mention == "<@123>" - def test_display_name_property_when_global_name(self, obj): - obj.global_name = "Thommo" - assert obj.display_name == obj.global_name + def test_display_name_property_when_global_name(self, partial_user: users.PartialUserImpl): + partial_user.global_name = "Thommo" + assert partial_user.display_name == partial_user.global_name - def test_display_name_property_when_no_global_name(self, obj): - obj.global_name = None - assert obj.display_name == obj.username + def test_display_name_property_when_no_global_name(self, partial_user: users.PartialUserImpl): + partial_user.global_name = None + assert partial_user.display_name == partial_user.username @pytest.mark.asyncio - async def test_fetch_self(self, obj): - user = object() - obj.app.rest.fetch_user = mock.AsyncMock(return_value=user) - assert await obj.fetch_self() is user - obj.app.rest.fetch_user.assert_awaited_once_with(user=123) + async def test_fetch_self(self, partial_user: users.PartialUserImpl): + user = mock.Mock() + partial_user.app.rest.fetch_user = mock.AsyncMock(return_value=user) + assert await partial_user.fetch_self() is user + partial_user.app.rest.fetch_user.assert_awaited_once_with(user=123) @pytest.mark.asyncio class TestOwnUser: @pytest.fixture - def obj(self): + def own_user(self) -> users.OwnUser: return users.OwnUser( id=snowflakes.Snowflake(12345), app=mock.Mock(), @@ -432,7 +602,7 @@ def obj(self): avatar_decoration=None, avatar_hash="69420", banner_hash="42069", - accent_color=123456, + accent_color=colors.Color(123456), is_bot=False, is_system=False, flags=users.UserFlag.PARTNERED_SERVER_OWNER, @@ -444,43 +614,43 @@ def obj(self): primary_guild=None, ) - async def test_fetch_self(self, obj): - user = object() - obj.app.rest.fetch_my_user = mock.AsyncMock(return_value=user) - assert await obj.fetch_self() is user - obj.app.rest.fetch_my_user.assert_awaited_once_with() + async def test_fetch_self(self, own_user: users.OwnUser): + user = mock.Mock() + own_user.app.rest.fetch_my_user = mock.AsyncMock(return_value=user) + assert await own_user.fetch_self() is user + own_user.app.rest.fetch_my_user.assert_awaited_once_with() - async def test_fetch_dm_channel(self, obj): + async def test_fetch_dm_channel(self, own_user: users.OwnUser): with pytest.raises(TypeError, match=r"Unable to fetch your own DM channel"): - await obj.fetch_dm_channel() + await own_user.fetch_dm_channel() - async def test_send(self, obj): + async def test_send(self, own_user: users.OwnUser): with pytest.raises(TypeError, match=r"Unable to send a DM to yourself"): - await obj.send() + await own_user.send() class TestPrimaryGuild: @pytest.fixture - def obj(self): + def primary_guild(self) -> users.PrimaryGuild: return users.PrimaryGuild( identity_guild_id=snowflakes.Snowflake(1234), identity_enabled=True, tag="HKRI", badge_hash="abcd1234" ) - def test_make_url(self, obj): + def test_make_url(self, primary_guild: users.PrimaryGuild): with mock.patch.object( routes, "CDN_PRIMARY_GUILD_BADGE", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert obj.make_url() == "file" + assert primary_guild.make_url() == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, guild_id=1234, hash="abcd1234", size=4096, file_format="PNG", lossless=True ) - def test_make_url_with_all_args(self, obj): + def test_make_url_with_all_args(self, primary_guild: users.PrimaryGuild): with mock.patch.object( routes, "CDN_PRIMARY_GUILD_BADGE", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: - assert obj.make_url(file_format="WEBP", size=280, lossless=False) == "file" + assert primary_guild.make_url(file_format="WEBP", size=280, lossless=False) == "file" route.compile_to_file.assert_called_once_with( urls.CDN_URL, guild_id=1234, hash="abcd1234", size=280, file_format="WEBP", lossless=False diff --git a/tests/hikari/test_webhooks.py b/tests/hikari/test_webhooks.py index 0bf0d57649..5830e0cc28 100644 --- a/tests/hikari/test_webhooks.py +++ b/tests/hikari/test_webhooks.py @@ -20,59 +20,83 @@ # SOFTWARE. from __future__ import annotations +import typing + +from hikari.internal import routes import mock import pytest -from hikari import channels, urls +from hikari import channels, snowflakes, traits, urls from hikari import undefined from hikari import webhooks -from tests.hikari import hikari_test_helpers class TestExecutableWebhook: + class MockedExecutableWebhook(webhooks.ExecutableWebhook): + def __init__(self, app: traits.RESTAware): + super().__init__() + + self._app = app + self._webhook_id = snowflakes.Snowflake(548398213) + self._token = "webhook_token" + + @property + def app(self) -> traits.RESTAware: + return self._app + + @property + def webhook_id(self) -> snowflakes.Snowflake: + return self._webhook_id + + @property + def token(self) -> typing.Optional[str]: + return self._token + @pytest.fixture - def executable_webhook(self): - return hikari_test_helpers.mock_class_namespace( - webhooks.ExecutableWebhook, slots_=False, app=mock.AsyncMock() - )() + def executable_webhook(self, hikari_app: traits.RESTAware) -> webhooks.ExecutableWebhook: + return TestExecutableWebhook.MockedExecutableWebhook(hikari_app) @pytest.mark.asyncio - async def test_execute_when_no_token(self, executable_webhook): - executable_webhook.token = None - - with pytest.raises(ValueError, match=r"Cannot send a message using a webhook where we don't know the token"): + async def test_execute_when_no_token(self, executable_webhook: webhooks.ExecutableWebhook): + with ( + mock.patch.object(executable_webhook, "_token", None), + pytest.raises(ValueError, match=r"Cannot send a message using a webhook where we don't know the token"), + ): await executable_webhook.execute() @pytest.mark.asyncio - async def test_execute_with_optionals(self, executable_webhook): - mock_attachment_1 = object() - mock_attachment_2 = object() - mock_component = object() - mock_components = object(), object() - mock_embed = object() - mock_embeds = object(), object() - mock_poll = object() - - result = await executable_webhook.execute( - content="coooo", - username="oopp", - avatar_url="urlurlurl", - tts=True, - attachment=mock_attachment_1, - attachments=mock_attachment_2, - component=mock_component, - components=mock_components, - embed=mock_embed, - embeds=mock_embeds, - poll=mock_poll, - mentions_everyone=False, - user_mentions=[1235432], - role_mentions=[65234123], - flags=64, - ) + async def test_execute_with_optionals(self, executable_webhook: webhooks.ExecutableWebhook): + mock_attachment_1 = mock.Mock() + mock_attachment_2 = mock.Mock() + mock_component = mock.Mock() + mock_components = mock.Mock(), mock.Mock() + mock_embed = mock.Mock() + mock_embeds = mock.Mock(), mock.Mock() + mock_poll = mock.Mock() - assert result is executable_webhook.app.rest.execute_webhook.return_value - executable_webhook.app.rest.execute_webhook.assert_awaited_once_with( + with mock.patch.object( + executable_webhook.app.rest, "execute_webhook", new_callable=mock.AsyncMock + ) as patched_execute_webhook: + result = await executable_webhook.execute( + content="coooo", + username="oopp", + avatar_url="urlurlurl", + tts=True, + attachment=mock_attachment_1, + attachments=mock_attachment_2, + component=mock_component, + components=mock_components, + embed=mock_embed, + embeds=mock_embeds, + poll=mock_poll, + mentions_everyone=False, + user_mentions=[1235432], + role_mentions=[65234123], + flags=64, + ) + + assert result is patched_execute_webhook.return_value + patched_execute_webhook.assert_awaited_once_with( webhook=executable_webhook.webhook_id, token=executable_webhook.token, content="coooo", @@ -93,11 +117,14 @@ async def test_execute_with_optionals(self, executable_webhook): ) @pytest.mark.asyncio - async def test_execute_without_optionals(self, executable_webhook): - result = await executable_webhook.execute() + async def test_execute_without_optionals(self, executable_webhook: webhooks.ExecutableWebhook): + with mock.patch.object( + executable_webhook.app.rest, "execute_webhook", new_callable=mock.AsyncMock + ) as patched_execute_webhook: + result = await executable_webhook.execute() - assert result is executable_webhook.app.rest.execute_webhook.return_value - executable_webhook.app.rest.execute_webhook.assert_awaited_once_with( + assert result is patched_execute_webhook.return_value + patched_execute_webhook.assert_awaited_once_with( webhook=executable_webhook.webhook_id, token=executable_webhook.token, content=undefined.UNDEFINED, @@ -118,113 +145,133 @@ async def test_execute_without_optionals(self, executable_webhook): ) @pytest.mark.asyncio - async def test_fetch_message(self, executable_webhook): - message = object() - returned_message = object() - executable_webhook.app.rest.fetch_webhook_message = mock.AsyncMock(return_value=returned_message) + async def test_fetch_message(self, executable_webhook: webhooks.ExecutableWebhook): + message = mock.Mock() + returned_message = mock.Mock() - returned = await executable_webhook.fetch_message(message) + with mock.patch.object( + executable_webhook.app.rest, "fetch_webhook_message", new=mock.AsyncMock(return_value=returned_message) + ) as patched_fetch_webhook_message: + returned = await executable_webhook.fetch_message(message) - assert returned is returned_message + assert returned is returned_message - executable_webhook.app.rest.fetch_webhook_message.assert_awaited_once_with( - executable_webhook.webhook_id, token=executable_webhook.token, message=message - ) + patched_fetch_webhook_message.assert_awaited_once_with( + executable_webhook.webhook_id, token=executable_webhook.token, message=message + ) @pytest.mark.asyncio - async def test_fetch_message_when_no_token(self, executable_webhook): - executable_webhook.token = None - with pytest.raises(ValueError, match=r"Cannot fetch a message using a webhook where we don't know the token"): + async def test_fetch_message_when_no_token(self, executable_webhook: webhooks.ExecutableWebhook): + with ( + mock.patch.object(executable_webhook, "_token", None), + pytest.raises(ValueError, match=r"Cannot fetch a message using a webhook where we don't know the token"), + ): await executable_webhook.fetch_message(987) @pytest.mark.asyncio - async def test_edit_message(self, executable_webhook): - message = object() - embed = object() - attachment = object() - component = object() - components = object() - - returned = await executable_webhook.edit_message( - message, - content="test", - embed=embed, - embeds=[embed, embed], - attachment=attachment, - attachments=[attachment, attachment], - component=component, - components=components, - mentions_everyone=False, - user_mentions=True, - role_mentions=[567, 890], - ) - - assert returned is executable_webhook.app.rest.edit_webhook_message.return_value - - executable_webhook.app.rest.edit_webhook_message.assert_awaited_once_with( - executable_webhook.webhook_id, - token=executable_webhook.token, - message=message, - content="test", - embed=embed, - embeds=[embed, embed], - attachment=attachment, - attachments=[attachment, attachment], - component=component, - components=components, - mentions_everyone=False, - user_mentions=True, - role_mentions=[567, 890], - ) + async def test_edit_message(self, executable_webhook: webhooks.ExecutableWebhook): + message = mock.Mock() + embed = mock.Mock() + attachment = mock.Mock() + component = mock.Mock() + components = mock.Mock() + + with ( + mock.patch.object(executable_webhook.app, "rest") as patched_rest, + mock.patch.object( + patched_rest, "edit_webhook_message", new_callable=mock.AsyncMock + ) as patched_edit_webhook_message, + ): + returned = await executable_webhook.edit_message( + message, + content="test", + embed=embed, + embeds=[embed, embed], + attachment=attachment, + attachments=[attachment, attachment], + component=component, + components=components, + mentions_everyone=False, + user_mentions=True, + role_mentions=[567, 890], + ) + + assert returned is patched_edit_webhook_message.return_value + + patched_edit_webhook_message.assert_awaited_once_with( + executable_webhook.webhook_id, + token=executable_webhook.token, + message=message, + content="test", + embed=embed, + embeds=[embed, embed], + attachment=attachment, + attachments=[attachment, attachment], + component=component, + components=components, + mentions_everyone=False, + user_mentions=True, + role_mentions=[567, 890], + ) @pytest.mark.asyncio - async def test_edit_message_when_no_token(self, executable_webhook): - executable_webhook.token = None - with pytest.raises(ValueError, match=r"Cannot edit a message using a webhook where we don't know the token"): + async def test_edit_message_when_no_token(self, executable_webhook: webhooks.ExecutableWebhook): + with ( + mock.patch.object(executable_webhook, "_token", None), + pytest.raises(ValueError, match=r"Cannot edit a message using a webhook where we don't know the token"), + ): await executable_webhook.edit_message(987) @pytest.mark.asyncio - async def test_delete_message(self, executable_webhook): - message = object() + async def test_delete_message(self, executable_webhook: webhooks.ExecutableWebhook): + message = mock.Mock() - await executable_webhook.delete_message(message) + with ( + mock.patch.object(executable_webhook.app, "rest") as patched_rest, + mock.patch.object( + patched_rest, "delete_webhook_message", new_callable=mock.AsyncMock + ) as patched_delete_webhook_message, + ): + await executable_webhook.delete_message(message) - executable_webhook.app.rest.delete_webhook_message.assert_awaited_once_with( - executable_webhook.webhook_id, token=executable_webhook.token, message=message - ) + patched_delete_webhook_message.assert_awaited_once_with( + executable_webhook.webhook_id, token=executable_webhook.token, message=message + ) @pytest.mark.asyncio - async def test_delete_message_when_no_token(self, executable_webhook): - executable_webhook.token = None - with pytest.raises(ValueError, match=r"Cannot delete a message using a webhook where we don't know the token"): + async def test_delete_message_when_no_token(self, executable_webhook: webhooks.ExecutableWebhook): + with ( + mock.patch.object(executable_webhook, "_token", None), + pytest.raises(ValueError, match=r"Cannot delete a message using a webhook where we don't know the token"), + ): assert await executable_webhook.delete_message(987) class TestPartialWebhook: @pytest.fixture - def webhook(self): + def webhook(self) -> webhooks.PartialWebhook: return webhooks.PartialWebhook( app=mock.Mock(rest=mock.AsyncMock()), - id=987654321, + id=snowflakes.Snowflake(987654321), type=webhooks.WebhookType.CHANNEL_FOLLOWER, name="not a webhook", avatar_hash="hook", application_id=None, ) - def test_str(self, webhook): + def test_str(self, webhook: webhooks.PartialWebhook): assert str(webhook) == "not a webhook" - def test_str_when_name_is_None(self, webhook): - webhook.name = None - assert str(webhook) == "Unnamed webhook ID 987654321" + def test_str_when_name_is_None(self, webhook: webhooks.PartialWebhook): + with mock.patch.object(webhook, "name", None): + assert str(webhook) == "Unnamed webhook ID 987654321" - def test_mention_property(self, webhook): + def test_mention_property(self, webhook: webhooks.PartialWebhook): assert webhook.mention == "<@987654321>" - def test_make_avatar_url_format_set_to_deprecated_ext_argument_if_provided(self, webhook): + def test_make_avatar_url_format_set_to_deprecated_ext_argument_if_provided(self, webhook: webhooks.PartialWebhook): with mock.patch.object( - channels.routes, "CDN_USER_AVATAR", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) + routes, "CDN_USER_AVATAR", new=mock.Mock(compile_to_file=mock.Mock(return_value="file")) ) as route: assert webhook.make_avatar_url(ext="JPEG") == "file" @@ -232,18 +279,19 @@ def test_make_avatar_url_format_set_to_deprecated_ext_argument_if_provided(self, urls.CDN_URL, user_id=987654321, hash="hook", size=4096, file_format="JPEG", lossless=True ) - def test_avatar_url_property(self, webhook): + def test_avatar_url_property(self, webhook: webhooks.PartialWebhook): assert webhook.avatar_url == webhook.make_avatar_url() - def test_default_avatar_url(self, webhook): + def test_default_avatar_url(self, webhook: webhooks.PartialWebhook): assert webhook.default_avatar_url.url == "https://cdn.discordapp.com/embed/avatars/0.png" - def test_make_avatar_url(self, webhook): + def test_make_avatar_url(self, webhook: webhooks.PartialWebhook): result = webhook.make_avatar_url(ext="jpeg", size=2048) + assert result is not None assert result.url == "https://cdn.discordapp.com/avatars/987654321/hook.jpeg?size=2048" - def test_make_avatar_url_when_no_avatar(self, webhook): + def test_make_avatar_url_when_no_avatar(self, webhook: webhooks.PartialWebhook): webhook.avatar_hash = None assert webhook.make_avatar_url() is None @@ -251,13 +299,13 @@ def test_make_avatar_url_when_no_avatar(self, webhook): class TestIncomingWebhook: @pytest.fixture - def webhook(self): + def webhook(self) -> webhooks.IncomingWebhook: return webhooks.IncomingWebhook( app=mock.Mock(rest=mock.AsyncMock()), - id=987654321, + id=snowflakes.Snowflake(987654321), type=webhooks.WebhookType.CHANNEL_FOLLOWER, - guild_id=123, - channel_id=456, + guild_id=snowflakes.Snowflake(123), + channel_id=snowflakes.Snowflake(456), author=None, name="not a webhook", avatar_hash=None, @@ -265,217 +313,268 @@ def webhook(self): application_id=None, ) - def test_webhook_id_property(self, webhook): + def test_webhook_id_property(self, webhook: webhooks.IncomingWebhook): assert webhook.webhook_id is webhook.id @pytest.mark.asyncio - async def test_delete(self, webhook): + async def test_delete(self, webhook: webhooks.IncomingWebhook): webhook.token = None - await webhook.delete() + with mock.patch.object( + webhook.app.rest, "delete_webhook", new_callable=mock.AsyncMock + ) as patched_delete_webhook: + await webhook.delete() - webhook.app.rest.delete_webhook.assert_awaited_once_with(987654321, token=undefined.UNDEFINED) + patched_delete_webhook.assert_awaited_once_with(987654321, token=undefined.UNDEFINED) @pytest.mark.asyncio - async def test_delete_uses_token_property(self, webhook): + async def test_delete_uses_token_property(self, webhook: webhooks.IncomingWebhook): webhook.token = "123321" - await webhook.delete() + with mock.patch.object( + webhook.app.rest, "delete_webhook", new_callable=mock.AsyncMock + ) as patched_delete_webhook: + await webhook.delete() - webhook.app.rest.delete_webhook.assert_awaited_once_with(987654321, token="123321") + patched_delete_webhook.assert_awaited_once_with(987654321, token="123321") @pytest.mark.asyncio - async def test_delete_use_token_is_true(self, webhook): + async def test_delete_use_token_is_true(self, webhook: webhooks.IncomingWebhook): webhook.token = "322312" - await webhook.delete(use_token=True) + with mock.patch.object( + webhook.app.rest, "delete_webhook", new_callable=mock.AsyncMock + ) as patched_delete_webhook: + await webhook.delete(use_token=True) - webhook.app.rest.delete_webhook.assert_awaited_once_with(987654321, token="322312") + patched_delete_webhook.assert_awaited_once_with(987654321, token="322312") @pytest.mark.asyncio - async def test_delete_use_token_is_true_without_token(self, webhook): + async def test_delete_use_token_is_true_without_token(self, webhook: webhooks.IncomingWebhook): webhook.token = None - with pytest.raises(ValueError, match="This webhook's token is unknown, so cannot be used"): - await webhook.delete(use_token=True) + with mock.patch.object( + webhook.app.rest, "delete_webhook", new_callable=mock.AsyncMock + ) as patched_delete_webhook: + with pytest.raises(ValueError, match="This webhook's token is unknown, so cannot be used"): + await webhook.delete(use_token=True) - webhook.app.rest.delete_webhook.assert_not_called() + patched_delete_webhook.assert_not_called() @pytest.mark.asyncio - async def test_delete_use_token_is_false(self, webhook): + async def test_delete_use_token_is_false(self, webhook: webhooks.IncomingWebhook): webhook.token = "322312" - await webhook.delete(use_token=False) + with mock.patch.object( + webhook.app.rest, "delete_webhook", new_callable=mock.AsyncMock + ) as patched_delete_webhook: + await webhook.delete(use_token=False) - webhook.app.rest.delete_webhook.assert_awaited_once_with(987654321, token=undefined.UNDEFINED) + patched_delete_webhook.assert_awaited_once_with(987654321, token=undefined.UNDEFINED) @pytest.mark.asyncio - async def test_edit(self, webhook): + async def test_edit(self, webhook: webhooks.IncomingWebhook): webhook.token = None - webhook.app.rest.edit_webhook.return_value = mock.Mock(webhooks.IncomingWebhook) - mock_avatar = object() + with mock.patch.object( + webhook.app.rest, "edit_webhook", new=mock.AsyncMock(return_value=mock.Mock(webhooks.IncomingWebhook)) + ) as patched_edit_webhook: + mock_avatar = mock.Mock() - result = await webhook.edit(name="OK", avatar=mock_avatar, channel=33333, reason="byebye") + result = await webhook.edit(name="OK", avatar=mock_avatar, channel=33333, reason="byebye") - assert result is webhook.app.rest.edit_webhook.return_value - webhook.app.rest.edit_webhook.assert_awaited_once_with( - 987654321, token=undefined.UNDEFINED, name="OK", avatar=mock_avatar, channel=33333, reason="byebye" - ) + assert result is patched_edit_webhook.return_value + patched_edit_webhook.assert_awaited_once_with( + 987654321, token=undefined.UNDEFINED, name="OK", avatar=mock_avatar, channel=33333, reason="byebye" + ) @pytest.mark.asyncio - async def test_edit_uses_token_property(self, webhook): + async def test_edit_uses_token_property(self, webhook: webhooks.IncomingWebhook): webhook.token = "aye" - webhook.app.rest.edit_webhook.return_value = mock.Mock(webhooks.IncomingWebhook) - mock_avatar = object() + with mock.patch.object( + webhook.app.rest, "edit_webhook", new=mock.AsyncMock(return_value=mock.Mock(webhooks.IncomingWebhook)) + ) as patched_edit_webhook: + mock_avatar = mock.Mock() - result = await webhook.edit(name="bye", avatar=mock_avatar, channel=33333, reason="byebye") + result = await webhook.edit(name="bye", avatar=mock_avatar, channel=33333, reason="byebye") - assert result is webhook.app.rest.edit_webhook.return_value - webhook.app.rest.edit_webhook.assert_awaited_once_with( - 987654321, token="aye", name="bye", avatar=mock_avatar, channel=33333, reason="byebye" - ) + assert result is patched_edit_webhook.return_value + patched_edit_webhook.assert_awaited_once_with( + 987654321, token="aye", name="bye", avatar=mock_avatar, channel=33333, reason="byebye" + ) @pytest.mark.asyncio - async def test_edit_when_use_token_is_true(self, webhook): + async def test_edit_when_use_token_is_true(self, webhook: webhooks.IncomingWebhook): webhook.token = "owoowow" - webhook.app.rest.edit_webhook.return_value = mock.Mock(webhooks.IncomingWebhook) - mock_avatar = object() + with mock.patch.object( + webhook.app.rest, "edit_webhook", new=mock.AsyncMock(return_value=mock.Mock(webhooks.IncomingWebhook)) + ) as patched_edit_webhook: + mock_avatar = mock.Mock() - result = await webhook.edit(use_token=True, name="hiu", avatar=mock_avatar, channel=231, reason="sus") + result = await webhook.edit(use_token=True, name="hiu", avatar=mock_avatar, channel=231, reason="sus") - assert result is webhook.app.rest.edit_webhook.return_value - webhook.app.rest.edit_webhook.assert_awaited_once_with( - 987654321, token="owoowow", name="hiu", avatar=mock_avatar, channel=231, reason="sus" - ) + assert result is patched_edit_webhook.return_value + patched_edit_webhook.assert_awaited_once_with( + 987654321, token="owoowow", name="hiu", avatar=mock_avatar, channel=231, reason="sus" + ) @pytest.mark.asyncio - async def test_edit_when_use_token_is_true_and_no_token(self, webhook): + async def test_edit_when_use_token_is_true_and_no_token(self, webhook: webhooks.IncomingWebhook): webhook.token = None - with pytest.raises(ValueError, match="This webhook's token is unknown, so cannot be used"): + with ( + mock.patch.object( + webhook.app.rest, "edit_webhook", new=mock.AsyncMock(return_value=mock.Mock(webhooks.IncomingWebhook)) + ) as patched_edit_webhook, + pytest.raises(ValueError, match="This webhook's token is unknown, so cannot be used"), + ): await webhook.edit(use_token=True) - webhook.app.rest.edit_webhook.assert_not_called() + patched_edit_webhook.assert_not_called() @pytest.mark.asyncio - async def test_edit_when_use_token_is_false(self, webhook): + async def test_edit_when_use_token_is_false(self, webhook: webhooks.IncomingWebhook): webhook.token = "owoowow" - webhook.app.rest.edit_webhook.return_value = mock.Mock(webhooks.IncomingWebhook) - mock_avatar = object() + with mock.patch.object( + webhook.app.rest, "edit_webhook", new=mock.AsyncMock(return_value=mock.Mock(webhooks.IncomingWebhook)) + ) as patched_edit_webhook: + mock_avatar = mock.Mock() - result = await webhook.edit(use_token=False, name="eee", avatar=mock_avatar, channel=231, reason="rrr") + result = await webhook.edit(use_token=False, name="eee", avatar=mock_avatar, channel=231, reason="rrr") - assert result is webhook.app.rest.edit_webhook.return_value - webhook.app.rest.edit_webhook.assert_awaited_once_with( - 987654321, token=undefined.UNDEFINED, name="eee", avatar=mock_avatar, channel=231, reason="rrr" - ) + assert result is patched_edit_webhook.return_value + patched_edit_webhook.assert_awaited_once_with( + 987654321, token=undefined.UNDEFINED, name="eee", avatar=mock_avatar, channel=231, reason="rrr" + ) @pytest.mark.asyncio - async def test_fetch_channel(self, webhook): - webhook.app.rest.fetch_channel.return_value = mock.Mock(channels.GuildTextChannel) - - assert await webhook.fetch_channel() is webhook.app.rest.fetch_channel.return_value + async def test_fetch_channel(self, webhook: webhooks.IncomingWebhook): + with mock.patch.object( + webhook.app.rest, "fetch_channel", new=mock.AsyncMock(return_value=mock.Mock(channels.GuildTextChannel)) + ) as patched_fetch_channel: + assert await webhook.fetch_channel() is patched_fetch_channel.return_value - webhook.app.rest.fetch_channel.assert_awaited_once_with(webhook.channel_id) + patched_fetch_channel.assert_awaited_once_with(webhook.channel_id) @pytest.mark.asyncio - async def test_fetch_self(self, webhook): - webhook.token = None - webhook.app.rest.fetch_webhook.return_value = mock.Mock(webhooks.IncomingWebhook) - - result = await webhook.fetch_self() - - assert result is webhook.app.rest.fetch_webhook.return_value - webhook.app.rest.fetch_webhook.assert_awaited_once_with(987654321, token=undefined.UNDEFINED) + async def test_fetch_self(self, webhook: webhooks.IncomingWebhook): + with ( + mock.patch.object(webhook, "token", None), + mock.patch.object( + webhook.app.rest, "fetch_webhook", new=mock.AsyncMock(return_value=mock.Mock(webhooks.IncomingWebhook)) + ) as patched_fetch_webhook, + ): + result = await webhook.fetch_self() + + assert result is patched_fetch_webhook.return_value + patched_fetch_webhook.assert_awaited_once_with(987654321, token=undefined.UNDEFINED) @pytest.mark.asyncio - async def test_fetch_self_uses_token_property(self, webhook): - webhook.token = "no gnomo" - webhook.app.rest.fetch_webhook.return_value = mock.Mock(webhooks.IncomingWebhook) - - result = await webhook.fetch_self() - - assert result is webhook.app.rest.fetch_webhook.return_value - webhook.app.rest.fetch_webhook.assert_awaited_once_with(987654321, token="no gnomo") + async def test_fetch_self_uses_token_property(self, webhook: webhooks.IncomingWebhook): + with ( + mock.patch.object(webhook, "token", "no gnomo"), + mock.patch.object( + webhook.app.rest, "fetch_webhook", new=mock.AsyncMock(return_value=mock.Mock(webhooks.IncomingWebhook)) + ) as patched_fetch_webhook, + ): + result = await webhook.fetch_self() + + assert result is patched_fetch_webhook.return_value + patched_fetch_webhook.assert_awaited_once_with(987654321, token="no gnomo") @pytest.mark.asyncio - async def test_fetch_self_when_use_token_is_true(self, webhook): - webhook.token = "no momo" - webhook.app.rest.fetch_webhook.return_value = mock.Mock(webhooks.IncomingWebhook) - - result = await webhook.fetch_self(use_token=True) - - assert result is webhook.app.rest.fetch_webhook.return_value - webhook.app.rest.fetch_webhook.assert_awaited_once_with(987654321, token="no momo") + async def test_fetch_self_when_use_token_is_true(self, webhook: webhooks.IncomingWebhook): + with ( + mock.patch.object(webhook, "token", "no momo"), + mock.patch.object( + webhook.app.rest, "fetch_webhook", new=mock.AsyncMock(return_value=mock.Mock(webhooks.IncomingWebhook)) + ) as patched_fetch_webhook, + ): + result = await webhook.fetch_self(use_token=True) + + assert result is patched_fetch_webhook.return_value + patched_fetch_webhook.assert_awaited_once_with(987654321, token="no momo") @pytest.mark.asyncio - async def test_fetch_self_when_use_token_is_true_without_token_property(self, webhook): + async def test_fetch_self_when_use_token_is_true_without_token_property(self, webhook: webhooks.IncomingWebhook): webhook.token = None - with pytest.raises(ValueError, match="This webhook's token is unknown, so cannot be used"): - await webhook.fetch_self(use_token=True) + with mock.patch.object(webhook.app.rest, "fetch_webhook") as patched_fetch_webhook: + with pytest.raises(ValueError, match="This webhook's token is unknown, so cannot be used"): + await webhook.fetch_self(use_token=True) - webhook.app.rest.fetch_webhook.assert_not_called() + patched_fetch_webhook.assert_not_called() @pytest.mark.asyncio - async def test_fetch_self_when_use_token_is_false(self, webhook): - webhook.token = "no momo" - webhook.app.rest.fetch_webhook.return_value = mock.Mock(webhooks.IncomingWebhook) + async def test_fetch_self_when_use_token_is_false(self, webhook: webhooks.IncomingWebhook): + with ( + mock.patch.object(webhook, "token", "no momo"), + mock.patch.object( + webhook.app.rest, "fetch_webhook", new=mock.AsyncMock(return_value=mock.Mock(webhooks.IncomingWebhook)) + ) as patched_fetch_webhook, + ): + result = await webhook.fetch_self(use_token=False) - result = await webhook.fetch_self(use_token=False) - - assert result is webhook.app.rest.fetch_webhook.return_value - webhook.app.rest.fetch_webhook.assert_awaited_once_with(987654321, token=undefined.UNDEFINED) + assert result is patched_fetch_webhook.return_value + patched_fetch_webhook.assert_awaited_once_with(987654321, token=undefined.UNDEFINED) class TestChannelFollowerWebhook: @pytest.fixture - def webhook(self): + def webhook(self) -> webhooks.ChannelFollowerWebhook: return webhooks.ChannelFollowerWebhook( app=mock.Mock(rest=mock.AsyncMock()), - id=987654321, + id=snowflakes.Snowflake(987654321), type=webhooks.WebhookType.CHANNEL_FOLLOWER, - guild_id=123, - channel_id=456, + guild_id=snowflakes.Snowflake(123), + channel_id=snowflakes.Snowflake(456), author=None, name="not a webhook", avatar_hash=None, application_id=None, - source_channel=object(), - source_guild=object(), + source_channel=mock.Mock(), + source_guild=mock.Mock(), ) @pytest.mark.asyncio - async def test_delete(self, webhook): - await webhook.delete() + async def test_delete(self, webhook: webhooks.ChannelFollowerWebhook): + with mock.patch.object(webhook.app.rest, "delete_webhook") as patched_delete_webhook: + await webhook.delete() - webhook.app.rest.delete_webhook.assert_awaited_once_with(987654321) + patched_delete_webhook.assert_awaited_once_with(987654321) @pytest.mark.asyncio - async def test_edit(self, webhook): - mock_avatar = object() - webhook.app.rest.edit_webhook.return_value = mock.Mock(webhooks.ChannelFollowerWebhook) + async def test_edit(self, webhook: webhooks.ChannelFollowerWebhook): + mock_avatar = mock.Mock() - result = await webhook.edit(name="hi", avatar=mock_avatar, channel=43123, reason="ok") + with mock.patch.object( + webhook.app.rest, + "edit_webhook", + new=mock.AsyncMock(return_value=mock.Mock(webhooks.ChannelFollowerWebhook)), + ) as patched_edit_webhook: + result = await webhook.edit(name="hi", avatar=mock_avatar, channel=43123, reason="ok") - assert result is webhook.app.rest.edit_webhook.return_value - webhook.app.rest.edit_webhook.assert_awaited_once_with( - 987654321, name="hi", avatar=mock_avatar, channel=43123, reason="ok" - ) + assert result is patched_edit_webhook.return_value + patched_edit_webhook.assert_awaited_once_with( + 987654321, name="hi", avatar=mock_avatar, channel=43123, reason="ok" + ) @pytest.mark.asyncio - async def test_fetch_channel(self, webhook): - webhook.app.rest.fetch_channel.return_value = mock.Mock(channels.GuildTextChannel) - - assert await webhook.fetch_channel() is webhook.app.rest.fetch_channel.return_value + async def test_fetch_channel(self, webhook: webhooks.ChannelFollowerWebhook): + with mock.patch.object( + webhook.app.rest, "fetch_channel", new=mock.AsyncMock(return_value=mock.Mock(channels.GuildTextChannel)) + ) as patched_fetch_channel: + assert await webhook.fetch_channel() is patched_fetch_channel.return_value - webhook.app.rest.fetch_channel.assert_awaited_once_with(webhook.channel_id) + patched_fetch_channel.assert_awaited_once_with(webhook.channel_id) @pytest.mark.asyncio - async def test_fetch_self(self, webhook): - webhook.app.rest.fetch_webhook.return_value = mock.Mock(webhooks.ChannelFollowerWebhook) - - result = await webhook.fetch_self() - - assert result is webhook.app.rest.fetch_webhook.return_value - webhook.app.rest.fetch_webhook.assert_awaited_once_with(987654321) + async def test_fetch_self(self, webhook: webhooks.ChannelFollowerWebhook): + with mock.patch.object( + webhook.app.rest, + "fetch_webhook", + new=mock.AsyncMock(return_value=mock.Mock(webhooks.ChannelFollowerWebhook)), + ) as patched_fetch_webhook: + result = await webhook.fetch_self() + + assert result is patched_fetch_webhook.return_value + patched_fetch_webhook.assert_awaited_once_with(987654321)