1+ import threading
2+ from unittest .mock import patch , MagicMock
3+
4+ from databricks .sql .client import Connection
5+ from databricks .sql .telemetry .telemetry_client import TelemetryClientFactory , TelemetryClient
6+ from databricks .sql .thrift_backend import ThriftBackend
7+ from databricks .sql .utils import ExecuteResponse
8+ from databricks .sql .thrift_api .TCLIService .ttypes import TSessionHandle , TOperationHandle , TOperationState , THandleIdentifier
9+
10+ try :
11+ import pyarrow as pa
12+ except ImportError :
13+ pa = None
14+
15+
16+ def run_in_threads (target , num_threads , pass_index = False ):
17+ """Helper to run target function in multiple threads."""
18+ threads = [
19+ threading .Thread (target = target , args = (i ,) if pass_index else ())
20+ for i in range (num_threads )
21+ ]
22+ for t in threads :
23+ t .start ()
24+ for t in threads :
25+ t .join ()
26+
27+
28+ class MockArrowQueue :
29+ """Mock queue that behaves like ArrowQueue but returns empty results."""
30+
31+ def __init__ (self ):
32+ # Create an empty arrow table if pyarrow is available, otherwise use None
33+ if pa is not None :
34+ self .empty_table = pa .table ({'column' : pa .array ([])})
35+ else :
36+ # Create a simple mock table-like object
37+ self .empty_table = MagicMock ()
38+ self .empty_table .num_rows = 0
39+ self .empty_table .num_columns = 0
40+
41+ def next_n_rows (self , num_rows : int ):
42+ """Return empty results."""
43+ return self .empty_table
44+
45+ def remaining_rows (self ):
46+ """Return empty results."""
47+ return self .empty_table
48+
49+
50+ def test_concurrent_queries_with_telemetry_capture ():
51+ """
52+ Test showing concurrent threads executing queries with real telemetry capture.
53+ Uses the actual Connection and Cursor classes, mocking only the ThriftBackend.
54+ """
55+ num_threads = 5
56+ captured_telemetry = []
57+ connections = [] # Store connections to close them later
58+ connections_lock = threading .Lock () # Thread safety for connections list
59+
60+ def mock_send_telemetry (self , events ):
61+ """Capture telemetry events instead of sending them over network."""
62+ captured_telemetry .extend (events )
63+
64+ # Clean up any existing state
65+ if TelemetryClientFactory ._executor :
66+ TelemetryClientFactory ._executor .shutdown (wait = True )
67+ TelemetryClientFactory ._clients .clear ()
68+ TelemetryClientFactory ._executor = None
69+ TelemetryClientFactory ._initialized = False
70+
71+ with patch .object (TelemetryClient , '_send_telemetry' , mock_send_telemetry ):
72+ # Mock the ThriftBackend to avoid actual network calls
73+ with patch .object (ThriftBackend , 'open_session' ) as mock_open_session , \
74+ patch .object (ThriftBackend , 'execute_command' ) as mock_execute_command , \
75+ patch .object (ThriftBackend , 'close_session' ) as mock_close_session , \
76+ patch .object (ThriftBackend , 'fetch_results' ) as mock_fetch_results , \
77+ patch .object (ThriftBackend , 'close_command' ) as mock_close_command , \
78+ patch .object (ThriftBackend , 'handle_to_hex_id' ) as mock_handle_to_hex_id , \
79+ patch ('databricks.sql.auth.thrift_http_client.THttpClient.open' ) as mock_transport_open :
80+
81+ # Mock transport.open() to prevent actual network connection
82+ mock_transport_open .return_value = None
83+
84+ # Set up mock responses with proper structure
85+ mock_handle_identifier = THandleIdentifier ()
86+ mock_handle_identifier .guid = b'1234567890abcdef'
87+ mock_handle_identifier .secret = b'test_secret_1234'
88+
89+ mock_session_handle = TSessionHandle ()
90+ mock_session_handle .sessionId = mock_handle_identifier
91+ mock_session_handle .serverProtocolVersion = 1
92+
93+ mock_open_session .return_value = MagicMock (
94+ sessionHandle = mock_session_handle ,
95+ serverProtocolVersion = 1
96+ )
97+
98+ mock_handle_to_hex_id .return_value = "test-session-id-12345678"
99+
100+ mock_op_handle = TOperationHandle ()
101+ mock_op_handle .operationId = THandleIdentifier ()
102+ mock_op_handle .operationId .guid = b'abcdef1234567890'
103+ mock_op_handle .operationId .secret = b'op_secret_abcd'
104+
105+ # Create proper mock arrow_queue with required methods
106+ mock_arrow_queue = MockArrowQueue ()
107+
108+ mock_execute_response = ExecuteResponse (
109+ arrow_queue = mock_arrow_queue ,
110+ description = [],
111+ command_handle = mock_op_handle ,
112+ status = TOperationState .FINISHED_STATE ,
113+ has_been_closed_server_side = False ,
114+ has_more_rows = False ,
115+ lz4_compressed = False ,
116+ arrow_schema_bytes = b'' ,
117+ is_staging_operation = False
118+ )
119+ mock_execute_command .return_value = mock_execute_response
120+
121+ # Mock fetch_results to return empty results
122+ mock_fetch_results .return_value = (mock_arrow_queue , False )
123+
124+ # Mock close_command to do nothing
125+ mock_close_command .return_value = None
126+
127+ # Mock close_session to do nothing
128+ mock_close_session .return_value = None
129+
130+ def execute_query_worker (thread_id ):
131+ """Each thread creates a connection and executes a query."""
132+
133+ # Create real Connection and Cursor objects
134+ conn = Connection (
135+ server_hostname = "test-host" ,
136+ http_path = "/test/path" ,
137+ access_token = "test-token" ,
138+ enable_telemetry = True
139+ )
140+
141+ # Thread-safe storage of connection
142+ with connections_lock :
143+ connections .append (conn )
144+
145+ cursor = conn .cursor ()
146+ # This will trigger the @log_latency decorator naturally
147+ cursor .execute (f"SELECT { thread_id } as thread_id" )
148+ result = cursor .fetchall ()
149+ conn .close ()
150+
151+
152+ run_in_threads (execute_query_worker , num_threads , pass_index = True )
153+
154+ # We expect at least 2 events per thread (one for open_session and one for execute_command)
155+ assert len (captured_telemetry ) >= num_threads * 2
156+ print (f"Captured telemetry: { captured_telemetry } " )
157+
158+ # Verify the decorator was used (check some telemetry events have latency measurement)
159+ events_with_latency = [
160+ e for e in captured_telemetry
161+ if hasattr (e , 'entry' ) and hasattr (e .entry , 'sql_driver_log' )
162+ and e .entry .sql_driver_log .operation_latency_ms is not None
163+ ]
164+ assert len (events_with_latency ) >= num_threads
165+
166+ # Verify we have events with statement IDs (indicating @log_latency decorator worked)
167+ events_with_statements = [
168+ e for e in captured_telemetry
169+ if hasattr (e , 'entry' ) and hasattr (e .entry , 'sql_driver_log' )
170+ and e .entry .sql_driver_log .sql_statement_id is not None
171+ ]
172+ assert len (events_with_statements ) >= num_threads
173+
174+
0 commit comments