Skip to content

Commit 4ea632a

Browse files
committed
feat: implement event-based fallback for deferred_column_property; add tests to verify fallback behavior after refresh and loading
1 parent ad7e957 commit 4ea632a

File tree

2 files changed

+162
-176
lines changed

2 files changed

+162
-176
lines changed

sqlmodel/deferred_column.py

Lines changed: 43 additions & 176 deletions
Original file line numberDiff line numberDiff line change
@@ -7,118 +7,16 @@
77

88
from typing import Any
99

10+
from sqlalchemy import event
1011
from sqlalchemy.orm import ColumnProperty
11-
from sqlalchemy.orm.attributes import InstrumentedAttribute
12+
from sqlalchemy.orm.attributes import set_committed_value
1213
from sqlalchemy.orm.strategies import DeferredColumnLoader, _state_session
1314

1415

15-
class SafeAttributeWrapper:
16-
"""
17-
A simple wrapper around InstrumentedAttribute that checks session validity
18-
on every access and returns fallback values when needed.
19-
"""
20-
21-
def __init__(self, original_attr, fallback_value):
22-
self.original_attr = original_attr
23-
self.fallback_value = fallback_value
24-
# Copy important attributes from original
25-
self.__name__ = getattr(original_attr, "__name__", None)
26-
self.__doc__ = getattr(original_attr, "__doc__", None)
27-
28-
def __get__(self, instance, owner):
29-
"""Intercept attribute access to check session validity"""
30-
if instance is None:
31-
return self
32-
33-
# Check session state before accessing
34-
try:
35-
state = instance._sa_instance_state
36-
session = _state_session(state)
37-
38-
# First check for invalid async context - regardless of session state
39-
if self._is_invalid_async_context(state):
40-
return self.fallback_value
41-
42-
# If no session, check if attribute is already loaded
43-
if session is None:
44-
if (
45-
hasattr(instance, "__dict__")
46-
and self.original_attr.key in instance.__dict__
47-
):
48-
# Attribute was loaded previously but session is now invalid
49-
# However, if we detect we SHOULD be in async context but aren't,
50-
# return fallback instead of cached value
51-
return instance.__dict__[self.original_attr.key]
52-
else:
53-
# Not loaded and no session - return fallback
54-
return self.fallback_value
55-
56-
# Session is valid, proceed with normal access through original attribute
57-
return self.original_attr.__get__(instance, owner)
58-
59-
except Exception as e:
60-
# If any error occurs during access, check if it's async-related
61-
error_msg = str(e).lower()
62-
if any(
63-
keyword in error_msg
64-
for keyword in [
65-
"greenlet",
66-
"await_only",
67-
"asyncio",
68-
"async",
69-
"missinggreenlet",
70-
]
71-
):
72-
return self.fallback_value
73-
# For other errors, re-raise
74-
raise
75-
76-
def __set__(self, instance, value):
77-
"""Delegate setting to original attribute"""
78-
return self.original_attr.__set__(instance, value)
79-
80-
def __delete__(self, instance):
81-
"""Delegate deletion to original attribute"""
82-
return self.original_attr.__delete__(instance)
83-
84-
def _is_invalid_async_context(self, state):
85-
"""Check if we're in an invalid async context that would cause MissingGreenlet"""
86-
try:
87-
# Check if we have async session
88-
if hasattr(state, "async_session") and state.async_session is not None:
89-
# We have async session, need to check greenlet context
90-
try:
91-
import greenlet
92-
93-
current = greenlet.getcurrent()
94-
# If we're not in a greenlet context but have async session,
95-
# accessing deferred attributes will fail
96-
if current is None or current.parent is None:
97-
return True
98-
except ImportError:
99-
# No greenlet support, assume we're in invalid context if async_session exists
100-
return True
101-
return False
102-
except Exception:
103-
# If any check fails, assume we're in invalid context
104-
return True
105-
106-
# Make wrapper transparent to SQLAlchemy inspection system
107-
def __getattr__(self, name):
108-
"""Proxy all other attributes to the original InstrumentedAttribute"""
109-
return getattr(self.original_attr, name)
110-
111-
def _sa_inspect_type(self):
112-
"""Support SQLAlchemy inspection by delegating to original attribute"""
113-
if hasattr(self.original_attr, "_sa_inspect_type"):
114-
return self.original_attr._sa_inspect_type()
115-
return None
116-
117-
11816
class SafeDeferredColumnLoader(DeferredColumnLoader):
11917
"""
120-
A custom deferred column loader that returns a fallback value instead of
121-
raising DetachedInstanceError when the session is detached.
18+
A simplified deferred column loader that works with event-based fallback setting.
19+
The main fallback logic is now handled by event listeners in SafeColumnProperty.
12220
"""
12321

12422
def __init__(self, parent, strategy_key, fallback_value=None):
@@ -127,11 +25,10 @@ def __init__(self, parent, strategy_key, fallback_value=None):
12725

