1212# See the License for the specific language governing permissions and
1313# limitations under the License.
1414
15+ from threading import local
16+ from weakref import WeakKeyDictionary
17+
1518from sqlalchemy .event import listen # pylint: disable=no-name-in-module
1619
1720from opentelemetry import trace
@@ -66,12 +69,21 @@ def __init__(self, tracer, engine):
6669 self .tracer = tracer
6770 self .engine = engine
6871 self .vendor = _normalize_vendor (engine .name )
69- self .current_span = None
72+ self .cursor_mapping = WeakKeyDictionary ()
73+ self .local = local ()
7074
7175 listen (engine , "before_cursor_execute" , self ._before_cur_exec )
7276 listen (engine , "after_cursor_execute" , self ._after_cur_exec )
7377 listen (engine , "handle_error" , self ._handle_error )
7478
79+ @property
80+ def current_thread_span (self ):
81+ return getattr (self .local , "current_span" , None )
82+
83+ @current_thread_span .setter
84+ def current_thread_span (self , span ):
85+ setattr (self .local , "current_span" , span )
86+
7587 def _operation_name (self , db_name , statement ):
7688 parts = []
7789 if isinstance (statement , str ):
@@ -94,34 +106,38 @@ def _before_cur_exec(self, conn, cursor, statement, *args):
94106 attrs = _get_attributes_from_cursor (self .vendor , cursor , attrs )
95107
96108 db_name = attrs .get (_DB , "" )
97- self . current_span = self .tracer .start_span (
109+ span = self .tracer .start_span (
98110 self ._operation_name (db_name , statement ),
99111 kind = trace .SpanKind .CLIENT ,
100112 )
101- with trace .use_span (self .current_span , end_on_exit = False ):
102- if self .current_span .is_recording ():
103- self .current_span .set_attribute (_STMT , statement )
104- self .current_span .set_attribute ("db.system" , self .vendor )
113+ self .current_thread_span = self .cursor_mapping [cursor ] = span
114+ with trace .use_span (span , end_on_exit = False ):
115+ if span .is_recording ():
116+ span .set_attribute (_STMT , statement )
117+ span .set_attribute ("db.system" , self .vendor )
105118 for key , value in attrs .items ():
106- self . current_span .set_attribute (key , value )
119+ span .set_attribute (key , value )
107120
108121 # pylint: disable=unused-argument
109122 def _after_cur_exec (self , conn , cursor , statement , * args ):
110- if self .current_span is None :
123+ span = self .cursor_mapping .get (cursor , None )
124+ if span is None :
111125 return
112- self .current_span .end ()
126+
127+ span .end ()
113128
114129 def _handle_error (self , context ):
115- if self .current_span is None :
130+ span = self .current_thread_span
131+ if span is None :
116132 return
117133
118134 try :
119- if self . current_span .is_recording ():
120- self . current_span .set_status (
135+ if span .is_recording ():
136+ span .set_status (
121137 Status (StatusCode .ERROR , str (context .original_exception ),)
122138 )
123139 finally :
124- self . current_span .end ()
140+ span .end ()
125141
126142
127143def _get_attributes_from_url (url ):
0 commit comments