Skip to content

Commit 1355283

Browse files
committed
Merge branch 'telemetry' into telemetry-testing
2 parents ea86fe2 + cac9c7a commit 1355283

File tree

9 files changed

+87
-101
lines changed

9 files changed

+87
-101
lines changed

.github/CODEOWNERS

Lines changed: 0 additions & 5 deletions
This file was deleted.

.github/workflows/code-quality-checks.yml

Lines changed: 3 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,15 +1,7 @@
11
name: Code Quality Checks
2-
on:
3-
push:
4-
branches:
5-
- main
6-
- sea-migration
7-
- telemetry
8-
pull_request:
9-
branches:
10-
- main
11-
- sea-migration
12-
- telemetry
2+
3+
on: [pull_request]
4+
135
jobs:
146
run-unit-tests:
157
runs-on: ubuntu-latest

.github/workflows/integration.yml

Lines changed: 3 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,10 @@
11
name: Integration Tests
2+
23
on:
3-
push:
4-
paths-ignore:
5-
- "**.MD"
6-
- "**.md"
7-
pull_request:
4+
push:
85
branches:
96
- main
10-
- sea-migration
11-
- telemetry
7+
pull_request:
128

139
jobs:
1410
run-e2e-tests:

src/databricks/sql/client.py

Lines changed: 3 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -1389,7 +1389,6 @@ def _fill_results_buffer(self):
13891389
self.results = results
13901390
self.has_more_rows = has_more_rows
13911391

1392-
@log_latency()
13931392
def _convert_columnar_table(self, table):
13941393
column_names = [c[0] for c in self.description]
13951394
ResultRow = Row(*column_names)
@@ -1402,7 +1401,6 @@ def _convert_columnar_table(self, table):
14021401

14031402
return result
14041403

1405-
@log_latency()
14061404
def _convert_arrow_table(self, table):
14071405
column_names = [c[0] for c in self.description]
14081406
ResultRow = Row(*column_names)
@@ -1445,7 +1443,6 @@ def _convert_arrow_table(self, table):
14451443
def rownumber(self):
14461444
return self._next_row_index
14471445

1448-
@log_latency()
14491446
def fetchmany_arrow(self, size: int) -> "pyarrow.Table":
14501447
"""
14511448
Fetch the next set of rows of a query result, returning a PyArrow table.
@@ -1488,7 +1485,6 @@ def merge_columnar(self, result1, result2):
14881485
]
14891486
return ColumnTable(merged_result, result1.column_names)
14901487

1491-
@log_latency()
14921488
def fetchmany_columnar(self, size: int):
14931489
"""
14941490
Fetch the next set of rows of a query result, returning a Columnar Table.
@@ -1514,7 +1510,6 @@ def fetchmany_columnar(self, size: int):
15141510

15151511
return results
15161512

1517-
@log_latency()
15181513
def fetchall_arrow(self) -> "pyarrow.Table":
15191514
"""Fetch all (remaining) rows of a query result, returning them as a PyArrow table."""
15201515
results = self.results.remaining_rows()
@@ -1541,7 +1536,6 @@ def fetchall_arrow(self) -> "pyarrow.Table":
15411536
return pyarrow.Table.from_pydict(data)
15421537
return results
15431538

1544-
@log_latency()
15451539
def fetchall_columnar(self):
15461540
"""Fetch all (remaining) rows of a query result, returning them as a Columnar table."""
15471541
results = self.results.remaining_rows()
@@ -1555,6 +1549,7 @@ def fetchall_columnar(self):
15551549

15561550
return results
15571551

1552+
@log_latency()
15581553
def fetchone(self) -> Optional[Row]:
15591554
"""
15601555
Fetch the next row of a query result set, returning a single sequence,
@@ -1571,6 +1566,7 @@ def fetchone(self) -> Optional[Row]:
15711566
else:
15721567
return None
15731568

1569+
@log_latency()
15741570
def fetchall(self) -> List[Row]:
15751571
"""
15761572
Fetch all (remaining) rows of a query result, returning them as a list of rows.
@@ -1580,6 +1576,7 @@ def fetchall(self) -> List[Row]:
15801576
else:
15811577
return self._convert_arrow_table(self.fetchall_arrow())
15821578