12826
def _load_for_state(self, state, passive):
12927
"""
130-
Override the default behavior to return fallback value instead of raising
131-
DetachedInstanceError or MissingGreenlet when session is None or async context is missing.
28+
Override to handle session-related errors gracefully.
29+
Fallback values are pre-set by event listeners, so we mainly handle exceptions here.
13230
"""
13331
from sqlalchemy.orm import LoaderCallableStatus
134-
from sqlalchemy.orm.attributes import set_committed_value
13532

13633
if not state.key:
13734
return LoaderCallableStatus.ATTR_EMPTY
@@ -142,51 +39,14 @@ def _load_for_state(self, state, passive):
14239
if not passive & PassiveFlag.SQL_OK:
14340
return LoaderCallableStatus.PASSIVE_NO_RESULT
14441

145-
# Check if the attribute is already loaded
146-
if self.key not in state.unloaded:
147-
# Attribute is already loaded, use parent implementation
148-
return super()._load_for_state(state, passive)
149-
15042
# Check if we have a session before attempting to load
15143
session = _state_session(state)
15244
if session is None:
153-
# No session available, set fallback value directly on the instance
154-
instance = state.obj()
155-
if instance is not None:
156-
set_committed_value(instance, self.key, self.fallback_value)
157-
return LoaderCallableStatus.ATTR_WAS_SET
158-
return self.fallback_value
159-
160-
# Check if this is an AsyncSession that might cause MissingGreenlet
161-
async_session = state.async_session
162-
if async_session is not None:
163-
# We have an async session, check if we're in proper async context
164-
try:
165-
# Try to import greenlet to check context
166-
import greenlet
167-
168-
current_greenlet = greenlet.getcurrent()
169-
# If we're in the main thread without proper async context,
170-
# the greenlet will not have a proper parent or spawn context
171-
if current_greenlet.parent is None and not hasattr(
172-
current_greenlet, "_spawning_greenlet"
173-
):
174-
# We're likely in sync code trying to access async session attributes
175-
instance = state.obj()
176-
if instance is not None:
177-
set_committed_value(instance, self.key, self.fallback_value)
178-
return LoaderCallableStatus.ATTR_WAS_SET
179-
return self.fallback_value
180-
except (ImportError, AttributeError):
181-
# greenlet not available, but we know it's an async session
182-
# in sync context - return fallback
183-
instance = state.obj()
184-
if instance is not None:
185-
set_committed_value(instance, self.key, self.fallback_value)
186-
return LoaderCallableStatus.ATTR_WAS_SET
187-
return self.fallback_value
188-
189-
# Final attempt with error handling for any remaining async issues
45+
# No session - return the fallback that should be already set by event listener
46+
# If for some reason it's not set, the fallback value will be used
47+
return LoaderCallableStatus.PASSIVE_NO_RESULT
48+
49+
# Try normal loading with error handling for async issues
19050
try:
19151
return super()._load_for_state(state, passive)
19252
except Exception as e:
@@ -202,43 +62,50 @@ def _load_for_state(self, state, passive):
20262
"missinggreenlet",
20363
]
20464
):
205-
# This is an async-related error, set fallback value
206-
instance = state.obj()
207-
if instance is not None:
208-
set_committed_value(instance, self.key, self.fallback_value)
209-
return LoaderCallableStatus.ATTR_WAS_SET
210-
return self.fallback_value
211-
# For other exceptions, re-raise them
212-
raise
213-
if any(
214-
keyword in error_msg
215-
for keyword in [
216-
"greenlet",
217-
"await_only",
218-
"asyncio",
219-
"async",
220-
"missinggreenlet",
221-
]
222-
):
223-
# This is an async-related error, set fallback value directly
224-
instance = state.obj()
225-
if instance is not None:
226-
set_committed_value(instance, self.key, self.fallback_value)
227-
return LoaderCallableStatus.ATTR_WAS_SET
228-
return self.fallback_value
65+
# This is an async-related error
66+
# The fallback value should already be set by event listener
67+
return LoaderCallableStatus.PASSIVE_NO_RESULT
22968
# For other exceptions, re-raise them
23069
raise
23170

23271

23372
class SafeColumnProperty(ColumnProperty):
23473
"""
235-
Custom ColumnProperty that uses SafeDeferredColumnLoader for deferred loading.
74+
Custom ColumnProperty that automatically sets fallback values on load events.
75+
This ensures deferred properties always have a safe fallback value available.
23676
"""
23777

23878
def __init__(self, *args, fallback_value=None, **kwargs):
23979
self.fallback_value = fallback_value
24080
super().__init__(*args, **kwargs)
24181

