Skip to content

Commit 11d41ce

Browse files
committed
test
Signed-off-by: Sai Shree Pradhan <[email protected]>
1 parent 3e9b47d commit 11d41ce

File tree

1 file changed

+174
-0
lines changed

1 file changed

+174
-0
lines changed
Lines changed: 174 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,174 @@
1+
import threading
2+
from unittest.mock import patch, MagicMock
3+
4+
from databricks.sql.client import Connection
5+
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory, TelemetryClient
6+
from databricks.sql.thrift_backend import ThriftBackend
7+
from databricks.sql.utils import ExecuteResponse
8+
from databricks.sql.thrift_api.TCLIService.ttypes import TSessionHandle, TOperationHandle, TOperationState, THandleIdentifier
9+
10+
try:
11+
import pyarrow as pa
12+
except ImportError:
13+
pa = None
14+
15+
16+
def run_in_threads(target, num_threads, pass_index=False):
17+
"""Helper to run target function in multiple threads."""
18+
threads = [
19+
threading.Thread(target=target, args=(i,) if pass_index else ())
20+
for i in range(num_threads)
21+
]
22+
for t in threads:
23+
t.start()
24+
for t in threads:
25+
t.join()
26+
27+
28+
class MockArrowQueue:
29+
"""Mock queue that behaves like ArrowQueue but returns empty results."""
30+
31+
def __init__(self):
32+
# Create an empty arrow table if pyarrow is available, otherwise use None
33+
if pa is not None:
34+
self.empty_table = pa.table({'column': pa.array([])})
35+
else:
36+
# Create a simple mock table-like object
37+
self.empty_table = MagicMock()
38+
self.empty_table.num_rows = 0
39+
self.empty_table.num_columns = 0
40+
41+
def next_n_rows(self, num_rows: int):
42+
"""Return empty results."""
43+
return self.empty_table
44+
45+
def remaining_rows(self):
46+
"""Return empty results."""
47+
return self.empty_table
48+
49+
50+
def test_concurrent_queries_with_telemetry_capture():
51+
"""
52+
Test showing concurrent threads executing queries with real telemetry capture.
53+
Uses the actual Connection and Cursor classes, mocking only the ThriftBackend.
54+
"""
55+
num_threads = 5
56+
captured_telemetry = []
57+
connections = [] # Store connections to close them later
58+
connections_lock = threading.Lock() # Thread safety for connections list
59+
60+
def mock_send_telemetry(self, events):
61+
"""Capture telemetry events instead of sending them over network."""
62+
captured_telemetry.extend(events)
63+
64+
# Clean up any existing state
65+
if TelemetryClientFactory._executor:
66+
TelemetryClientFactory._executor.shutdown(wait=True)
67+
TelemetryClientFactory._clients.clear()
68+
TelemetryClientFactory._executor = None
69+
TelemetryClientFactory._initialized = False
70+
71+
with patch.object(TelemetryClient, '_send_telemetry', mock_send_telemetry):
72+
# Mock the ThriftBackend to avoid actual network calls
73+
with patch.object(ThriftBackend, 'open_session') as mock_open_session, \
74+
patch.object(ThriftBackend, 'execute_command') as mock_execute_command, \
75+
patch.object(ThriftBackend, 'close_session') as mock_close_session, \
76+
patch.object(ThriftBackend, 'fetch_results') as mock_fetch_results, \
77+
patch.object(ThriftBackend, 'close_command') as mock_close_command, \
78+
patch.object(ThriftBackend, 'handle_to_hex_id') as mock_handle_to_hex_id, \
79+
patch('databricks.sql.auth.thrift_http_client.THttpClient.open') as mock_transport_open:
80+
81+
# Mock transport.open() to prevent actual network connection
82+
mock_transport_open.return_value = None
83+
84+
# Set up mock responses with proper structure
85+
mock_handle_identifier = THandleIdentifier()
86+
mock_handle_identifier.guid = b'1234567890abcdef'
87+
mock_handle_identifier.secret = b'test_secret_1234'
88+
89+
mock_session_handle = TSessionHandle()
90+
mock_session_handle.sessionId = mock_handle_identifier
91+
mock_session_handle.serverProtocolVersion = 1
92+
93+
mock_open_session.return_value = MagicMock(
94+
sessionHandle=mock_session_handle,
95+
serverProtocolVersion=1
96+
)
97+
98+
mock_handle_to_hex_id.return_value = "test-session-id-12345678"
99+
100+
mock_op_handle = TOperationHandle()
101+
mock_op_handle.operationId = THandleIdentifier()
102+
mock_op_handle.operationId.guid = b'abcdef1234567890'
103+
mock_op_handle.operationId.secret = b'op_secret_abcd'
104+
105+
# Create proper mock arrow_queue with required methods
106+
mock_arrow_queue = MockArrowQueue()
107+
108+
mock_execute_response = ExecuteResponse(
109+
arrow_queue=mock_arrow_queue,
110+
description=[],
111+
command_handle=mock_op_handle,
112+
status=TOperationState.FINISHED_STATE,
113+
has_been_closed_server_side=False,
114+
has_more_rows=False,
115+
lz4_compressed=False,
116+
arrow_schema_bytes=b'',
117+
is_staging_operation=False
118+
)
119+
mock_execute_command.return_value = mock_execute_response
120+
121+
# Mock fetch_results to return empty results
122+
mock_fetch_results.return_value = (mock_arrow_queue, False)
123+
124+
# Mock close_command to do nothing
125+
mock_close_command.return_value = None
126+
127+
# Mock close_session to do nothing
128+
mock_close_session.return_value = None
129+
130+
def execute_query_worker(thread_id):
131+
"""Each thread creates a connection and executes a query."""
132+
133+
# Create real Connection and Cursor objects
134+
conn = Connection(
135+
server_hostname="test-host",
136+
http_path="/test/path",
137+
access_token="test-token",
138+
enable_telemetry=True
139+
)
140+
141+
# Thread-safe storage of connection
142+
with connections_lock:
143+
connections.append(conn)
144+
145+
cursor = conn.cursor()
146+
# This will trigger the @log_latency decorator naturally
147+
cursor.execute(f"SELECT {thread_id} as thread_id")
148+
result = cursor.fetchall()
149+
conn.close()
150+
151+
152+
run_in_threads(execute_query_worker, num_threads, pass_index=True)
153+
154+
# We expect at least 2 events per thread (one for open_session and one for execute_command)
155+
assert len(captured_telemetry) >= num_threads*2
156+
print(f"Captured telemetry: {captured_telemetry}")
157+
158+
# Verify the decorator was used (check some telemetry events have latency measurement)
159+
events_with_latency = [
160+
e for e in captured_telemetry
161+
if hasattr(e, 'entry') and hasattr(e.entry, 'sql_driver_log')
162+
and e.entry.sql_driver_log.operation_latency_ms is not None
163+
]
164+
assert len(events_with_latency) >= num_threads
165+
166+
# Verify we have events with statement IDs (indicating @log_latency decorator worked)
167+
events_with_statements = [
168+
e for e in captured_telemetry
169+
if hasattr(e, 'entry') and hasattr(e.entry, 'sql_driver_log')
170+
and e.entry.sql_driver_log.sql_statement_id is not None
171+
]
172+
assert len(events_with_statements) >= num_threads
173+
174+

0 commit comments

Comments
 (0)