|
| 1 | +""" |
| 2 | +Copyright 2023 The Dapr Authors |
| 3 | +Licensed under the Apache License, Version 2.0 (the "License"); |
| 4 | +you may not use this file except in compliance with the License. |
| 5 | +You may obtain a copy of the License at |
| 6 | + http://www.apache.org/licenses/LICENSE-2.0 |
| 7 | +Unless required by applicable law or agreed to in writing, software |
| 8 | +distributed under the License is distributed on an "AS IS" BASIS, |
| 9 | +WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| 10 | +See the License for the specific language governing permissions and |
| 11 | +limitations under the License. |
| 12 | +""" |
| 13 | + |
| 14 | +from __future__ import annotations |
| 15 | + |
| 16 | +from datetime import timedelta |
| 17 | +from typing import Any, Optional, TypeVar |
| 18 | + |
| 19 | +from dapr.actor.id import ActorId |
| 20 | +from dapr.actor.runtime._reminder_data import ActorReminderData |
| 21 | +from dapr.actor.runtime._timer_data import TIMER_CALLBACK, ActorTimerData |
| 22 | +from dapr.actor.runtime.actor import Actor |
| 23 | +from dapr.actor.runtime.mock_state_manager import MockStateManager |
| 24 | + |
| 25 | + |
| 26 | +class MockActor(Actor): |
| 27 | + """A mock actor class to be used to override certain Actor methods for unit testing. |
| 28 | + To be used only via the create_mock_actor function, which takes in a class and returns a |
| 29 | + mock actor object for that class. |
| 30 | +
|
| 31 | + Examples: |
| 32 | + class SomeActorInterface(ActorInterface): |
| 33 | + @actor_method(name="method") |
| 34 | + async def set_state(self, data: dict) -> None: |
| 35 | +
|
| 36 | + class SomeActor(Actor, SomeActorInterface): |
| 37 | + async def set_state(self, data: dict) -> None: |
| 38 | + await self._state_manager.set_state('state', data) |
| 39 | + await self._state_manager.save_state() |
| 40 | +
|
| 41 | + mock_actor = create_mock_actor(SomeActor, "actor_1") |
| 42 | + assert mock_actor._state_manager._mock_state == {} |
| 43 | + await mock_actor.set_state({"test":10}) |
| 44 | + assert mock_actor._state_manager._mock_state == {"test":10} |
| 45 | + """ |
| 46 | + |
| 47 | + def __init__(self, actor_id: str, initstate: Optional[dict]): |
| 48 | + self.id = ActorId(actor_id) |
| 49 | + self._runtime_ctx = None # type: ignore |
| 50 | + self._state_manager = MockStateManager(self, initstate) |
| 51 | + |
| 52 | + async def register_timer( |
| 53 | + self, |
| 54 | + name: Optional[str], |
| 55 | + callback: TIMER_CALLBACK, |
| 56 | + state: Any, |
| 57 | + due_time: timedelta, |
| 58 | + period: timedelta, |
| 59 | + ttl: Optional[timedelta] = None, |
| 60 | + ) -> None: |
| 61 | + """Adds actor timer to self._state_manager._mock_timers. |
| 62 | + Args: |
| 63 | + name (str): the name of the timer to register. |
| 64 | + callback (Callable): An awaitable callable which will be called when the timer fires. |
| 65 | + state (Any): An object which will pass to the callback method, or None. |
| 66 | + due_time (datetime.timedelta): the amount of time to delay before the awaitable |
| 67 | + callback is first invoked. |
| 68 | + period (datetime.timedelta): the time interval between invocations |
| 69 | + of the awaitable callback. |
| 70 | + ttl (Optional[datetime.timedelta]): the time interval before the timer stops firing |
| 71 | + """ |
| 72 | + name = name or self.__get_new_timer_name() |
| 73 | + timer = ActorTimerData(name, callback, state, due_time, period, ttl) |
| 74 | + self._state_manager._mock_timers[name] = timer # type: ignore |
| 75 | + |
| 76 | + async def unregister_timer(self, name: str) -> None: |
| 77 | + """Unregisters actor timer from self._state_manager._mock_timers. |
| 78 | +
|
| 79 | + Args: |
| 80 | + name (str): the name of the timer to unregister. |
| 81 | + """ |
| 82 | + self._state_manager._mock_timers.pop(name, None) # type: ignore |
| 83 | + |
| 84 | + async def register_reminder( |
| 85 | + self, |
| 86 | + name: str, |
| 87 | + state: bytes, |
| 88 | + due_time: timedelta, |
| 89 | + period: timedelta, |
| 90 | + ttl: Optional[timedelta] = None, |
| 91 | + ) -> None: |
| 92 | + """Adds actor reminder to self._state_manager._mock_reminders. |
| 93 | +
|
| 94 | + Args: |
| 95 | + name (str): the name of the reminder to register. the name must be unique per actor. |
| 96 | + state (bytes): the user state passed to the reminder invocation. |
| 97 | + due_time (datetime.timedelta): the amount of time to delay before invoking the reminder |
| 98 | + for the first time. |
| 99 | + period (datetime.timedelta): the time interval between reminder invocations after |
| 100 | + the first invocation. |
| 101 | + ttl (datetime.timedelta): the time interval before the reminder stops firing |
| 102 | + """ |
| 103 | + reminder = ActorReminderData(name, state, due_time, period, ttl) |
| 104 | + self._state_manager._mock_reminders[name] = reminder # type: ignore |
| 105 | + |
| 106 | + async def unregister_reminder(self, name: str) -> None: |
| 107 | + """Unregisters actor reminder from self._state_manager._mock_reminders.. |
| 108 | +
|
| 109 | + Args: |
| 110 | + name (str): the name of the reminder to unregister. |
| 111 | + """ |
| 112 | + self._state_manager._mock_reminders.pop(name, None) # type: ignore |
| 113 | + |
| 114 | + |
| 115 | +T = TypeVar('T', bound=Actor) |
| 116 | + |
| 117 | + |
| 118 | +def create_mock_actor(cls1: type[T], actor_id: str, initstate: Optional[dict] = None) -> T: |
| 119 | + class MockSuperClass(MockActor, cls1): # type: ignore |
| 120 | + pass |
| 121 | + |
| 122 | + return MockSuperClass(actor_id, initstate) # type: ignore |
0 commit comments