Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
18 changes: 17 additions & 1 deletion fieldsignals/signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,9 +65,17 @@ def connect( # type: ignore[override]
def is_reverse_rel(f: Any) -> bool:
return f.many_to_many or f.one_to_many or isinstance(f, ForeignObjectRel)

def is_virtual_field(f: Any) -> bool:
"""Virtual fields like GenericForeignKey don't have database columns."""
return not f.concrete

if fields is None:
resolved_fields = sender._meta.get_fields()
resolved_fields = [f for f in resolved_fields if not is_reverse_rel(f)]
resolved_fields = [
f
for f in resolved_fields
if not is_reverse_rel(f) and not is_virtual_field(f)
]
else:
resolved_fields = [f for f in sender._meta.get_fields() if f.name in set(fields)]
for f in resolved_fields:
Expand All @@ -76,6 +84,11 @@ def is_reverse_rel(f: Any) -> bool:
"django-fieldsignals doesn't handle reverse related fields "
f"({f.name} is a {f.__class__.__name__})"
)
if is_virtual_field(f):
raise ValueError(
"django-fieldsignals doesn't handle virtual fields "
f"({f.name} is a {f.__class__.__name__})"
)

if not resolved_fields:
raise ValueError("fields must be non-empty")
Expand Down Expand Up @@ -166,6 +179,9 @@ def get_and_update_changed_fields(
deferred_fields = instance.get_deferred_fields()

for field in fields:
if not field.concrete:
# Skip virtual fields (e.g. GenericForeignKey)
continue
if field.attname in deferred_fields:
continue
# using value_from_object instead of getattr() means we don't traverse foreignkeys
Expand Down
45 changes: 42 additions & 3 deletions fieldsignals/tests/test_signals.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ def __init__(self, name: str, m2m: bool = False) -> None:
self.attname = name
self.many_to_many = m2m
self.one_to_many = False
self.concrete = True

def value_from_object(self, instance: Any) -> Any:
return getattr(instance, self.name)
Expand Down Expand Up @@ -144,6 +145,29 @@ def get_fields() -> list[Field | MockOneToOneRel]:
return [Field("f"), MockOneToOneRel("o2o")]


class VirtualField:
"""Mock GenericForeignKey - virtual field with no database column."""

def __init__(self, name: str) -> None:
self.name = name
self.many_to_many = False
self.one_to_many = False
self.concrete = False


class FakeModelWithVirtualField:
f = "a value"
gfk = None

class _meta:
@staticmethod
def get_fields() -> list[Field | VirtualField]:
return [Field("f"), VirtualField("gfk")]

def get_deferred_fields(self) -> set[str]:
return set()


class TestGeneral:
@pytest.fixture(autouse=True)
def ready(self) -> None:
Expand All @@ -157,9 +181,7 @@ def test_m2m_fields_error(self) -> None:
def test_one_to_one_rel_field_error(self) -> None:
with must_be_called(False) as func:
with pytest.raises(ValueError):
post_save_changed.connect(
func, sender=FakeModelWithOneToOne, fields=["o2o", "f"]
)
post_save_changed.connect(func, sender=FakeModelWithOneToOne, fields=["o2o", "f"])

def test_one_to_one_rel_excluded(self) -> None:
with must_be_called(False) as func:
Expand Down Expand Up @@ -223,6 +245,23 @@ def test_boolean_field_transition_to_valid(self) -> None:

assert func.kwargs["changed_fields"] == {"is_active": (None, True)}

def test_virtual_field_excluded(self) -> None:
"""Virtual fields like GenericForeignKey are auto-skipped."""
with must_be_called(False) as func:
post_save_changed.connect(func, sender=FakeModelWithVirtualField)

obj = FakeModelWithVirtualField()
post_init.send(instance=obj, sender=FakeModelWithVirtualField)
post_save.send(instance=obj, sender=FakeModelWithVirtualField)

def test_virtual_field_error(self) -> None:
"""Explicitly requesting a virtual field raises clear error."""
with must_be_called(False) as func:
with pytest.raises(
ValueError, match="doesn't handle virtual fields.*gfk.*VirtualField"
):
post_save_changed.connect(func, sender=FakeModelWithVirtualField, fields=["gfk"])


class TestPostSave:
@pytest.fixture(autouse=True)
Expand Down