Skip to content

Commit 9b1b1f5

Browse files
remove repetition from Session.__init__
Signed-off-by: varun-edachali-dbx <[email protected]>
1 parent 2c9368a commit 9b1b1f5

File tree

3 files changed

+19
-53
lines changed

3 files changed

+19
-53
lines changed

src/databricks/sql/backend/thrift_backend.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -592,17 +592,17 @@ def open_session(self, session_configuration, catalog, schema) -> SessionId:
592592
response = self.make_request(self._client.OpenSession, open_session_req)
593593
self._check_initial_namespace(catalog, schema, response)
594594
self._check_protocol_version(response)
595-
self._session_id_hex = (
596-
self.handle_to_hex_id(response.sessionHandle)
597-
if response.sessionHandle
598-
else None
599-
)
595+
600596
properties = (
601597
{"serverProtocolVersion": response.serverProtocolVersion}
602598
if response.serverProtocolVersion
603599
else {}
604600
)
605-
return SessionId.from_thrift_handle(response.sessionHandle, properties)
601+
session_id = SessionId.from_thrift_handle(
602+
response.sessionHandle, properties
603+
)
604+
self._session_id_hex = session_id.hex_guid
605+
return session_id
606606
except:
607607
self._transport.close()
608608
raise

src/databricks/sql/client.py

Lines changed: 8 additions & 42 deletions
Original file line numberDiff line numberDiff line change
@@ -242,47 +242,15 @@ def read(self) -> Optional[OAuthToken]:
242242

243243
self.disable_pandas = kwargs.get("_disable_pandas", False)
244244
self.lz4_compression = kwargs.get("enable_query_result_lz4_compression", True)
245-
246-
auth_provider = get_python_sql_connector_auth_provider(
247-
server_hostname, **kwargs
248-
)
245+
self.use_cloud_fetch = kwargs.get("use_cloud_fetch", True)
246+
self._cursors = [] # type: List[Cursor]
249247

250248
self.server_telemetry_enabled = True
251249
self.client_telemetry_enabled = kwargs.get("enable_telemetry", False)
252250
self.telemetry_enabled = (
253251
self.client_telemetry_enabled and self.server_telemetry_enabled
254252
)
255253

256-
user_agent_entry = kwargs.get("user_agent_entry")
257-
if user_agent_entry is None:
258-
user_agent_entry = kwargs.get("_user_agent_entry")
259-
if user_agent_entry is not None:
260-
logger.warning(
261-
"[WARN] Parameter '_user_agent_entry' is deprecated; use 'user_agent_entry' instead. "
262-
"This parameter will be removed in the upcoming releases."
263-
)
264-
265-
if user_agent_entry:
266-
useragent_header = "{}/{} ({})".format(
267-
USER_AGENT_NAME, __version__, user_agent_entry
268-
)
269-
else:
270-
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)
271-
272-
base_headers = [("User-Agent", useragent_header)]
273-
274-
self._ssl_options = SSLOptions(
275-
# Double negation is generally a bad thing, but we have to keep backward compatibility
276-
tls_verify=not kwargs.get(
277-
"_tls_no_verify", False
278-
), # by default - verify cert and host
279-
tls_verify_hostname=kwargs.get("_tls_verify_hostname", True),
280-
tls_trusted_ca_file=kwargs.get("_tls_trusted_ca_file"),
281-
tls_client_cert_file=kwargs.get("_tls_client_cert_file"),
282-
tls_client_cert_key_file=kwargs.get("_tls_client_cert_key_file"),
283-
tls_client_cert_key_password=kwargs.get("_tls_client_cert_key_password"),
284-
)
285-
286254
self.session = Session(
287255
server_hostname,
288256
http_path,
@@ -303,8 +271,8 @@ def read(self) -> Optional[OAuthToken]:
303271
TelemetryClientFactory.initialize_telemetry_client(
304272
telemetry_enabled=self.telemetry_enabled,
305273
session_id_hex=self.get_session_id_hex(),
306-
auth_provider=auth_provider,
307-
host_url=self.host,
274+
auth_provider=self.session.auth_provider,
275+
host_url=self.session.host,
308276
)
309277

310278
self._telemetry_client = TelemetryClientFactory.get_telemetry_client(
@@ -314,15 +282,15 @@ def read(self) -> Optional[OAuthToken]:
314282
driver_connection_params = DriverConnectionParameters(
315283
http_path=http_path,
316284
mode=DatabricksClientType.THRIFT,
317-
host_info=HostDetails(host_url=server_hostname, port=self.port),
318-
auth_mech=TelemetryHelper.get_auth_mechanism(auth_provider),
319-
auth_flow=TelemetryHelper.get_auth_flow(auth_provider),
285+
host_info=HostDetails(host_url=server_hostname, port=self.session.port),
286+
auth_mech=TelemetryHelper.get_auth_mechanism(self.session.auth_provider),
287+
auth_flow=TelemetryHelper.get_auth_flow(self.session.auth_provider),
320288
socket_timeout=kwargs.get("_socket_timeout", None),
321289
)
322290

323291
self._telemetry_client.export_initial_telemetry_log(
324292
driver_connection_params=driver_connection_params,
325-
user_agent=useragent_header,
293+
user_agent=self.session.useragent_header,
326294
)
327295

328296
def _set_use_inline_params_with_warning(self, value: Union[bool, str]):
@@ -446,8 +414,6 @@ def _close(self, close_cursors=True) -> None:
446414
except Exception as e:
447415
logger.error(f"Attempt to close session raised a local exception: {e}")
448416

449-
self.open = False
450-
451417
TelemetryClientFactory.close(self.get_session_id_hex())
452418

453419
def commit(self):

src/databricks/sql/session.py

Lines changed: 5 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def __init__(
4040
self.catalog = catalog
4141
self.schema = schema
4242

43-
auth_provider = get_python_sql_connector_auth_provider(
43+
self.auth_provider = get_python_sql_connector_auth_provider(
4444
server_hostname, **kwargs
4545
)
4646

@@ -54,13 +54,13 @@ def __init__(
5454
)
5555

5656
if user_agent_entry:
57-
useragent_header = "{}/{} ({})".format(
57+
self.useragent_header = "{}/{} ({})".format(
5858
USER_AGENT_NAME, __version__, user_agent_entry
5959
)
6060
else:
61-
useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)
61+
self.useragent_header = "{}/{}".format(USER_AGENT_NAME, __version__)
6262

63-
base_headers = [("User-Agent", useragent_header)]
63+
base_headers = [("User-Agent", self.useragent_header)]
6464

6565
self._ssl_options = SSLOptions(
6666
# Double negation is generally a bad thing, but we have to keep backward compatibility
@@ -79,7 +79,7 @@ def __init__(
7979
self.port,
8080
http_path,
8181
(http_headers or []) + base_headers,
82-
auth_provider,
82+
self.auth_provider,
8383
ssl_options=self._ssl_options,
8484
_use_arrow_native_complex_types=_use_arrow_native_complex_types,
8585
**kwargs,

0 commit comments

Comments
 (0)