Skip to content
Merged
Show file tree
Hide file tree
Changes from 4 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/set_user_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
server_hostname=os.getenv("DATABRICKS_SERVER_HOSTNAME"),
http_path=os.getenv("DATABRICKS_HTTP_PATH"),
access_token=os.getenv("DATABRICKS_TOKEN"),
_user_agent_entry="ExamplePartnerTag",
user_agent_entry="ExamplePartnerTag",
) as connection:

with connection.cursor() as cursor:
Expand Down
22 changes: 16 additions & 6 deletions src/databricks/sql/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,9 @@ def __init__(
port of the oauth redirect uri (localhost). This is required when custom oauth client_id
`oauth_client_id` is set

user_agent_entry: `str`, optional
A custom tag to append to the User-Agent header. This is typically used by partners to identify their applications.. If not specified, it will use the default user agent PyDatabricksSqlConnector

experimental_oauth_persistence: configures preferred storage for persisting oauth tokens.
This has to be a class implementing `OAuthPersistence`.
When `auth_type` is set to `databricks-oauth` or `azure-oauth` without persisting the oauth token in a
Expand Down Expand Up @@ -176,8 +179,6 @@ def read(self) -> Optional[OAuthToken]:
"""

# Internal arguments in **kwargs:
# _user_agent_entry
# Tag to add to User-Agent header. For use by partners.
# _use_cert_as_auth
# Use a TLS cert instead of a token
# _enable_ssl
Expand Down Expand Up @@ -227,12 +228,21 @@ def read(self) -> Optional[OAuthToken]:
server_hostname, **kwargs
)

if not kwargs.get("_user_agent_entry"):
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)
else:
user_agent_entry = kwargs.get("user_agent_entry")
if user_agent_entry is None:
user_agent_entry = kwargs.get("_user_agent_entry")
if user_agent_entry is not None:
logger.warning(
"[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. "
"This parameter will be removed in the next release."
)

if user_agent_entry:
useragent_header = "{}/{} ({})".format(
USER_AGENT_NAME, __version__, kwargs.get("_user_agent_entry")
USER_AGENT_NAME, __version__, user_agent_entry
)
else:
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)

base_headers = [("User-Agent", useragent_header)]

Expand Down
2 changes: 0 additions & 2 deletions src/databricks/sql/thrift_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,6 @@ def __init__(
**kwargs,
):
# Internal arguments in **kwargs:
# _user_agent_entry
# Tag to add to User-Agent header. For use by partners.
# _username, _password
# Username and password Basic authentication (no official support)
# _connection_uri
Expand Down
2 changes: 1 addition & 1 deletion tests/unit/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -157,7 +157,7 @@ def test_useragent_header(self, mock_client_class):
)
self.assertIn(user_agent_header, http_headers)

databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, _user_agent_entry="foobar")
databricks.sql.connect(**self.DUMMY_CONNECTION_ARGS, user_agent_entry="foobar")
user_agent_header_with_entry = (
"User-Agent",
"{}/{} ({})".format(
Expand Down
Loading