Skip to content

Commit 2ccd729

Browse files
authored
Scope connection parameter creation (#929)
1 parent 40c2337 commit 2ccd729

File tree

5 files changed

+172
-371
lines changed

5 files changed

+172
-371
lines changed

CHANGELOG.md

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,10 @@
1-
## dbt-databricks 1.9.5 (Feb 11, 2025)
1+
## dbt-databricks 1.9.6 (TBD)
2+
3+
### Under the Hood
4+
5+
- Refactoring of some connection internals ([929](https://github.com/databricks/dbt-databricks/pull/929))
6+
7+
## dbt-databricks 1.9.5 (Feb 13, 2025)
28

39
### Features
410

@@ -11,7 +17,6 @@
1117
- Fix for regression in glue table listing behavior ([934](https://github.com/databricks/dbt-databricks/pull/934))
1218
- Use POSIX standard when creating location for the tables (thanks @gsolasab!) ([919](https://github.com/databricks/dbt-databricks/pull/919))
1319

14-
1520
### Under the Hood
1621

1722
- Collapsing to a single connection manager (since the old one no longer works) ([910](https://github.com/databricks/dbt-databricks/pull/910))

dbt/adapters/databricks/connections.py

Lines changed: 72 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -94,6 +94,26 @@
9494
"""
9595

9696

97+
@dataclass(frozen=True)
98+
class QueryContextWrapper:
99+
"""
100+
Until dbt tightens this protocol up, we need to wrap the context for safety
101+
"""
102+
103+
compute_name: Optional[str] = None
104+
relation_name: Optional[str] = None
105+
106+
@staticmethod
107+
def from_context(query_header_context: Any) -> "QueryContextWrapper":
108+
if query_header_context is None:
109+
return QueryContextWrapper()
110+
compute_name = None
111+
relation_name = getattr(query_header_context, "relation_name", "[unknown]")
112+
if hasattr(query_header_context, "config") and query_header_context.config:
113+
compute_name = query_header_context.config.get("databricks_compute")
114+
return QueryContextWrapper(compute_name=compute_name, relation_name=relation_name)
115+
116+
97117
class DatabricksMacroQueryStringSetter(MacroQueryStringSetter):
98118
def _get_comment_macro(self) -> Optional[str]:
99119
if self.config.query_comment.comment == DEFAULT_QUERY_COMMENT:
@@ -238,16 +258,16 @@ def set_connection_name(
238258
self._cleanup_idle_connections()
239259

240260
conn_name: str = "master" if name is None else name
241-
261+
wrapped = QueryContextWrapper.from_context(query_header_context)
242262
# Get a connection for this thread
243-
conn = self._get_if_exists_compute_connection(_get_compute_name(query_header_context) or "")
263+
conn = self._get_if_exists_compute_connection(wrapped.compute_name or "")
244264

245265
if conn is None:
246-
conn = self._create_compute_connection(conn_name, query_header_context)
266+
conn = self._create_compute_connection(conn_name, wrapped)
247267
else: # existing connection either wasn't open or didn't have the right name
248268
conn = self._update_compute_connection(conn, conn_name)
249269

250-
conn._acquire(query_header_context)
270+
conn._acquire(wrapped)
251271

252272
return conn
253273

@@ -518,13 +538,13 @@ def _cleanup_idle_connections(self) -> None:
518538
conn._reset_handle(self.open)
519539

520540
def _create_compute_connection(
521-
self, conn_name: str, query_header_context: Any = None
541+
self, conn_name: str, query_header_context: QueryContextWrapper
522542
) -> DatabricksDBTConnection:
523543
"""Create anew connection for the combination of current thread and compute associated
524544
with the given node."""
525545

526546
# Create a new connection
527-
compute_name = _get_compute_name(query_header_context) or ""
547+
compute_name = query_header_context.compute_name or ""
528548

529549
conn = DatabricksDBTConnection(
530550
type=Identifier(self.TYPE),
@@ -536,9 +556,9 @@ def _create_compute_connection(
536556
)
537557
conn.compute_name = compute_name
538558
creds = cast(DatabricksCredentials, self.profile.credentials)
539-
conn.http_path = _get_http_path(query_header_context, creds=creds) or ""
559+
conn.http_path = QueryConfigUtils.get_http_path(query_header_context, creds)
540560
conn.thread_identifier = cast(tuple[int, int], self.get_thread_identifier())
541-
conn.max_idle_time = _get_max_idle_time(query_header_context, creds=creds)
561+
conn.max_idle_time = QueryConfigUtils.get_max_idle_time(query_header_context, creds)
542562

543563
conn.handle = LazyHandle(self.open)
544564

@@ -604,74 +624,56 @@ def _update_compute_connection(
604624
return conn
605625

606626

607-
def _get_compute_name(query_header_context: Any) -> Optional[str]:
608-
# Get the name of the specified compute resource from the node's
609-
# config.
610-
compute_name = None
611-
if (
612-
query_header_context
613-
and hasattr(query_header_context, "config")
614-
and query_header_context.config
615-
):
616-
compute_name = query_header_context.config.get("databricks_compute", None)
617-
return compute_name
618-
619-
620-
def _get_http_path(query_header_context: Any, creds: DatabricksCredentials) -> Optional[str]:
621-
"""Get the http_path for the compute specified for the node.
622-
If none is specified default will be used."""
623-
624-
# ResultNode *should* have relation_name attr, but we work around a core
625-
# issue by checking.
626-
relation_name = getattr(query_header_context, "relation_name", "[unknown]")
627-
628-
# If there is no node we return the http_path for the default compute.
629-
if not query_header_context:
630-
return creds.http_path
631-
632-
# Get the name of the compute resource specified in the node's config.
633-
# If none is specified return the http_path for the default compute.
634-
compute_name = _get_compute_name(query_header_context)
635-
if not compute_name:
636-
return creds.http_path
637-
638-
# Get the http_path for the named compute.
639-
http_path = None
640-
if creds.compute:
641-
http_path = creds.compute.get(compute_name, {}).get("http_path", None)
642-
643-
# no http_path for the named compute resource is an error condition
644-
if not http_path:
645-
raise DbtRuntimeError(
646-
f"Compute resource {compute_name} does not exist or "
647-
f"does not specify http_path, relation: {relation_name}"
648-
)
627+
class QueryConfigUtils:
628+
"""
629+
Utility class for getting config values from QueryHeaderContextWrapper and Credentials.
630+
"""
649631

650-
return http_path
632+
@staticmethod
633+
def get_http_path(context: QueryContextWrapper, creds: DatabricksCredentials) -> str:
634+
"""
635+
Get the http_path for the compute specified for the node.
636+
If none is specified default will be used.
637+
"""
651638

639+
if not context.compute_name:
640+
return creds.http_path or ""
641+
642+
# Get the http_path for the named compute.
643+
http_path = None
644+
if creds.compute:
645+
http_path = creds.compute.get(context.compute_name, {}).get("http_path", None)
646+
647+
# no http_path for the named compute resource is an error condition
648+
if not http_path:
649+
raise DbtRuntimeError(
650+
f"Compute resource {context.compute_name} does not exist or "
651+
f"does not specify http_path, relation: {context.relation_name}"
652+
)
652653

653-
def _get_max_idle_time(query_header_context: Any, creds: DatabricksCredentials) -> int:
654-
"""Get the http_path for the compute specified for the node.
655-
If none is specified default will be used."""
654+
return http_path
656655

657-
max_idle_time = (
658-
DEFAULT_MAX_IDLE_TIME if creds.connect_max_idle is None else creds.connect_max_idle
659-
)
656+
@staticmethod
657+
def get_max_idle_time(context: QueryContextWrapper, creds: DatabricksCredentials) -> int:
658+
"""Get the http_path for the compute specified for the node.
659+
If none is specified default will be used."""
660660

661-
if query_header_context:
662-
compute_name = _get_compute_name(query_header_context)
663-
if compute_name and creds.compute:
664-
max_idle_time = creds.compute.get(compute_name, {}).get(
661+
max_idle_time = (
662+
DEFAULT_MAX_IDLE_TIME if creds.connect_max_idle is None else creds.connect_max_idle
663+
)
664+
665+
if context.compute_name and creds.compute:
666+
max_idle_time = creds.compute.get(context.compute_name, {}).get(
665667
"connect_max_idle", max_idle_time
666668
)
667669

668-
if not isinstance(max_idle_time, int):
669-
if isinstance(max_idle_time, str) and max_idle_time.strip().isnumeric():
670-
return int(max_idle_time.strip())
671-
else:
672-
raise DbtRuntimeError(
673-
f"{max_idle_time} is not a valid value for connect_max_idle. "
674-
"Must be a number of seconds."
675-
)
670+
if not isinstance(max_idle_time, int):
671+
if isinstance(max_idle_time, str) and max_idle_time.strip().isnumeric():
672+
return int(max_idle_time.strip())
673+
else:
674+
raise DbtRuntimeError(
675+
f"{max_idle_time} is not a valid value for connect_max_idle. "
676+
"Must be a number of seconds."
677+
)
676678

677-
return max_idle_time
679+
return max_idle_time

tests/unit/test_compute_config.py

Lines changed: 0 additions & 59 deletions
This file was deleted.

0 commit comments

Comments
 (0)