Skip to content

Commit 3e9b47d

Browse files
committed
tests
Signed-off-by: Sai Shree Pradhan <[email protected]>
1 parent 6c5d6ba commit 3e9b47d

File tree

2 files changed

+179
-5
lines changed

2 files changed

+179
-5
lines changed

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -156,7 +156,6 @@ class TelemetryClient(BaseTelemetryClient):
156156
# Telemetry endpoint paths
157157
TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext"
158158
TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth"
159-
DEFAULT_BATCH_SIZE = 10
160159

161160
def __init__(
162161
self,
@@ -168,7 +167,7 @@ def __init__(
168167
):
169168
logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex)
170169
self._telemetry_enabled = telemetry_enabled
171-
self._batch_size = self.DEFAULT_BATCH_SIZE # TODO: Decide on batch size
170+
self._batch_size = 10 # TODO: Decide on batch size
172171
self._session_id_hex = session_id_hex
173172
self._auth_provider = auth_provider
174173
self._user_agent = None
@@ -403,7 +402,7 @@ def get_telemetry_client(session_id_hex):
403402
if session_id_hex in TelemetryClientFactory._clients:
404403
return TelemetryClientFactory._clients[session_id_hex]
405404
else:
406-
logger.error(
405+
logger.debug(
407406
"Telemetry client not initialized for connection %s",
408407
session_id_hex,
409408
)

tests/unit/test_telemetry.py

Lines changed: 177 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,16 @@
22
import pytest
33
import requests
44
from unittest.mock import patch, MagicMock
5+
import threading
6+
import random
7+
import time
8+
from concurrent.futures import ThreadPoolExecutor
59

610
from databricks.sql.telemetry.telemetry_client import (
711
TelemetryClient,
812
NoopTelemetryClient,
913
TelemetryClientFactory,
1014
TelemetryHelper,
11-
BaseTelemetryClient
1215
)
1316
from databricks.sql.telemetry.models.enums import AuthMech, AuthFlow
1417
from databricks.sql.auth.authenticators import (
@@ -283,4 +286,176 @@ def test_factory_shutdown_flow(self, telemetry_system_reset):
283286
# Close second client - factory should shut down
284287
TelemetryClientFactory.close(session2)
285288
assert TelemetryClientFactory._initialized is False
286-
assert TelemetryClientFactory._executor is None
289+
assert TelemetryClientFactory._executor is None
290+
291+
292+
# A helper function to run a target in multiple threads and wait for them.
293+
def run_in_threads(target, num_threads, pass_index=False):
294+
"""Creates, starts, and joins a specified number of threads.
295+
296+
Args:
297+
target: The function to run in each thread
298+
num_threads: Number of threads to create
299+
pass_index: If True, passes the thread index (0, 1, 2, ...) as first argument
300+
"""
301+
threads = [
302+
threading.Thread(target=target, args=(i,) if pass_index else ())
303+
for i in range(num_threads)
304+
]
305+
for t in threads:
306+
t.start()
307+
for t in threads:
308+
t.join()
309+
310+
311+
class TestTelemetryRaceConditions:
312+
"""Tests for race conditions in multithreaded scenarios."""
313+
314+
@pytest.fixture(autouse=True)
315+
def clean_factory(self):
316+
"""A fixture to automatically reset the factory's state before each test."""
317+
# Clean up at the start of each test
318+
if TelemetryClientFactory._executor:
319+
TelemetryClientFactory._executor.shutdown(wait=True)
320+
TelemetryClientFactory._clients.clear()
321+
TelemetryClientFactory._executor = None
322+
TelemetryClientFactory._initialized = False
323+
324+
yield
325+
326+
# Clean up at the end of each test
327+
if TelemetryClientFactory._executor:
328+
TelemetryClientFactory._executor.shutdown(wait=True)
329+
TelemetryClientFactory._clients.clear()
330+
TelemetryClientFactory._executor = None
331+
TelemetryClientFactory._initialized = False
332+
333+
def test_factory_concurrent_initialization_of_DIFFERENT_clients(self):
334+
"""
335+
Tests that multiple threads creating DIFFERENT clients concurrently
336+
share a single ThreadPoolExecutor and all clients are created successfully.
337+
"""
338+
num_threads = 20
339+
340+
def create_client(thread_id):
341+
TelemetryClientFactory.initialize_telemetry_client(
342+
telemetry_enabled=True,
343+
session_id_hex=f"session_{thread_id}",
344+
auth_provider=None,
345+
host_url="test-host",
346+
)
347+
348+
run_in_threads(create_client, 20, pass_index=True)
349+
350+
# ASSERT: The factory was properly initialized
351+
assert TelemetryClientFactory._initialized is True
352+
assert TelemetryClientFactory._executor is not None
353+
assert isinstance(TelemetryClientFactory._executor, ThreadPoolExecutor)
354+
355+
# ASSERT: All clients were successfully created
356+
assert len(TelemetryClientFactory._clients) == num_threads
357+
358+
# ASSERT: All TelemetryClient instances share the same executor
359+
telemetry_clients = [
360+
client for client in TelemetryClientFactory._clients.values()
361+
if isinstance(client, TelemetryClient)
362+
]
363+
assert len(telemetry_clients) == num_threads
364+
365+
shared_executor = TelemetryClientFactory._executor
366+
for client in telemetry_clients:
367+
assert client._executor is shared_executor
368+
369+
def test_factory_concurrent_initialization_of_SAME_client(self):
370+
"""
371+
Tests that multiple threads trying to initialize the SAME client
372+
result in only one client instance being created.
373+
"""
374+
session_id = "shared-session"
375+
num_threads = 20
376+
377+
def create_same_client():
378+
TelemetryClientFactory.initialize_telemetry_client(
379+
telemetry_enabled=True,
380+
session_id_hex=session_id,
381+
auth_provider=None,
382+
host_url="test-host",
383+
)
384+
385+
run_in_threads(create_same_client, num_threads)
386+
387+
# ASSERT: Only one client was created in the factory.
388+
assert len(TelemetryClientFactory._clients) == 1
389+
client = TelemetryClientFactory.get_telemetry_client(session_id)
390+
assert isinstance(client, TelemetryClient)
391+
392+
def test_client_concurrent_event_export(self):
393+
"""
394+
Tests that no events are lost when multiple threads call _export_event
395+
on the same client instance concurrently.
396+
"""
397+
client = TelemetryClient(True, "session-1", None, "host", MagicMock())
398+
# Mock _flush to prevent auto-flushing when batch size threshold is reached
399+
original_flush = client._flush
400+
client._flush = MagicMock()
401+
402+
num_threads = 5
403+
events_per_thread = 10
404+
405+
def add_events():
406+
for i in range(events_per_thread):
407+
client._export_event(f"event-{i}")
408+
409+
run_in_threads(add_events, num_threads)
410+
411+
# ASSERT: The batch contains all events from all threads, none were lost.
412+
total_expected_events = num_threads * events_per_thread
413+
assert len(client._events_batch) == total_expected_events
414+
415+
# Restore original flush method for cleanup
416+
client._flush = original_flush
417+
418+
def test_client_concurrent_flush(self):
419+
"""
420+
Tests that if multiple threads trigger _flush at the same time,
421+
the underlying send operation is only called once for the batch.
422+
"""
423+
client = TelemetryClient(True, "session-1", None, "host", MagicMock())
424+
client._send_telemetry = MagicMock()
425+
426+
# Pre-fill the batch so there's something to flush
427+
client._events_batch = ["event"] * 5
428+
429+
def call_flush():
430+
client._flush()
431+
432+
run_in_threads(call_flush, 10)
433+
434+
# ASSERT: The send operation was called exactly once.
435+
# This proves the lock prevents multiple threads from sending the same batch.
436+
client._send_telemetry.assert_called_once()
437+
# ASSERT: The event batch is now empty.
438+
assert len(client._events_batch) == 0
439+
440+
def test_factory_concurrent_create_and_close(self):
441+
"""
442+
Tests that concurrently creating and closing different clients
443+
doesn't corrupt the factory state and correctly shuts down the executor.
444+
"""
445+
num_ops = 50
446+
447+
def create_and_close_client(i):
448+
session_id = f"session_{i}"
449+
TelemetryClientFactory.initialize_telemetry_client(
450+
telemetry_enabled=True, session_id_hex=session_id, auth_provider=None, host_url="host"
451+
)
452+
# Small sleep to increase chance of interleaving operations
453+
time.sleep(random.uniform(0, 0.01))
454+
TelemetryClientFactory.close(session_id)
455+
456+
run_in_threads(create_and_close_client, num_ops, pass_index=True)
457+
458+
# ASSERT: After all operations, the factory should be empty and reset.
459+
assert not TelemetryClientFactory._clients
460+
assert TelemetryClientFactory._executor is None
461+
assert not TelemetryClientFactory._initialized

0 commit comments

Comments
 (0)