Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -42,7 +42,7 @@
import functools
import logging
import re
from typing import Any, Callable, Generic, TypeVar
from typing import Any, Awaitable, Callable, Generic, TypeVar

import wrapt
from wrapt import wrap_function_wrapper
Expand Down Expand Up @@ -596,6 +596,44 @@ def traced_execution(
self._populate_span(span, cursor, *args)
return query_method(*args, **kwargs)

async def traced_execution_async(
self,
cursor: CursorT,
query_method: Callable[..., Awaitable[Any]],
*args: tuple[Any, ...],
**kwargs: dict[Any, Any],
):
name = self.get_operation_name(cursor, args)
if not name:
name = (
self._db_api_integration.database
if self._db_api_integration.database
else self._db_api_integration.name
)

with self._db_api_integration._tracer.start_as_current_span(
name, kind=SpanKind.CLIENT
) as span:
if span.is_recording():
if args and self._commenter_enabled:
if self._enable_attribute_commenter:
# sqlcomment is added to executed query and db.statement span attribute
args = self._update_args_with_added_sql_comment(
args, cursor
)
self._populate_span(span, cursor, *args)
else:
# sqlcomment is only added to executed query
# so db.statement is set before add_sql_comment
self._populate_span(span, cursor, *args)
args = self._update_args_with_added_sql_comment(
args, cursor
)
else:
# no sqlcomment anywhere
self._populate_span(span, cursor, *args)
return await query_method(*args, **kwargs)


# pylint: disable=abstract-method
class TracedCursorProxy(wrapt.ObjectProxy, Generic[CursorT]):
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -399,17 +399,17 @@ def _new_cursor_async_factory(

class TracedCursorAsyncFactory(base_factory):
async def execute(self, *args: Any, **kwargs: Any):
return await _cursor_tracer.traced_execution(
return await _cursor_tracer.traced_execution_async(
self, super().execute, *args, **kwargs
)

async def executemany(self, *args: Any, **kwargs: Any):
return await _cursor_tracer.traced_execution(
return await _cursor_tracer.traced_execution_async(
self, super().executemany, *args, **kwargs
)

async def callproc(self, *args: Any, **kwargs: Any):
return await _cursor_tracer.traced_execution(
return await _cursor_tracer.traced_execution_async(
self, super().callproc, *args, **kwargs
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.

import asyncio
import types
from unittest import IsolatedAsyncioTestCase, mock

Expand Down Expand Up @@ -50,10 +51,15 @@ def __init__(self, *args, **kwargs):
pass

# pylint: disable=unused-argument, no-self-use
async def execute(self, query, params=None, throw_exception=False):
async def execute(
self, query, params=None, throw_exception=False, delay=0.0
):
if throw_exception:
raise psycopg.Error("Test Exception")

if delay:
await asyncio.sleep(delay)

# pylint: disable=unused-argument, no-self-use
async def executemany(self, query, params=None, throw_exception=False):
if throw_exception:
Expand Down Expand Up @@ -492,3 +498,27 @@ async def test_not_recording_async(self):
self.assertFalse(mock_span.set_status.called)

PsycopgInstrumentor().uninstrument()

async def test_tracing_is_async(self):
PsycopgInstrumentor().instrument()

# before this async fix cursor.execute would take 14000 ns, delaying for
# 100,000ns
delay = 0.0001

async def test_async_connection():
acnx = await psycopg.AsyncConnection.connect("test")
async with acnx as cnx:
async with cnx.cursor() as cursor:
await cursor.execute("SELECT * FROM test", delay=delay)

await test_async_connection()
spans_list = self.memory_exporter.get_finished_spans()
self.assertEqual(len(spans_list), 1)
span = spans_list[0]

# duration is nanoseconds
duration = span.end_time - span.start_time
self.assertGreater(duration, delay * 1e9)

PsycopgInstrumentor().uninstrument()
Loading