Skip to content

Commit 2d5013c

Browse files
committed
refactor
1 parent 38755d0 commit 2d5013c

File tree

3 files changed

+35
-48
lines changed

3 files changed

+35
-48
lines changed

peewee_async/signals.py

Lines changed: 6 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -1,49 +1,8 @@
1-
"""
2-
Provide django-style hooks for model events.
3-
"""
41
from peewee_async import AioModel as _Model
52
from typing import Union, Literal, Any
3+
from playhouse.signals import Signal
64

7-
8-
class Signal(object):
9-
def __init__(self) -> None:
10-
self._flush()
11-
12-
def _flush(self)-> None:
13-
self._receivers = set()
14-
self._receiver_list = []
15-
16-
def connect(self, receiver, name=None, sender=None) -> None:
17-
name = name or receiver.__name__
18-
key = (name, sender)
19-
if key not in self._receivers:
20-
self._receivers.add(key)
21-
self._receiver_list.append((name, receiver, sender))
22-
else:
23-
raise ValueError('receiver named %s (for sender=%s) already '
24-
'connected' % (name, sender or 'any'))
25-
26-
def disconnect(self, receiver=None, name=None, sender=None) -> None:
27-
if receiver:
28-
name = name or receiver.__name__
29-
if not name:
30-
raise ValueError('a receiver or a name must be provided')
31-
32-
key = (name, sender)
33-
if key not in self._receivers:
34-
raise ValueError('receiver named %s for sender=%s not found.' %
35-
(name, sender or 'any'))
36-
37-
self._receivers.remove(key)
38-
self._receiver_list = [(n, r, s) for n, r, s in self._receiver_list
39-
if (n, s) != key]
40-
41-
def __call__(self, name=None, sender=None):
42-
def decorator(fn):
43-
self.connect(fn, name, sender)
44-
return fn
45-
return decorator
46-
5+
class AioSignal(Signal):
476
async def send(self, instance, *args, **kwargs):
487
sender = type(instance)
498
responses = []
@@ -53,10 +12,10 @@ async def send(self, instance, *args, **kwargs):
5312
return responses
5413

5514

56-
aio_pre_save = Signal()
57-
aio_post_save = Signal()
58-
aio_pre_delete = Signal()
59-
aio_post_delete = Signal()
15+
aio_pre_save = AioSignal()
16+
aio_post_save = AioSignal()
17+
aio_pre_delete = AioSignal()
18+
aio_post_delete = AioSignal()
6019
pre_init = Signal() # can't be async !
6120

6221

tests/models.py

Lines changed: 10 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
import peewee
44
import peewee_async
5+
import peewee_async.signals
56

67

78
class TestModel(peewee_async.AioModel):
@@ -61,7 +62,15 @@ class IntegerTestModel(peewee_async.AioModel):
6162
num = peewee.IntegerField()
6263

6364

65+
class TestSignalModel(peewee_async.signals.AioModel):
66+
__test__ = False # disable pytest warnings
67+
text = peewee.CharField(max_length=100)
68+
69+
def __str__(self) -> str:
70+
return '<%s id=%s> %s' % (self.__class__.__name__, self.id, self.text)
71+
72+
6473
ALL_MODELS = (
6574
TestModel, UUIDTestModel, TestModelAlpha, TestModelBeta, TestModelGamma,
66-
CompositeTestModel, IntegerTestModel
75+
CompositeTestModel, IntegerTestModel, TestSignalModel
6776
)

tests/test_signals.py

Lines changed: 19 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,19 @@
1+
import uuid
2+
3+
from peewee_async.databases import AioDatabase
4+
from tests.conftest import dbs_all, dbs_postgres
5+
from tests.models import TestSignalModel
6+
from tests.utils import model_has_fields
7+
from peewee_async.signals import aio_pre_save
8+
9+
10+
11+
12+
@dbs_all
13+
async def test_aio_pre_save(db: AioDatabase) -> None:
14+
15+
@aio_pre_save(sender=TestSignalModel)
16+
async def on_save_handler(model_class, instance, created):
17+
print(model_class, instance, created)
18+
19+
await TestSignalModel.aio_create(text="text")

0 commit comments

Comments
 (0)