Skip to content

Commit 48a681f

Browse files
committed
Improve refresh_from_db to make it works with protected fields
1 parent f017f4c commit 48a681f

File tree

2 files changed

+25
-13
lines changed

2 files changed

+25
-13
lines changed

django_fsm/__init__.py

Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -431,19 +431,16 @@ def is_fsm_and_protected(f):
431431
return {f.attname for f in protected_fields}
432432

433433
def refresh_from_db(self, *args, **kwargs):
434-
fields = kwargs.pop("fields", None)
434+
protected_fields = self._get_protected_fsm_fields()
435435

436-
# Use provided fields, if not set then reload all non-deferred fields.0
437-
if not fields:
438-
deferred_fields = self.get_deferred_fields()
439-
protected_fields = self._get_protected_fsm_fields()
440-
skipped_fields = deferred_fields.union(protected_fields)
436+
for f in protected_fields:
437+
self._meta.get_field(f).protected = False
441438

442-
fields = [f.attname for f in self._meta.concrete_fields if f.attname not in skipped_fields]
443-
444-
kwargs["fields"] = fields
445439
super().refresh_from_db(*args, **kwargs)
446440

441+
for f in protected_fields:
442+
self._meta.get_field(f).protected = True
443+
447444

448445
class ConcurrentTransitionMixin:
449446
"""

tests/testapp/tests/test_protected_fields.py

Lines changed: 19 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -26,18 +26,33 @@ def test_no_direct_access(self):
2626
instance = RefreshableProtectedAccessModel()
2727
assert instance.status == "new"
2828

29-
def try_change():
30-
instance.status = "change"
31-
3229
with pytest.raises(AttributeError):
33-
try_change()
30+
instance.status = "change"
3431

3532
instance.publish()
3633
instance.save()
3734
assert instance.status == "published"
3835

3936
def test_refresh_from_db(self):
4037
instance = RefreshableModel()
38+
assert instance.status == "new"
39+
instance.save()
40+
41+
instance.refresh_from_db()
42+
assert instance.status == "new"
43+
44+
def test_concurrent_refresh_from_db(self):
45+
instance = RefreshableModel()
46+
assert instance.status == "new"
4147
instance.save()
4248

49+
# NOTE: This simulates a concurrent update scenario
50+
concurrent_instance = RefreshableModel.objects.get(pk=instance.pk)
51+
assert concurrent_instance.status == instance.status == "new"
52+
concurrent_instance.publish()
53+
assert concurrent_instance.status == "published"
54+
concurrent_instance.save()
55+
56+
assert instance.status == "new"
4357
instance.refresh_from_db()
58+
assert instance.status == "published"

0 commit comments

Comments
 (0)