Skip to content

Commit 50e771b

Browse files
committed
actual e2e
Signed-off-by: Sai Shree Pradhan <[email protected]>
1 parent 11d41ce commit 50e771b

File tree

1 file changed

+54
-145
lines changed

1 file changed

+54
-145
lines changed
Lines changed: 54 additions & 145 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,9 @@
11
import threading
2-
from unittest.mock import patch, MagicMock
3-
4-
from databricks.sql.client import Connection
5-
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory, TelemetryClient
6-
from databricks.sql.thrift_backend import ThriftBackend
7-
from databricks.sql.utils import ExecuteResponse
8-
from databricks.sql.thrift_api.TCLIService.ttypes import TSessionHandle, TOperationHandle, TOperationState, THandleIdentifier
9-
10-
try:
11-
import pyarrow as pa
12-
except ImportError:
13-
pa = None
2+
from unittest.mock import patch
3+
import pytest
144

5+
from databricks.sql.telemetry.telemetry_client import TelemetryClient, TelemetryClientFactory
6+
from tests.e2e.test_driver import PySQLPytestTestCase
157

168
def run_in_threads(target, num_threads, pass_index=False):
179
"""Helper to run target function in multiple threads."""
@@ -25,150 +17,67 @@ def run_in_threads(target, num_threads, pass_index=False):
2517
t.join()
2618

2719

28-
class MockArrowQueue:
29-
"""Mock queue that behaves like ArrowQueue but returns empty results."""
30-
31-
def __init__(self):
32-
# Create an empty arrow table if pyarrow is available, otherwise use None
33-
if pa is not None:
34-
self.empty_table = pa.table({'column': pa.array([])})
35-
else:
36-
# Create a simple mock table-like object
37-
self.empty_table = MagicMock()
38-
self.empty_table.num_rows = 0
39-
self.empty_table.num_columns = 0
40-
41-
def next_n_rows(self, num_rows: int):
42-
"""Return empty results."""
43-
return self.empty_table
20+
class TestE2ETelemetry(PySQLPytestTestCase):
4421

45-
def remaining_rows(self):
46-
"""Return empty results."""
47-
return self.empty_table
22+
@pytest.fixture(autouse=True)
23+
def telemetry_setup_teardown(self):
24+
"""
25+
This fixture ensures the TelemetryClientFactory is in a clean state
26+
before each test and shuts it down afterward. Using a fixture makes
27+
this robust and automatic.
28+
"""
29+
# --- SETUP ---
30+
if TelemetryClientFactory._executor:
31+
TelemetryClientFactory._executor.shutdown(wait=True)
32+
TelemetryClientFactory._clients.clear()
33+
TelemetryClientFactory._executor = None
34+
TelemetryClientFactory._initialized = False
4835

36+
yield # This is where the test runs
4937

50-
def test_concurrent_queries_with_telemetry_capture():
51-
"""
52-
Test showing concurrent threads executing queries with real telemetry capture.
53-
Uses the actual Connection and Cursor classes, mocking only the ThriftBackend.
54-
"""
55-
num_threads = 5
56-
captured_telemetry = []
57-
connections = [] # Store connections to close them later
58-
connections_lock = threading.Lock() # Thread safety for connections list
38+
# --- TEARDOWN ---
39+
if TelemetryClientFactory._executor:
40+
TelemetryClientFactory._executor.shutdown(wait=True)
41+
TelemetryClientFactory._executor = None
42+
TelemetryClientFactory._initialized = False
5943

60-
def mock_send_telemetry(self, events):
61-
"""Capture telemetry events instead of sending them over network."""
62-
captured_telemetry.extend(events)
44+
def test_concurrent_queries_sends_telemetry(self):
45+
"""
46+
An E2E test where concurrent threads execute real queries against
47+
the staging endpoint, while we capture and verify the generated telemetry.
48+
"""
49+
num_threads = 5
50+
captured_telemetry = []
51+
captured_telemetry_lock = threading.Lock()
6352

64-
# Clean up any existing state
65-
if TelemetryClientFactory._executor:
66-
TelemetryClientFactory._executor.shutdown(wait=True)
67-
TelemetryClientFactory._clients.clear()
68-
TelemetryClientFactory._executor = None
69-
TelemetryClientFactory._initialized = False
53+
def mock_send_telemetry(self, events):
54+
"""
55+
This is our telemetry interceptor. It captures events into our list
56+
instead of sending them over the network.
57+
"""
58+
with captured_telemetry_lock:
59+
captured_telemetry.extend(events)
7060

