Skip to content

Commit fe8cd57

Browse files
authored
Testing for telemetry (#616)
* e2e test telemetry Signed-off-by: Sai Shree Pradhan <[email protected]> * assert session id, statement id Signed-off-by: Sai Shree Pradhan <[email protected]> * minor changes, added checks on server response Signed-off-by: Sai Shree Pradhan <[email protected]> * finally block Signed-off-by: Sai Shree Pradhan <[email protected]> * removed setup clean up Signed-off-by: Sai Shree Pradhan <[email protected]> * finally in test_complex_types Signed-off-by: Sai Shree Pradhan <[email protected]> --------- Signed-off-by: Sai Shree Pradhan <[email protected]>
1 parent e732e96 commit fe8cd57

File tree

2 files changed

+171
-3
lines changed

2 files changed

+171
-3
lines changed

tests/e2e/test_complex_types.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -39,9 +39,11 @@ def table_fixture(self, connection_details):
3939
)
4040
"""
4141
)
42-
yield
43-
# Clean up the table after the test
44-
cursor.execute("DELETE FROM pysql_test_complex_types_table")
42+
try:
43+
yield
44+
finally:
45+
# Clean up the table after the test
46+
cursor.execute("DELETE FROM pysql_test_complex_types_table")
4547

4648
@pytest.mark.parametrize(
4749
"field,expected_type",
Lines changed: 166 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,166 @@
1+
import random
2+
import threading
3+
import time
4+
from unittest.mock import patch
5+
import pytest
6+
7+
from databricks.sql.telemetry.models.enums import StatementType
8+
from databricks.sql.telemetry.telemetry_client import TelemetryClient, TelemetryClientFactory
9+
from tests.e2e.test_driver import PySQLPytestTestCase
10+
11+
def run_in_threads(target, num_threads, pass_index=False):
12+
"""Helper to run target function in multiple threads."""
13+
threads = [
14+
threading.Thread(target=target, args=(i,) if pass_index else ())
15+
for i in range(num_threads)
16+
]
17+
for t in threads:
18+
t.start()
19+
for t in threads:
20+
t.join()
21+
22+
23+
class TestE2ETelemetry(PySQLPytestTestCase):
24+
25+
@pytest.fixture(autouse=True)
26+
def telemetry_setup_teardown(self):
27+
"""
28+
This fixture ensures the TelemetryClientFactory is in a clean state
29+
before each test and shuts it down afterward. Using a fixture makes
30+
this robust and automatic.
31+
"""
32+
try:
33+
yield
34+
finally:
35+
if TelemetryClientFactory._executor:
36+
TelemetryClientFactory._executor.shutdown(wait=True)
37+
TelemetryClientFactory._executor = None
38+
TelemetryClientFactory._initialized = False
39+
40+
def test_concurrent_queries_sends_telemetry(self):
41+
"""
42+
An E2E test where concurrent threads execute real queries against
43+
the staging endpoint, while we capture and verify the generated telemetry.
44+
"""
45+
num_threads = 30
46+
capture_lock = threading.Lock()
47+
captured_telemetry = []
48+
captured_session_ids = []
49+
captured_statement_ids = []
50+
captured_responses = []
51+
captured_exceptions = []
52+
53+
original_send_telemetry = TelemetryClient._send_telemetry
54+
original_callback = TelemetryClient._telemetry_request_callback
55+
56+
def send_telemetry_wrapper(self_client, events):
57+
with capture_lock:
58+
captured_telemetry.extend(events)
59+
original_send_telemetry(self_client, events)
60+
61+
def callback_wrapper(self_client, future, sent_count):
62+
"""
63+
Wraps the original callback to capture the server's response
64+
or any exceptions from the async network call.
65+
"""
66+
try:
67+
original_callback(self_client, future, sent_count)
68+
69+
# Now, capture the result for our assertions
70+
response = future.result()
71+
response.raise_for_status() # Raise an exception for 4xx/5xx errors
72+
telemetry_response = response.json()
73+
with capture_lock:
74+
captured_responses.append(telemetry_response)
75+
except Exception as e:
76+
with capture_lock:
77+
captured_exceptions.append(e)
78+
79+
with patch.object(TelemetryClient, "_send_telemetry", send_telemetry_wrapper), \
80+
patch.object(TelemetryClient, "_telemetry_request_callback", callback_wrapper):
81+
82+
def execute_query_worker(thread_id):
83+
"""Each thread creates a connection and executes a query."""
84+
85+
time.sleep(random.uniform(0, 0.05))
86+
87+
with self.connection(extra_params={"enable_telemetry": True}) as conn:
88+
# Capture the session ID from the connection before executing the query
89+
session_id_hex = conn.get_session_id_hex()
90+
with capture_lock:
91+
captured_session_ids.append(session_id_hex)
92+
93+
with conn.cursor() as cursor:
94+
cursor.execute(f"SELECT {thread_id}")
95+
# Capture the statement ID after executing the query
96+
statement_id = cursor.query_id
97+
with capture_lock:
98+
captured_statement_ids.append(statement_id)
99+
cursor.fetchall()
100+
101+
# Run the workers concurrently
102+
run_in_threads(execute_query_worker, num_threads, pass_index=True)
103+
104+
if TelemetryClientFactory._executor:
105+
TelemetryClientFactory._executor.shutdown(wait=True)
106+
107+
# --- VERIFICATION ---
108+
assert not captured_exceptions
109+
assert len(captured_responses) > 0
110+
111+
total_successful_events = 0
112+
for response in captured_responses:
113+
assert "errors" not in response or not response["errors"]
114+
if "numProtoSuccess" in response:
115+
total_successful_events += response["numProtoSuccess"]
116+
assert total_successful_events == num_threads * 2
117+
118+
assert len(captured_telemetry) == num_threads * 2 # 2 events per thread (initial_telemetry_log, latency_log (execute))
119+
assert len(captured_session_ids) == num_threads # One session ID per thread
120+
assert len(captured_statement_ids) == num_threads # One statement ID per thread (per query)
121+
122+
# Separate initial logs from latency logs
123+
initial_logs = [
124+
e for e in captured_telemetry
125+
if e.entry.sql_driver_log.operation_latency_ms is None
126+
and e.entry.sql_driver_log.driver_connection_params is not None
127+
and e.entry.sql_driver_log.system_configuration is not None
128+
]
129+
latency_logs = [
130+
e for e in captured_telemetry
131+
if e.entry.sql_driver_log.operation_latency_ms is not None
132+
and e.entry.sql_driver_log.sql_statement_id is not None
133+
and e.entry.sql_driver_log.sql_operation.statement_type == StatementType.QUERY
134+
]
135+
136+
# Verify counts
137+
assert len(initial_logs) == num_threads
138+
assert len(latency_logs) == num_threads
139+
140+
# Verify that telemetry events contain the exact session IDs we captured from connections
141+
telemetry_session_ids = set()
142+
for event in captured_telemetry:
143+
session_id = event.entry.sql_driver_log.session_id
144+
assert session_id is not None
145+
telemetry_session_ids.add(session_id)
146+
147+
captured_session_ids_set = set(captured_session_ids)
148+
assert telemetry_session_ids == captured_session_ids_set
149+
assert len(captured_session_ids_set) == num_threads
150+
151+
# Verify that telemetry latency logs contain the exact statement IDs we captured from cursors
152+
telemetry_statement_ids = set()
153+
for event in latency_logs:
154+
statement_id = event.entry.sql_driver_log.sql_statement_id
155+
assert statement_id is not None
156+
telemetry_statement_ids.add(statement_id)
157+
158+
captured_statement_ids_set = set(captured_statement_ids)
159+
assert telemetry_statement_ids == captured_statement_ids_set
160+
assert len(captured_statement_ids_set) == num_threads
161+
162+
# Verify that each latency log has a statement ID from our captured set
163+
for event in latency_logs:
164+
log = event.entry.sql_driver_log
165+
assert log.sql_statement_id in captured_statement_ids
166+
assert log.session_id in captured_session_ids

0 commit comments

Comments
 (0)