88from typing import Any
99
1010from sqlalchemy .orm import ColumnProperty
11+ from sqlalchemy .orm .attributes import InstrumentedAttribute
1112from sqlalchemy .orm .strategies import DeferredColumnLoader , _state_session
1213
1314
15+ class SafeAttributeWrapper :
16+ """
17+ A simple wrapper around InstrumentedAttribute that checks session validity
18+ on every access and returns fallback values when needed.
19+ """
20+
21+ def __init__ (self , original_attr , fallback_value ):
22+ self .original_attr = original_attr
23+ self .fallback_value = fallback_value
24+ # Copy important attributes from original
25+ self .__name__ = getattr (original_attr , "__name__" , None )
26+ self .__doc__ = getattr (original_attr , "__doc__" , None )
27+
28+ def __get__ (self , instance , owner ):
29+ """Intercept attribute access to check session validity"""
30+ if instance is None :
31+ return self
32+
33+ # Check session state before accessing
34+ try :
35+ state = instance ._sa_instance_state
36+ session = _state_session (state )
37+
38+ # First check for invalid async context - regardless of session state
39+ if self ._is_invalid_async_context (state ):
40+ return self .fallback_value
41+
42+ # If no session, check if attribute is already loaded
43+ if session is None :
44+ if (
45+ hasattr (instance , "__dict__" )
46+ and self .original_attr .key in instance .__dict__
47+ ):
48+ # Attribute was loaded previously but session is now invalid
49+ # However, if we detect we SHOULD be in async context but aren't,
50+ # return fallback instead of cached value
51+ return instance .__dict__ [self .original_attr .key ]
52+ else :
53+ # Not loaded and no session - return fallback
54+ return self .fallback_value
55+
56+ # Session is valid, proceed with normal access through original attribute
57+ return self .original_attr .__get__ (instance , owner )
58+
59+ except Exception as e :
60+ # If any error occurs during access, check if it's async-related
61+ error_msg = str (e ).lower ()
62+ if any (
63+ keyword in error_msg
64+ for keyword in [
65+ "greenlet" ,
66+ "await_only" ,
67+ "asyncio" ,
68+ "async" ,
69+ "missinggreenlet" ,
70+ ]
71+ ):
72+ return self .fallback_value
73+ # For other errors, re-raise
74+ raise
75+
76+ def __set__ (self , instance , value ):
77+ """Delegate setting to original attribute"""
78+ return self .original_attr .__set__ (instance , value )
79+
80+ def __delete__ (self , instance ):
81+ """Delegate deletion to original attribute"""
82+ return self .original_attr .__delete__ (instance )
83+
84+ def _is_invalid_async_context (self , state ):
85+ """Check if we're in an invalid async context that would cause MissingGreenlet"""
86+ try :
87+ # Check if we have async session
88+ if hasattr (state , "async_session" ) and state .async_session is not None :
89+ # We have async session, need to check greenlet context
90+ try :
91+ import greenlet
92+
93+ current = greenlet .getcurrent ()
94+ # If we're not in a greenlet context but have async session,
95+ # accessing deferred attributes will fail
96+ if current is None or current .parent is None :
97+ return True
98+ except ImportError :
99+ # No greenlet support, assume we're in invalid context if async_session exists
100+ return True
101+ return False
102+ except Exception :
103+ # If any check fails, assume we're in invalid context
104+ return True
105+
106+ # Make wrapper transparent to SQLAlchemy inspection system
107+ def __getattr__ (self , name ):
108+ """Proxy all other attributes to the original InstrumentedAttribute"""
109+ return getattr (self .original_attr , name )
110+
111+ def _sa_inspect_type (self ):
112+ """Support SQLAlchemy inspection by delegating to original attribute"""
113+ if hasattr (self .original_attr , "_sa_inspect_type" ):
114+ return self .original_attr ._sa_inspect_type ()
115+ return None
116+
117+
14118class SafeDeferredColumnLoader (DeferredColumnLoader ):
15119 """
16120 A custom deferred column loader that returns a fallback value instead of
@@ -53,19 +157,59 @@ def _load_for_state(self, state, passive):
53157 return LoaderCallableStatus .ATTR_WAS_SET
54158 return self .fallback_value
55159
56- # Check if this is an async session context that might cause MissingGreenlet
160+ # Check if this is an AsyncSession that might cause MissingGreenlet
161+ async_session = state .async_session
162+ if async_session is not None :
163+ # We have an async session, check if we're in proper async context
164+ try :
165+ # Try to import greenlet to check context
166+ import greenlet
167+
168+ current_greenlet = greenlet .getcurrent ()
169+ # If we're in the main thread without proper async context,
170+ # the greenlet will not have a proper parent or spawn context
171+ if current_greenlet .parent is None and not hasattr (
172+ current_greenlet , "_spawning_greenlet"
173+ ):
174+ # We're likely in sync code trying to access async session attributes
175+ instance = state .obj ()
176+ if instance is not None :
177+ set_committed_value (instance , self .key , self .fallback_value )
178+ return LoaderCallableStatus .ATTR_WAS_SET
179+ return self .fallback_value
180+ except (ImportError , AttributeError ):
181+ # greenlet not available, but we know it's an async session
182+ # in sync context - return fallback
183+ instance = state .obj ()
184+ if instance is not None :
185+ set_committed_value (instance , self .key , self .fallback_value )
186+ return LoaderCallableStatus .ATTR_WAS_SET
187+ return self .fallback_value
188+
189+ # Final attempt with error handling for any remaining async issues
57190 try :
58- # Try to access session._connection_for_bind to check if we're in async context
59- # without proper greenlet
60- if hasattr (session , "get_bind" ) and hasattr (
61- session , "_connection_for_bind"
62- ):
63- # This is a more elegant way to detect async context issues
64- # If we're in async session without greenlet context, this will fail
65- session .get_bind ()
191+ return super ()._load_for_state (state , passive )
66192 except Exception as e :
67193 # Handle async-related errors (MissingGreenlet, etc.)
68194 error_msg = str (e ).lower ()
195+ if any (
196+ keyword in error_msg
197+ for keyword in [
198+ "greenlet" ,
199+ "await_only" ,
200+ "asyncio" ,
201+ "async" ,
202+ "missinggreenlet" ,
203+ ]
204+ ):
205+ # This is an async-related error, set fallback value
206+ instance = state .obj ()
207+ if instance is not None :
208+ set_committed_value (instance , self .key , self .fallback_value )
209+ return LoaderCallableStatus .ATTR_WAS_SET
210+ return self .fallback_value
211+ # For other exceptions, re-raise them
212+ raise
69213 if any (
70214 keyword in error_msg
71215 for keyword in [
@@ -85,9 +229,6 @@ def _load_for_state(self, state, passive):
85229 # For other exceptions, re-raise them
86230 raise
87231
88- # We have a proper session, use the parent implementation
89- return super ()._load_for_state (state , passive )
90-
91232
92233class SafeColumnProperty (ColumnProperty ):
93234 """
0 commit comments