22import pytest
33import requests
44from unittest .mock import patch , MagicMock
5+ import threading
6+ import random
7+ import time
8+ from concurrent .futures import ThreadPoolExecutor
59
610from databricks .sql .telemetry .telemetry_client import (
711 TelemetryClient ,
812 NoopTelemetryClient ,
913 TelemetryClientFactory ,
1014 TelemetryHelper ,
11- BaseTelemetryClient
1215)
1316from databricks .sql .telemetry .models .enums import AuthMech , AuthFlow
1417from databricks .sql .auth .authenticators import (
@@ -283,4 +286,176 @@ def test_factory_shutdown_flow(self, telemetry_system_reset):
283286 # Close second client - factory should shut down
284287 TelemetryClientFactory .close (session2 )
285288 assert TelemetryClientFactory ._initialized is False
286- assert TelemetryClientFactory ._executor is None
289+ assert TelemetryClientFactory ._executor is None
290+
291+
292+ # A helper function to run a target in multiple threads and wait for them.
293+ def run_in_threads (target , num_threads , pass_index = False ):
294+ """Creates, starts, and joins a specified number of threads.
295+
296+ Args:
297+ target: The function to run in each thread
298+ num_threads: Number of threads to create
299+ pass_index: If True, passes the thread index (0, 1, 2, ...) as first argument
300+ """
301+ threads = [
302+ threading .Thread (target = target , args = (i ,) if pass_index else ())
303+ for i in range (num_threads )
304+ ]
305+ for t in threads :
306+ t .start ()
307+ for t in threads :
308+ t .join ()
309+
310+
311+ class TestTelemetryRaceConditions :
312+ """Tests for race conditions in multithreaded scenarios."""
313+
314+ @pytest .fixture (autouse = True )
315+ def clean_factory (self ):
316+ """A fixture to automatically reset the factory's state before each test."""
317+ # Clean up at the start of each test
318+ if TelemetryClientFactory ._executor :
319+ TelemetryClientFactory ._executor .shutdown (wait = True )
320+ TelemetryClientFactory ._clients .clear ()
321+ TelemetryClientFactory ._executor = None
322+ TelemetryClientFactory ._initialized = False
323+
324+ yield
325+
326+ # Clean up at the end of each test
327+ if TelemetryClientFactory ._executor :
328+ TelemetryClientFactory ._executor .shutdown (wait = True )
329+ TelemetryClientFactory ._clients .clear ()
330+ TelemetryClientFactory ._executor = None
331+ TelemetryClientFactory ._initialized = False
332+
333+ def test_factory_concurrent_initialization_of_DIFFERENT_clients (self ):
334+ """
335+ Tests that multiple threads creating DIFFERENT clients concurrently
336+ share a single ThreadPoolExecutor and all clients are created successfully.
337+ """
338+ num_threads = 20
339+
340+ def create_client (thread_id ):
341+ TelemetryClientFactory .initialize_telemetry_client (
342+ telemetry_enabled = True ,
343+ session_id_hex = f"session_{ thread_id } " ,
344+ auth_provider = None ,
345+ host_url = "test-host" ,
346+ )
347+
348+ run_in_threads (create_client , 20 , pass_index = True )
349+
350+ # ASSERT: The factory was properly initialized
351+ assert TelemetryClientFactory ._initialized is True
352+ assert TelemetryClientFactory ._executor is not None
353+ assert isinstance (TelemetryClientFactory ._executor , ThreadPoolExecutor )
354+
355+ # ASSERT: All clients were successfully created
356+ assert len (TelemetryClientFactory ._clients ) == num_threads
357+
358+ # ASSERT: All TelemetryClient instances share the same executor
359+ telemetry_clients = [
360+ client for client in TelemetryClientFactory ._clients .values ()
361+ if isinstance (client , TelemetryClient )
362+ ]
363+ assert len (telemetry_clients ) == num_threads
364+
365+ shared_executor = TelemetryClientFactory ._executor
366+ for client in telemetry_clients :
367+ assert client ._executor is shared_executor
368+
369+ def test_factory_concurrent_initialization_of_SAME_client (self ):
370+ """
371+ Tests that multiple threads trying to initialize the SAME client
372+ result in only one client instance being created.
373+ """
374+ session_id = "shared-session"
375+ num_threads = 20
376+
377+ def create_same_client ():
378+ TelemetryClientFactory .initialize_telemetry_client (
379+ telemetry_enabled = True ,
380+ session_id_hex = session_id ,
381+ auth_provider = None ,
382+ host_url = "test-host" ,
383+ )
384+
385+ run_in_threads (create_same_client , num_threads )
386+
387+ # ASSERT: Only one client was created in the factory.
388+ assert len (TelemetryClientFactory ._clients ) == 1
389+ client = TelemetryClientFactory .get_telemetry_client (session_id )
390+ assert isinstance (client , TelemetryClient )
391+
392+ def test_client_concurrent_event_export (self ):
393+ """
394+ Tests that no events are lost when multiple threads call _export_event
395+ on the same client instance concurrently.
396+ """
397+ client = TelemetryClient (True , "session-1" , None , "host" , MagicMock ())
398+ # Mock _flush to prevent auto-flushing when batch size threshold is reached
399+ original_flush = client ._flush
400+ client ._flush = MagicMock ()
401+
402+ num_threads = 5
403+ events_per_thread = 10
404+
405+ def add_events ():
406+ for i in range (events_per_thread ):
407+ client ._export_event (f"event-{ i } " )
408+
409+ run_in_threads (add_events , num_threads )
410+
411+ # ASSERT: The batch contains all events from all threads, none were lost.
412+ total_expected_events = num_threads * events_per_thread
413+ assert len (client ._events_batch ) == total_expected_events
414+
415+ # Restore original flush method for cleanup
416+ client ._flush = original_flush
417+
418+ def test_client_concurrent_flush (self ):
419+ """
420+ Tests that if multiple threads trigger _flush at the same time,
421+ the underlying send operation is only called once for the batch.
422+ """
423+ client = TelemetryClient (True , "session-1" , None , "host" , MagicMock ())
424+ client ._send_telemetry = MagicMock ()
425+
426+ # Pre-fill the batch so there's something to flush
427+ client ._events_batch = ["event" ] * 5
428+
429+ def call_flush ():
430+ client ._flush ()
431+
432+ run_in_threads (call_flush , 10 )
433+
434+ # ASSERT: The send operation was called exactly once.
435+ # This proves the lock prevents multiple threads from sending the same batch.
436+ client ._send_telemetry .assert_called_once ()
437+ # ASSERT: The event batch is now empty.
438+ assert len (client ._events_batch ) == 0
439+
440+ def test_factory_concurrent_create_and_close (self ):
441+ """
442+ Tests that concurrently creating and closing different clients
443+ doesn't corrupt the factory state and correctly shuts down the executor.
444+ """
445+ num_ops = 50
446+
447+ def create_and_close_client (i ):
448+ session_id = f"session_{ i } "
449+ TelemetryClientFactory .initialize_telemetry_client (
450+ telemetry_enabled = True , session_id_hex = session_id , auth_provider = None , host_url = "host"
451+ )
452+ # Small sleep to increase chance of interleaving operations
453+ time .sleep (random .uniform (0 , 0.01 ))
454+ TelemetryClientFactory .close (session_id )
455+
456+ run_in_threads (create_and_close_client , num_ops , pass_index = True )
457+
458+ # ASSERT: After all operations, the factory should be empty and reset.
459+ assert not TelemetryClientFactory ._clients
460+ assert TelemetryClientFactory ._executor is None
461+ assert not TelemetryClientFactory ._initialized
0 commit comments