71-
with patch.object(TelemetryClient, '_send_telemetry', mock_send_telemetry):
72-
# Mock the ThriftBackend to avoid actual network calls
73-
with patch.object(ThriftBackend, 'open_session') as mock_open_session, \
74-
patch.object(ThriftBackend, 'execute_command') as mock_execute_command, \
75-
patch.object(ThriftBackend, 'close_session') as mock_close_session, \
76-
patch.object(ThriftBackend, 'fetch_results') as mock_fetch_results, \
77-
patch.object(ThriftBackend, 'close_command') as mock_close_command, \
78-
patch.object(ThriftBackend, 'handle_to_hex_id') as mock_handle_to_hex_id, \
79-
patch('databricks.sql.auth.thrift_http_client.THttpClient.open') as mock_transport_open:
80-
81-
# Mock transport.open() to prevent actual network connection
82-
mock_transport_open.return_value = None
83-
84-
# Set up mock responses with proper structure
85-
mock_handle_identifier = THandleIdentifier()
86-
mock_handle_identifier.guid = b'1234567890abcdef'
87-
mock_handle_identifier.secret = b'test_secret_1234'
88-
89-
mock_session_handle = TSessionHandle()
90-
mock_session_handle.sessionId = mock_handle_identifier
91-
mock_session_handle.serverProtocolVersion = 1
92-
93-
mock_open_session.return_value = MagicMock(
94-
sessionHandle=mock_session_handle,
95-
serverProtocolVersion=1
96-
)
97-
98-
mock_handle_to_hex_id.return_value = "test-session-id-12345678"
99-
100-
mock_op_handle = TOperationHandle()
101-
mock_op_handle.operationId = THandleIdentifier()
102-
mock_op_handle.operationId.guid = b'abcdef1234567890'
103-
mock_op_handle.operationId.secret = b'op_secret_abcd'
104-
105-
# Create proper mock arrow_queue with required methods
106-
mock_arrow_queue = MockArrowQueue()
107-
108-
mock_execute_response = ExecuteResponse(
109-
arrow_queue=mock_arrow_queue,
110-
description=[],
111-
command_handle=mock_op_handle,
112-
status=TOperationState.FINISHED_STATE,
113-
has_been_closed_server_side=False,
114-
has_more_rows=False,
115-
lz4_compressed=False,
116-
arrow_schema_bytes=b'',
117-
is_staging_operation=False
118-
)
119-
mock_execute_command.return_value = mock_execute_response
120-
121-
# Mock fetch_results to return empty results
122-
mock_fetch_results.return_value = (mock_arrow_queue, False)
123-
124-
# Mock close_command to do nothing
125-
mock_close_command.return_value = None
126-
127-
# Mock close_session to do nothing
128-
mock_close_session.return_value = None
61+
with patch.object(TelemetryClient, '_send_telemetry', mock_send_telemetry):
12962

13063
def execute_query_worker(thread_id):
13164
"""Each thread creates a connection and executes a query."""
132-
133-
# Create real Connection and Cursor objects
134-
conn = Connection(
135-
server_hostname="test-host",
136-
http_path="/test/path",
137-
access_token="test-token",
138-
enable_telemetry=True
139-
)
140-
141-
# Thread-safe storage of connection
142-
with connections_lock:
143-
connections.append(conn)
144-
145-
cursor = conn.cursor()
146-
# This will trigger the @log_latency decorator naturally
147-
cursor.execute(f"SELECT {thread_id} as thread_id")
148-
result = cursor.fetchall()
149-
conn.close()
150-
65+
with self.connection(extra_params={"enable_telemetry": True}) as conn:
66+
with conn.cursor() as cursor:
67+
cursor.execute(f"SELECT {thread_id}")
68+
cursor.fetchall()
15169

70+
# Run the workers concurrently
15271
run_in_threads(execute_query_worker, num_threads, pass_index=True)
15372

154-
# We expect at least 2 events per thread (one for open_session and one for execute_command)
155-
assert len(captured_telemetry) >= num_threads*2
156-
print(f"Captured telemetry: {captured_telemetry}")
157-
158-
# Verify the decorator was used (check some telemetry events have latency measurement)
73+
if TelemetryClientFactory._executor:
74+
TelemetryClientFactory._executor.shutdown(wait=True)
75+
76+
# --- VERIFICATION ---
77+
assert len(captured_telemetry) == num_threads * 4 # 4 events per thread (initial_telemetry_log, 3 latency_logs (execute_command, fetchall_arrow, _convert_arrow_table))
78+
15979
events_with_latency = [
160-
e for e in captured_telemetry
161-
if hasattr(e, 'entry') and hasattr(e.entry, 'sql_driver_log')
162-
and e.entry.sql_driver_log.operation_latency_ms is not None
80+
e for e in captured_telemetry
81+
if e.entry.sql_driver_log.operation_latency_ms is not None and e.entry.sql_driver_log.sql_statement_id is not None
16382
]
164-
assert len(events_with_latency) >= num_threads
165-
166-
# Verify we have events with statement IDs (indicating @log_latency decorator worked)
167-
events_with_statements = [
168-
e for e in captured_telemetry
169-
if hasattr(e, 'entry') and hasattr(e.entry, 'sql_driver_log')
170-
and e.entry.sql_driver_log.sql_statement_id is not None
171-
]
172-
assert len(events_with_statements) >= num_threads
173-
174-
83+
assert len(events_with_latency) == num_threads * 3 # 3 events per thread (execute_command, fetchall_arrow, _convert_arrow_table)

0 commit comments

Comments
 (0)