11import threading
2- from unittest .mock import patch , MagicMock
3-
4- from databricks .sql .client import Connection
5- from databricks .sql .telemetry .telemetry_client import TelemetryClientFactory , TelemetryClient
6- from databricks .sql .thrift_backend import ThriftBackend
7- from databricks .sql .utils import ExecuteResponse
8- from databricks .sql .thrift_api .TCLIService .ttypes import TSessionHandle , TOperationHandle , TOperationState , THandleIdentifier
9-
10- try :
11- import pyarrow as pa
12- except ImportError :
13- pa = None
2+ from unittest .mock import patch
3+ import pytest
144
5+ from databricks .sql .telemetry .telemetry_client import TelemetryClient , TelemetryClientFactory
6+ from tests .e2e .test_driver import PySQLPytestTestCase
157
168def run_in_threads (target , num_threads , pass_index = False ):
179 """Helper to run target function in multiple threads."""
@@ -25,150 +17,67 @@ def run_in_threads(target, num_threads, pass_index=False):
2517 t .join ()
2618
2719
28- class MockArrowQueue :
29- """Mock queue that behaves like ArrowQueue but returns empty results."""
30-
31- def __init__ (self ):
32- # Create an empty arrow table if pyarrow is available, otherwise use None
33- if pa is not None :
34- self .empty_table = pa .table ({'column' : pa .array ([])})
35- else :
36- # Create a simple mock table-like object
37- self .empty_table = MagicMock ()
38- self .empty_table .num_rows = 0
39- self .empty_table .num_columns = 0
40-
41- def next_n_rows (self , num_rows : int ):
42- """Return empty results."""
43- return self .empty_table
20+ class TestE2ETelemetry (PySQLPytestTestCase ):
4421
45- def remaining_rows (self ):
46- """Return empty results."""
47- return self .empty_table
22+ @pytest .fixture (autouse = True )
23+ def telemetry_setup_teardown (self ):
24+ """
25+ This fixture ensures the TelemetryClientFactory is in a clean state
26+ before each test and shuts it down afterward. Using a fixture makes
27+ this robust and automatic.
28+ """
29+ # --- SETUP ---
30+ if TelemetryClientFactory ._executor :
31+ TelemetryClientFactory ._executor .shutdown (wait = True )
32+ TelemetryClientFactory ._clients .clear ()
33+ TelemetryClientFactory ._executor = None
34+ TelemetryClientFactory ._initialized = False
4835
36+ yield # This is where the test runs
4937
50- def test_concurrent_queries_with_telemetry_capture ():
51- """
52- Test showing concurrent threads executing queries with real telemetry capture.
53- Uses the actual Connection and Cursor classes, mocking only the ThriftBackend.
54- """
55- num_threads = 5
56- captured_telemetry = []
57- connections = [] # Store connections to close them later
58- connections_lock = threading .Lock () # Thread safety for connections list
38+ # --- TEARDOWN ---
39+ if TelemetryClientFactory ._executor :
40+ TelemetryClientFactory ._executor .shutdown (wait = True )
41+ TelemetryClientFactory ._executor = None
42+ TelemetryClientFactory ._initialized = False
5943
60- def mock_send_telemetry (self , events ):
61- """Capture telemetry events instead of sending them over network."""
62- captured_telemetry .extend (events )
44+ def test_concurrent_queries_sends_telemetry (self ):
45+ """
46+ An E2E test where concurrent threads execute real queries against
47+ the staging endpoint, while we capture and verify the generated telemetry.
48+ """
49+ num_threads = 5
50+ captured_telemetry = []
51+ captured_telemetry_lock = threading .Lock ()
6352
64- # Clean up any existing state
65- if TelemetryClientFactory ._executor :
66- TelemetryClientFactory ._executor .shutdown (wait = True )
67- TelemetryClientFactory ._clients .clear ()
68- TelemetryClientFactory ._executor = None
69- TelemetryClientFactory ._initialized = False
53+ def mock_send_telemetry (self , events ):
54+ """
55+ This is our telemetry interceptor. It captures events into our list
56+ instead of sending them over the network.
57+ """
58+ with captured_telemetry_lock :
59+ captured_telemetry .extend (events )
7060
71- with patch .object (TelemetryClient , '_send_telemetry' , mock_send_telemetry ):
72- # Mock the ThriftBackend to avoid actual network calls
73- with patch .object (ThriftBackend , 'open_session' ) as mock_open_session , \
74- patch .object (ThriftBackend , 'execute_command' ) as mock_execute_command , \
75- patch .object (ThriftBackend , 'close_session' ) as mock_close_session , \
76- patch .object (ThriftBackend , 'fetch_results' ) as mock_fetch_results , \
77- patch .object (ThriftBackend , 'close_command' ) as mock_close_command , \
78- patch .object (ThriftBackend , 'handle_to_hex_id' ) as mock_handle_to_hex_id , \
79- patch ('databricks.sql.auth.thrift_http_client.THttpClient.open' ) as mock_transport_open :
80-
81- # Mock transport.open() to prevent actual network connection
82- mock_transport_open .return_value = None
83-
84- # Set up mock responses with proper structure
85- mock_handle_identifier = THandleIdentifier ()
86- mock_handle_identifier .guid = b'1234567890abcdef'
87- mock_handle_identifier .secret = b'test_secret_1234'
88-
89- mock_session_handle = TSessionHandle ()
90- mock_session_handle .sessionId = mock_handle_identifier
91- mock_session_handle .serverProtocolVersion = 1
92-
93- mock_open_session .return_value = MagicMock (
94- sessionHandle = mock_session_handle ,
95- serverProtocolVersion = 1
96- )
97-
98- mock_handle_to_hex_id .return_value = "test-session-id-12345678"
99-
100- mock_op_handle = TOperationHandle ()
101- mock_op_handle .operationId = THandleIdentifier ()
102- mock_op_handle .operationId .guid = b'abcdef1234567890'
103- mock_op_handle .operationId .secret = b'op_secret_abcd'
104-
105- # Create proper mock arrow_queue with required methods
106- mock_arrow_queue = MockArrowQueue ()
107-
108- mock_execute_response = ExecuteResponse (
109- arrow_queue = mock_arrow_queue ,
110- description = [],
111- command_handle = mock_op_handle ,
112- status = TOperationState .FINISHED_STATE ,
113- has_been_closed_server_side = False ,
114- has_more_rows = False ,
115- lz4_compressed = False ,
116- arrow_schema_bytes = b'' ,
117- is_staging_operation = False
118- )
119- mock_execute_command .return_value = mock_execute_response
120-
121- # Mock fetch_results to return empty results
122- mock_fetch_results .return_value = (mock_arrow_queue , False )
123-
124- # Mock close_command to do nothing
125- mock_close_command .return_value = None
126-
127- # Mock close_session to do nothing
128- mock_close_session .return_value = None
61+ with patch .object (TelemetryClient , '_send_telemetry' , mock_send_telemetry ):
12962
13063 def execute_query_worker (thread_id ):
13164 """Each thread creates a connection and executes a query."""
132-
133- # Create real Connection and Cursor objects
134- conn = Connection (
135- server_hostname = "test-host" ,
136- http_path = "/test/path" ,
137- access_token = "test-token" ,
138- enable_telemetry = True
139- )
140-
141- # Thread-safe storage of connection
142- with connections_lock :
143- connections .append (conn )
144-
145- cursor = conn .cursor ()
146- # This will trigger the @log_latency decorator naturally
147- cursor .execute (f"SELECT { thread_id } as thread_id" )
148- result = cursor .fetchall ()
149- conn .close ()
150-
65+ with self .connection (extra_params = {"enable_telemetry" : True }) as conn :
66+ with conn .cursor () as cursor :
67+ cursor .execute (f"SELECT { thread_id } " )
68+ cursor .fetchall ()
15169
70+ # Run the workers concurrently
15271 run_in_threads (execute_query_worker , num_threads , pass_index = True )
15372
154- # We expect at least 2 events per thread (one for open_session and one for execute_command)
155- assert len (captured_telemetry ) >= num_threads * 2
156- print (f"Captured telemetry: { captured_telemetry } " )
157-
158- # Verify the decorator was used (check some telemetry events have latency measurement)
73+ if TelemetryClientFactory ._executor :
74+ TelemetryClientFactory ._executor .shutdown (wait = True )
75+
76+ # --- VERIFICATION ---
77+ assert len (captured_telemetry ) == num_threads * 4 # 4 events per thread (initial_telemetry_log, 3 latency_logs (execute_command, fetchall_arrow, _convert_arrow_table))
78+
15979 events_with_latency = [
160- e for e in captured_telemetry
161- if hasattr (e , 'entry' ) and hasattr (e .entry , 'sql_driver_log' )
162- and e .entry .sql_driver_log .operation_latency_ms is not None
80+ e for e in captured_telemetry
81+ if e .entry .sql_driver_log .operation_latency_ms is not None and e .entry .sql_driver_log .sql_statement_id is not None
16382 ]
164- assert len (events_with_latency ) >= num_threads
165-
166- # Verify we have events with statement IDs (indicating @log_latency decorator worked)
167- events_with_statements = [
168- e for e in captured_telemetry
169- if hasattr (e , 'entry' ) and hasattr (e .entry , 'sql_driver_log' )
170- and e .entry .sql_driver_log .sql_statement_id is not None
171- ]
172- assert len (events_with_statements ) >= num_threads
173-
174-
83+ assert len (events_with_latency ) == num_threads * 3 # 3 events per thread (execute_command, fetchall_arrow, _convert_arrow_table)
0 commit comments