82+
def instrument_class(self, mapper):
83+
"""Override to set up event listeners for automatic fallback value setting"""
84+
result = super().instrument_class(mapper)
85+
86+
# Set up event listeners to automatically set fallback values
87+
self._setup_fallback_listeners(mapper)
88+
89+
return result
90+
91+
def _setup_fallback_listeners(self, mapper):
92+
"""Set up event listeners to automatically set fallback values on load/refresh"""
93+
class_type = mapper.class_
94+
key = self.key
95+
fallback_value = self.fallback_value
96+
97+
@event.listens_for(class_type, "load")
98+
def _set_deferred_fallback_on_load(target, context):
99+
"""Set fallback value when object is loaded from database"""
100+
if key not in target.__dict__:
101+
set_committed_value(target, key, fallback_value)
102+
103+
@event.listens_for(class_type, "refresh")
104+
def _set_deferred_fallback_on_refresh(target, context, attrs):
105+
"""Set fallback value when object is refreshed"""
106+
if key not in target.__dict__ or attrs is None or key in attrs:
107+
set_committed_value(target, key, fallback_value)
108+
242109
def do_init(self):
243110
"""Override to set our custom strategy after parent initialization."""
244111
super().do_init()

tests/test_event_fallback.py

Lines changed: 119 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,119 @@
1+
"""
2+
Test that deferred_column_property returns fallback values correctly
3+
even after refresh operations.
4+
"""
5+
6+
from typing import Optional
7+
8+
from sqlalchemy import create_engine
9+
from sqlmodel import Field, SQLModel, Session, select
10+
from sqlmodel import deferred_column_property
11+
12+
13+
def test_fallback_after_refresh():
14+
"""Test that fallback value is returned even after refresh"""
15+
16+
class Employee(SQLModel, table=True):
17+
__tablename__ = "employee_refresh_test"
18+
19+
id: Optional[int] = Field(default=None, primary_key=True)
20+
user_id: Optional[int] = None
21+
company_id: int = 1
22+
23+
@classmethod
24+
def __declare_last__(cls):
25+
cls.is_owner = deferred_column_property(
26+
cls.__table__.c.user_id == cls.__table__.c.company_id,
27+
fallback_value=False, # Should always return False
28+
deferred=True,
29+
)
30+
31+
engine = create_engine("sqlite:///:memory:")
32+
SQLModel.metadata.create_all(engine)
33+
34+
with Session(engine) as session:
35+
employee = Employee(user_id=1, company_id=1)
36+
session.add(employee)
37+
session.commit()
38+
session.refresh(employee)
39+
employee_id = employee.id
40+
41+
# Check that fallback is set automatically
42+
print(f"is_owner after refresh: {employee.is_owner}")
43+
print(f"is_owner in __dict__: {'is_owner' in employee.__dict__}")
44+
45+
assert employee.is_owner == False, (
46+
f"Expected False (fallback), got {employee.is_owner}"
47+
)
48+
49+
# Test loading in new session
50+
with Session(engine) as session:
51+
employee = session.get(Employee, employee_id)
52+
53+
# Should have fallback value immediately
54+
print(f"is_owner in new session: {employee.is_owner}")
55+
assert employee.is_owner == False, (
56+
f"Expected False (fallback), got {employee.is_owner}"
57+
)
58+
59+
# Refresh again
60+
session.refresh(employee)
61+
print(f"is_owner after second refresh: {employee.is_owner}")
62+
assert employee.is_owner == False, (
63+
f"Expected False (fallback), got {employee.is_owner}"
64+
)
65+
66+
print("✅ Fallback after refresh test passed!")
67+
68+
69+
def test_no_actual_load():
70+
"""Test that deferred property never actually loads from database"""
71+
72+
class TestModel(SQLModel, table=True):
73+
__tablename__ = "test_no_load"
74+
75+
id: Optional[int] = Field(default=None, primary_key=True)
76+
value: int = 10
77+
78+
@classmethod
79+
def __declare_last__(cls):
80+
cls.computed = deferred_column_property(
81+
cls.__table__.c.value * 100, # Would be 1000 if loaded from DB
82+
fallback_value=-999, # Should always return this instead
83+
deferred=True,
84+
)
85+
86+
engine = create_engine("sqlite:///:memory:")
87+
SQLModel.metadata.create_all(engine)
88+
89+
with Session(engine) as session:
90+
obj = TestModel(value=10)
91+
session.add(obj)
92+
session.commit()
93+
session.refresh(obj)
94+
obj_id = obj.id
95+
96+
# Load in new session - should get fallback, NOT computed value from DB
97+
with Session(engine) as session:
98+
obj = session.get(TestModel, obj_id)
99+
100+
print(f"Value from DB: {obj.value}") # Should be 10
101+
print(
102+
f"Computed (should be fallback): {obj.computed}"
103+
) # Should be -999, NOT 1000
104+
105+
# The key test: we should NEVER get 1000 (the computed value)
106+
assert obj.computed == -999, (
107+
f"Expected -999 (fallback), got {obj.computed} - deferred property was incorrectly loaded!"
108+
)
109+
110+
# Even multiple accesses should return fallback
111+
assert obj.computed == -999, f"Second access should still return fallback"
112+
113+
print("✅ No actual load test passed!")
114+
115+
116+
if __name__ == "__main__":
117+
test_fallback_after_refresh()
118+
test_no_actual_load()
119+
print("\\n✅ All event-based fallback tests passed!")

0 commit comments

Comments
 (0)