Skip to content

Commit 4b68fda

Browse files
authored
feat: allow SignalInstances for evented dataclass fields to emit on field mutation (in addition to field change) (#379)
* working example * update test * fix test * hacky fix * fix tests * update comments * add doc * coverage * pragma * skip cov
1 parent e3e2c52 commit 4b68fda

File tree

7 files changed

+150
-17
lines changed

7 files changed

+150
-17
lines changed

.pre-commit-config.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ exclude: .asv
77

88
repos:
99
- repo: https://github.com/crate-ci/typos
10-
rev: v1
10+
rev: v1.35.3
1111
hooks:
1212
- id: typos
1313
args: [--force-exclude] # omitting --write-changes

src/psygnal/_evented_model.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -722,7 +722,7 @@ def _setattr_with_dependents_impl(self, name: str, value: Any) -> None:
722722
elif name in self.__field_dependents__:
723723
deps_with_callbacks = self.__field_dependents__[name]
724724
else:
725-
return self._super_setattr_(name, value)
725+
return self._super_setattr_(name, value) # pragma: no cover
726726

727727
self._primary_changes.add(name)
728728
if name not in self._changes_queue:

src/psygnal/_group.py

Lines changed: 36 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -319,19 +319,40 @@ def disconnect(self, slot: Callable | None = None, missing_ok: bool = True) -> N
319319
def _slot_index(self, slot: Callable) -> int:
320320
"""Get index of `slot` in `self._slots`. Return -1 if not connected.
321321
322-
Override to handle _relay_partial objects directly without wrapping
323-
them in WeakFunction, which would create different objects.
322+
In interpreted mode, `_relay_partial` callbacks are stored as
323+
weak references to the callable object itself, so comparing the
324+
dereferenced callback to the provided `_relay_partial` works. In compiled
325+
mode (mypyc), `weak_callback` may normalize a `_relay_partial` to a strong
326+
reference to its `__call__` method (a MethodWrapperType), to avoid segfaults on
327+
garbage collection. In that case the default WeakCallback equality logic is the
328+
correct and more robust path.
329+
330+
Therefore, try the base implementation first (which compares normalized
331+
WeakCallback objects). If that fails and we're dealing with a
332+
`_relay_partial`, fall back to comparing the dereferenced callable to the
333+
provided slot for backward compatibility.
324334
"""
325-
if not isinstance(slot, _relay_partial):
326-
# For non-_relay_partial objects, use the default behavior
327-
return super()._slot_index(slot)
328-
329-
with self._lock:
330-
# For _relay_partial objects, compare directly against callbacks
331-
for i, s in enumerate(self._slots):
332-
if s.dereference() == slot:
333-
return i
334-
return -1 # pragma: no cover
335+
# First, try the standard equality path used by SignalInstance
336+
idx = super()._slot_index(slot)
337+
if idx != -1:
338+
return idx
339+
340+
# Fallback: handle direct comparison for `_relay_partial` instances
341+
if isinstance(slot, _relay_partial):
342+
with self._lock:
343+
for i, s in enumerate(self._slots):
344+
deref = s.dereference()
345+
# Case 1: stored deref is the _relay_partial itself (interpreted)
346+
if deref == slot:
347+
return i
348+
# Case 2: compiled path where we stored __call__ bound method
349+
# (these will never hit on codecov, but they are covered in tests)
350+
owner = getattr(deref, "__self__", None) # pragma: no cover
351+
if (
352+
isinstance(owner, _relay_partial) and owner == slot
353+
): # pragma: no cover
354+
return i
355+
return -1 # pragma: no cover
335356

336357

337358
# NOTE
@@ -612,6 +633,7 @@ def connect(
612633
max_args: int | None = None,
613634
on_ref_error: RefErrorChoice = "warn",
614635
priority: int = 0,
636+
emit_on_evented_child_events: bool = False,
615637
) -> Callable[[F], F] | F:
616638
if slot is None:
617639
return self._psygnal_relay.connect(
@@ -622,6 +644,7 @@ def connect(
622644
max_args=max_args,
623645
on_ref_error=on_ref_error,
624646
priority=priority,
647+
emit_on_evented_child_events=emit_on_evented_child_events,
625648
)
626649
else:
627650
return self._psygnal_relay.connect(
@@ -633,6 +656,7 @@ def connect(
633656
max_args=max_args,
634657
on_ref_error=on_ref_error,
635658
priority=priority,
659+
emit_on_evented_child_events=emit_on_evented_child_events,
636660
)
637661

638662
def connect_direct(

src/psygnal/_group_descriptor.py

Lines changed: 55 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,7 @@
3939

4040
T = TypeVar("T", bound=type)
4141
S = TypeVar("S")
42-
42+
F = TypeVar("F", bound=Callable)
4343

4444
_EQ_OPERATORS: dict[type, dict[str, EqOperator]] = {}
4545
_EQ_OPERATOR_NAME = "__eq_operators__"
@@ -133,6 +133,60 @@ def _pick_equality_operator(type_: type | None) -> EqOperator:
133133

134134

135135
class _DataclassFieldSignalInstance(SignalInstance):
136+
"""The type of SignalInstance when emitting dataclass field changes."""
137+
138+
def _connect_child_event_listener(self, slot: Callable[..., Any]) -> None:
139+
# ------------ Emit this signal when the field changes ------------
140+
#
141+
# _DataclassFieldSignalInstance is a SignalInstance that is used for fields
142+
# on evented dataclasses. For example `team.events.leader` is a
143+
# _DataclassFieldSignalInstance that emits when the leader field changes.
144+
# (e.g. `team.leader = new_leader`)
145+
#
146+
# However, by default, it does NOT emit when the leader itself changes.
147+
# (e.g. `team.leader.age = 60`)
148+
# ... because team.leader may not be an evented object, and we can't
149+
# assume that we can track changes on it.
150+
#
151+
# However, if `team.leader` IS itself an evented object, we can connect
152+
# to its events, and emit this signal when it changes. That's what we do here.
153+
154+
# First, ensure that this SignalInstance is indeed a SignalInstance on
155+
# a SignalGroup (presumably a SignalGroupDescriptor)
156+
if not isinstance((group := self.instance), SignalGroup):
157+
return
158+
159+
# then get the root object of the group (e.g. "team")
160+
root_object = group.instance
161+
162+
# get the field name (e.g. "leader") representing this SignalInstance
163+
field_name = self.name
164+
165+
# then get the value of the field (e.g. "team.leader")
166+
try:
167+
member = getattr(root_object, field_name)
168+
except Exception:
169+
return
170+
171+
# If that member is itself evented (e.g. "team.leader" is an evented obj)
172+
# then grab the SignalGroup on it (e.g. "team.leader.events")
173+
if group := _find_signal_group(member):
174+
# next, watch for ANY changes on the member
175+
# and call the slot with the new value of the entire field,
176+
#
177+
# e.g. team.leader.events.connect(lambda: callback(team.leader))
178+
def _on_any_change(info: EmissionInfo) -> None:
179+
new_val = getattr(root_object, field_name)
180+
# TODO somehow get old value? ...
181+
# note that old_value is available only in evented_setattr
182+
# but the `_handle_child_event_connections` could potentially be a
183+
# place to call slot(new_val, old_val)... if we can somehow
184+
# get the slot to it.
185+
old_val: Any = None
186+
slot(new_val, old_val)
187+
188+
group.connect(_on_any_change, check_nargs=False, on_ref_error="ignore")
189+
136190
def connect_setattr(
137191
self,
138192
obj: ref | object,

src/psygnal/_signal.py

Lines changed: 26 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -615,6 +615,7 @@ def connect(
615615
max_args: int | None = None,
616616
on_ref_error: RefErrorChoice = ...,
617617
priority: int = ...,
618+
emit_on_evented_child_events: bool = ...,
618619
) -> Callable[[F], F]: ...
619620

620621
@overload
@@ -629,6 +630,7 @@ def connect(
629630
max_args: int | None = None,
630631
on_ref_error: RefErrorChoice = ...,
631632
priority: int = ...,
633+
emit_on_evented_child_events: bool = ...,
632634
) -> F: ...
633635

634636
def connect(
@@ -642,6 +644,7 @@ def connect(
642644
max_args: int | None = None,
643645
on_ref_error: RefErrorChoice = "warn",
644646
priority: int = 0,
647+
emit_on_evented_child_events: bool = False,
645648
) -> Callable[[F], F] | F:
646649
"""Connect a callback (`slot`) to this signal.
647650
@@ -710,6 +713,16 @@ def my_function(): ...
710713
callbacks are called when multiple are connected to the same signal.
711714
Higher priority callbacks are called first. Negative values are allowed.
712715
The default is 0.
716+
emit_on_evented_child_events : bool
717+
If `True`, and if this is a SignalInstance associated with a specific field
718+
on an evented dataclass, and if that field itself is an evented dataclass,
719+
then the slot will be called both when the field is set directly, *and* when
720+
a child member of that field is set.
721+
For example, if `Team` is an evented-dataclass with a field `leader: Person`
722+
which is itself an evented-dataclass, then
723+
`team.events.leader.connect(callback, emit_on_evented_child_events=True)`
724+
will invoke callback even when `team.leader.age` is mutated (in addition to
725+
when `team.leader` is set directly).
713726
714727
Raises
715728
------
@@ -764,10 +777,22 @@ def _wrapper(
764777
if thread is not None:
765778
cb = QueuedCallback(cb, thread=thread)
766779
self._append_slot(cb)
780+
781+
if emit_on_evented_child_events:
782+
self._connect_child_event_listener(slot)
767783
return slot
768784

769785
return _wrapper if slot is None else _wrapper(slot)
770786

787+
def _connect_child_event_listener(self, slot: Callable) -> None:
788+
"""Connect a child event listener to the slot.
789+
790+
This is called when a slot is connected to this signal. It allows subclasses
791+
to connect additional event listeners to the slot.
792+
"""
793+
# implementing this as a method allows us to override/extend it in subclasses
794+
pass # pragma: no cover
795+
771796
def _append_slot(self, slot: WeakCallback) -> None:
772797
"""Append a slot to the list of slots.
773798
@@ -1138,7 +1163,7 @@ def __contains__(self, slot: Callable) -> bool:
11381163
# this change is needed for some reason after mypy v1.14.0
11391164
if callable(slot):
11401165
return self._slot_index(slot) >= 0
1141-
return False
1166+
return False # pragma: no cover
11421167

11431168
def __len__(self) -> int:
11441169
"""Return number of connected slots."""

src/psygnal/_weak_callback.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -228,6 +228,15 @@ def _on_delete(weak_cb):
228228
)
229229

230230
if callable(cb):
231+
# this is a bit of hack to workaround a segfault observed in testing
232+
# on python <=3.11 when compiled by mypyc,
233+
# during _weak_callback___WeakFunction_traverse
234+
# it specifically happens with MethodWrapperType objects, that I think are made
235+
# by mypyc itself. So we just don't attempt to weakref them here anymore.
236+
_call = getattr(cb, "__call__", None) # noqa
237+
if isinstance(_call, MethodWrapperType):
238+
return StrongFunction(_call, max_args, args, kwargs, priority=priority)
239+
231240
return WeakFunction(
232241
cb, max_args, args, kwargs, finalize, on_ref_error, priority=priority
233242
)

tests/test_evented_decorator.py

Lines changed: 22 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -472,7 +472,7 @@ class Parent:
472472
)
473473

474474

475-
def test_team_example():
475+
def test_team_example() -> None:
476476
@evented
477477
@dataclass
478478
class Person:
@@ -505,3 +505,24 @@ class Team:
505505
testing.assert_not_emitted(team.events.leader),
506506
):
507507
team.leader.name = "Alice"
508+
509+
510+
def test_signal_instance_emits_on_subevents() -> None:
511+
@evented
512+
@dataclass
513+
class Person:
514+
name: str = ""
515+
age: int = 0
516+
517+
@evented
518+
@dataclass
519+
class Team:
520+
name: str = ""
521+
leader: Person = field(default_factory=Person)
522+
523+
team = Team(name="A-Team", leader=Person(name="Hannibal", age=59))
524+
525+
mock = Mock()
526+
team.events.leader.connect(mock, emit_on_evented_child_events=True)
527+
team.leader.age = 60
528+
mock.assert_called_once_with(Person(name="Hannibal", age=60), None)

0 commit comments

Comments
 (0)