1579+
@log_latency()
15831580
def fetchmany(self, size: int) -> List[Row]:
15841581
"""
15851582
Fetch the next set of rows of a query result, returning a list of rows.

src/databricks/sql/exc.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,9 @@
11
import json
22
import logging
33

4+
logger = logging.getLogger(__name__)
45
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
56

6-
logger = logging.getLogger(__name__)
77

88
### PEP-249 Mandated ###
99
# https://peps.python.org/pep-0249/#exceptions

src/databricks/sql/telemetry/latency_logger.py

Lines changed: 26 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
import time
22
import functools
33
from typing import Optional
4+
import logging
45
from databricks.sql.telemetry.telemetry_client import TelemetryClientFactory
56
from databricks.sql.telemetry.models.event import (
67
SqlExecutionEvent,
@@ -9,6 +10,8 @@
910
from databricks.sql.utils import ColumnQueue, CloudFetchQueue, ArrowQueue
1011
from uuid import UUID
1112

13+
logger = logging.getLogger(__name__)
14+
1215

1316
class TelemetryExtractor:
1417
"""
@@ -145,14 +148,15 @@ def get_extractor(obj):
145148
TelemetryExtractor: A specialized extractor instance:
146149
- CursorExtractor for Cursor objects
147150
- ResultSetExtractor for ResultSet objects
148-
- Throws an NotImplementedError for all other objects
151+
- None for all other objects
149152
"""
150153
if obj.__class__.__name__ == "Cursor":
151154
return CursorExtractor(obj)
152155
elif obj.__class__.__name__ == "ResultSet":
153156
return ResultSetExtractor(obj)
154157
else:
155-
raise NotImplementedError(f"No extractor found for {obj.__class__.__name__}")
158+
logger.error(f"No extractor found for {obj.__class__.__name__}")
159+
return None
156160

157161

158162
def log_latency(statement_type: StatementType = StatementType.NONE):
@@ -207,24 +211,26 @@ def _safe_call(func_to_call):
207211
duration_ms = int((end_time - start_time) * 1000)
208212

209213
extractor = get_extractor(self)
210-
session_id_hex = _safe_call(extractor.get_session_id_hex)
211-
statement_id = _safe_call(extractor.get_statement_id)
212-
213-
sql_exec_event = SqlExecutionEvent(
214-
statement_type=statement_type,
215-
is_compressed=_safe_call(extractor.get_is_compressed),
216-
execution_result=_safe_call(extractor.get_execution_result),
217-
retry_count=_safe_call(extractor.get_retry_count),
218-
)
219-
220-
telemetry_client = TelemetryClientFactory.get_telemetry_client(
221-
session_id_hex
222-
)
223-
telemetry_client.export_latency_log(
224-
latency_ms=duration_ms,
225-
sql_execution_event=sql_exec_event,
226-
sql_statement_id=statement_id,
227-
)
214+
215+
if extractor is not None:
216+
session_id_hex = _safe_call(extractor.get_session_id_hex)
217+
statement_id = _safe_call(extractor.get_statement_id)
218+
219+
sql_exec_event = SqlExecutionEvent(
220+
statement_type=statement_type,
221+
is_compressed=_safe_call(extractor.get_is_compressed),
222+
execution_result=_safe_call(extractor.get_execution_result),
223+
retry_count=_safe_call(extractor.get_retry_count),
224+
)
225+
226+
telemetry_client = TelemetryClientFactory.get_telemetry_client(
227+
session_id_hex
228+
)
229+
telemetry_client.export_latency_log(
230+
latency_ms=duration_ms,
231+
sql_execution_event=sql_exec_event,
232+
sql_statement_id=statement_id,
233+
)
228234

229235
return wrapper
230236

src/databricks/sql/telemetry/models/enums.py

Lines changed: 23 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,15 +2,17 @@
22

33

44
class AuthFlow(Enum):
5-
TOKEN_PASSTHROUGH = "token_passthrough"
6-
BROWSER_BASED_AUTHENTICATION = "browser_based_authentication"
5+
TYPE_UNSPECIFIED = "TYPE_UNSPECIFIED"
6+
TOKEN_PASSTHROUGH = "TOKEN_PASSTHROUGH"
7+
CLIENT_CREDENTIALS = "CLIENT_CREDENTIALS"
8+
BROWSER_BASED_AUTHENTICATION = "BROWSER_BASED_AUTHENTICATION"
79

810

911
class AuthMech(Enum):
10-
CLIENT_CERT = "CLIENT_CERT" # ssl certificate authentication
11-
PAT = "PAT" # Personal Access Token authentication
12-
DATABRICKS_OAUTH = "DATABRICKS_OAUTH" # Databricks-managed OAuth flow
13-
EXTERNAL_AUTH = "EXTERNAL_AUTH" # External identity provider (AWS, Azure, etc.)
12+
TYPE_UNSPECIFIED = "TYPE_UNSPECIFIED"
13+
OTHER = "OTHER"
14+
PAT = "PAT"
15+
OAUTH = "OAUTH"
1416

1517

1618
class DatabricksClientType(Enum):
@@ -19,24 +21,24 @@ class DatabricksClientType(Enum):
1921

2022

2123
class DriverVolumeOperationType(Enum):
22-
TYPE_UNSPECIFIED = "type_unspecified"
23-
PUT = "put"
24-
GET = "get"
25-
DELETE = "delete"
26-
LIST = "list"
27-
QUERY = "query"
24+
TYPE_UNSPECIFIED = "TYPE_UNSPECIFIED"
25+
PUT = "PUT"
26+
GET = "GET"
27+
DELETE = "DELETE"
28+
LIST = "LIST"
29+
QUERY = "QUERY"
2830

2931

3032
class ExecutionResultFormat(Enum):
31-
FORMAT_UNSPECIFIED = "format_unspecified"
32-
INLINE_ARROW = "inline_arrow"
33-
EXTERNAL_LINKS = "external_links"
34-
COLUMNAR_INLINE = "columnar_inline"
33+
FORMAT_UNSPECIFIED = "FORMAT_UNSPECIFIED"
34+
INLINE_ARROW = "INLINE_ARROW"
35+
EXTERNAL_LINKS = "EXTERNAL_LINKS"
36+
COLUMNAR_INLINE = "COLUMNAR_INLINE"
3537

3638

3739
class StatementType(Enum):
38-
NONE = "none"
39-
QUERY = "query"
40-
SQL = "sql"
41-
UPDATE = "update"
42-
METADATA = "metadata"
40+
NONE = "NONE"
41+
QUERY = "QUERY"
42+
SQL = "SQL"
43+
UPDATE = "UPDATE"
44+
METADATA = "METADATA"

src/databricks/sql/telemetry/telemetry_client.py

Lines changed: 25 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -60,40 +60,34 @@ def get_driver_system_configuration(cls) -> DriverSystemConfiguration:
6060
def get_auth_mechanism(auth_provider):
6161
"""Get the auth mechanism for the auth provider."""
6262
# AuthMech is an enum with the following values:
63-
# PAT, DATABRICKS_OAUTH, EXTERNAL_AUTH, CLIENT_CERT
63+
# TYPE_UNSPECIFIED, OTHER, PAT, OAUTH
6464

6565
if not auth_provider:
6666
return None
6767
if isinstance(auth_provider, AccessTokenAuthProvider):
68-
return AuthMech.PAT # Personal Access Token authentication
68+
return AuthMech.PAT
6969
elif isinstance(auth_provider, DatabricksOAuthProvider):
70-
return AuthMech.DATABRICKS_OAUTH # Databricks-managed OAuth flow
71-
elif isinstance(auth_provider, ExternalAuthProvider):
72-
return (
73-
AuthMech.EXTERNAL_AUTH
74-
) # External identity provider (AWS, Azure, etc.)
75-
return AuthMech.CLIENT_CERT # Client certificate (ssl)
70+
return AuthMech.OAUTH
71+
else:
72+
return AuthMech.OTHER
7673

7774
@staticmethod
7875
def get_auth_flow(auth_provider):
7976
"""Get the auth flow for the auth provider."""
8077
# AuthFlow is an enum with the following values:
81-
# TOKEN_PASSTHROUGH, BROWSER_BASED_AUTHENTICATION
78+
# TYPE_UNSPECIFIED, TOKEN_PASSTHROUGH, CLIENT_CREDENTIALS, BROWSER_BASED_AUTHENTICATION
8279

8380
if not auth_provider:
8481
return None
85-
8682
if isinstance(auth_provider, DatabricksOAuthProvider):
8783
if auth_provider._access_token and auth_provider._refresh_token:
88-
return (
89-
AuthFlow.TOKEN_PASSTHROUGH
90-
) # Has existing tokens, no user interaction needed
91-
if hasattr(auth_provider, "oauth_manager"):
92-
return (
93-
AuthFlow.BROWSER_BASED_AUTHENTICATION
94-
) # Will initiate OAuth flow requiring browser
95-
96-
return None
84+
return AuthFlow.TOKEN_PASSTHROUGH
85+
else:
86+
return AuthFlow.BROWSER_BASED_AUTHENTICATION
87+
elif isinstance(auth_provider, ExternalAuthProvider):
88+
return AuthFlow.CLIENT_CREDENTIALS
89+
else:
90+
return None
9791

9892

9993
class BaseTelemetryClient(ABC):
@@ -104,21 +98,23 @@ class BaseTelemetryClient(ABC):
10498

10599
@abstractmethod
106100
def export_initial_telemetry_log(self, driver_connection_params, user_agent):
107-
raise NotImplementedError(
108-
"Subclasses must implement export_initial_telemetry_log"
109-
)
101+
logger.debug("subclass must implement export_initial_telemetry_log")
102+
pass
110103

111104
@abstractmethod
112105
def export_failure_log(self, error_name, error_message):
113-
raise NotImplementedError("Subclasses must implement export_failure_log")
106+
logger.debug("subclass must implement export_failure_log")
107+
pass
114108

115109
@abstractmethod
116110
def export_latency_log(self, latency_ms, sql_execution_event, sql_statement_id):
117-
raise NotImplementedError("Subclasses must implement export_latency_log")
111+
logger.debug("subclass must implement export_latency_log")
112+
pass
118113

119114
@abstractmethod
120115
def close(self):
121-
raise NotImplementedError("Subclasses must implement close")
116+
logger.debug("subclass must implement close")
117+
pass
122118

123119

124120
class NoopTelemetryClient(BaseTelemetryClient):
@@ -157,6 +153,8 @@ class TelemetryClient(BaseTelemetryClient):
157153
TELEMETRY_AUTHENTICATED_PATH = "/telemetry-ext"
158154
TELEMETRY_UNAUTHENTICATED_PATH = "/telemetry-unauth"
159155

156+
DEFAULT_BATCH_SIZE = 100
157+
160158
def __init__(
161159
self,
162160
telemetry_enabled,
@@ -167,12 +165,12 @@ def __init__(
167165
):
168166
logger.debug("Initializing TelemetryClient for connection: %s", session_id_hex)
169167
self._telemetry_enabled = telemetry_enabled
170-
self._batch_size = 10 # TODO: Decide on batch size
168+
self._batch_size = self.DEFAULT_BATCH_SIZE
171169
self._session_id_hex = session_id_hex
172170
self._auth_provider = auth_provider
173171
self._user_agent = None
174172
self._events_batch = []
175-
self._lock = threading.Lock()
173+
self._lock = threading.RLock()
176174
self._driver_connection_params = None
177175
self._host_url = host_url
178176
self._executor = executor

tests/unit/test_telemetry.py

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,9 @@ def test_auth_mechanism_detection(self):
168168
"""Test authentication mechanism detection for different providers."""
169169
test_cases = [
170170
(AccessTokenAuthProvider("token"), AuthMech.PAT),
171-
(MagicMock(spec=DatabricksOAuthProvider), AuthMech.DATABRICKS_OAUTH),
172-
(MagicMock(spec=ExternalAuthProvider), AuthMech.EXTERNAL_AUTH),
173-
(MagicMock(), AuthMech.CLIENT_CERT), # Unknown provider
171+
(MagicMock(spec=DatabricksOAuthProvider), AuthMech.OAUTH),
172+
(MagicMock(spec=ExternalAuthProvider), AuthMech.OTHER),
173+
(MagicMock(), AuthMech.OTHER), # Unknown provider
174174
(None, None),
175175
]
176176

0 commit comments

Comments
 (0)