9494"""
9595
9696
97+ @dataclass (frozen = True )
98+ class QueryContextWrapper :
99+ """
100+ Until dbt tightens this protocol up, we need to wrap the context for safety
101+ """
102+
103+ compute_name : Optional [str ] = None
104+ relation_name : Optional [str ] = None
105+
106+ @staticmethod
107+ def from_context (query_header_context : Any ) -> "QueryContextWrapper" :
108+ if query_header_context is None :
109+ return QueryContextWrapper ()
110+ compute_name = None
111+ relation_name = getattr (query_header_context , "relation_name" , "[unknown]" )
112+ if hasattr (query_header_context , "config" ) and query_header_context .config :
113+ compute_name = query_header_context .config .get ("databricks_compute" )
114+ return QueryContextWrapper (compute_name = compute_name , relation_name = relation_name )
115+
116+
97117class DatabricksMacroQueryStringSetter (MacroQueryStringSetter ):
98118 def _get_comment_macro (self ) -> Optional [str ]:
99119 if self .config .query_comment .comment == DEFAULT_QUERY_COMMENT :
@@ -238,16 +258,16 @@ def set_connection_name(
238258 self ._cleanup_idle_connections ()
239259
240260 conn_name : str = "master" if name is None else name
241-
261+ wrapped = QueryContextWrapper . from_context ( query_header_context )
242262 # Get a connection for this thread
243- conn = self ._get_if_exists_compute_connection (_get_compute_name ( query_header_context ) or "" )
263+ conn = self ._get_if_exists_compute_connection (wrapped . compute_name or "" )
244264
245265 if conn is None :
246- conn = self ._create_compute_connection (conn_name , query_header_context )
266+ conn = self ._create_compute_connection (conn_name , wrapped )
247267 else : # existing connection either wasn't open or didn't have the right name
248268 conn = self ._update_compute_connection (conn , conn_name )
249269
250- conn ._acquire (query_header_context )
270+ conn ._acquire (wrapped )
251271
252272 return conn
253273
@@ -518,13 +538,13 @@ def _cleanup_idle_connections(self) -> None:
518538 conn ._reset_handle (self .open )
519539
520540 def _create_compute_connection (
521- self , conn_name : str , query_header_context : Any = None
541+ self , conn_name : str , query_header_context : QueryContextWrapper
522542 ) -> DatabricksDBTConnection :
523543 """Create anew connection for the combination of current thread and compute associated
524544 with the given node."""
525545
526546 # Create a new connection
527- compute_name = _get_compute_name ( query_header_context ) or ""
547+ compute_name = query_header_context . compute_name or ""
528548
529549 conn = DatabricksDBTConnection (
530550 type = Identifier (self .TYPE ),
@@ -536,9 +556,9 @@ def _create_compute_connection(
536556 )
537557 conn .compute_name = compute_name
538558 creds = cast (DatabricksCredentials , self .profile .credentials )
539- conn .http_path = _get_http_path (query_header_context , creds = creds ) or ""
559+ conn .http_path = QueryConfigUtils . get_http_path (query_header_context , creds )
540560 conn .thread_identifier = cast (tuple [int , int ], self .get_thread_identifier ())
541- conn .max_idle_time = _get_max_idle_time (query_header_context , creds = creds )
561+ conn .max_idle_time = QueryConfigUtils . get_max_idle_time (query_header_context , creds )
542562
543563 conn .handle = LazyHandle (self .open )
544564
@@ -604,74 +624,56 @@ def _update_compute_connection(
604624 return conn
605625
606626
607- def _get_compute_name (query_header_context : Any ) -> Optional [str ]:
608- # Get the name of the specified compute resource from the node's
609- # config.
610- compute_name = None
611- if (
612- query_header_context
613- and hasattr (query_header_context , "config" )
614- and query_header_context .config
615- ):
616- compute_name = query_header_context .config .get ("databricks_compute" , None )
617- return compute_name
618-
619-
620- def _get_http_path (query_header_context : Any , creds : DatabricksCredentials ) -> Optional [str ]:
621- """Get the http_path for the compute specified for the node.
622- If none is specified default will be used."""
623-
624- # ResultNode *should* have relation_name attr, but we work around a core
625- # issue by checking.
626- relation_name = getattr (query_header_context , "relation_name" , "[unknown]" )
627-
628- # If there is no node we return the http_path for the default compute.
629- if not query_header_context :
630- return creds .http_path
631-
632- # Get the name of the compute resource specified in the node's config.
633- # If none is specified return the http_path for the default compute.
634- compute_name = _get_compute_name (query_header_context )
635- if not compute_name :
636- return creds .http_path
637-
638- # Get the http_path for the named compute.
639- http_path = None
640- if creds .compute :
641- http_path = creds .compute .get (compute_name , {}).get ("http_path" , None )
642-
643- # no http_path for the named compute resource is an error condition
644- if not http_path :
645- raise DbtRuntimeError (
646- f"Compute resource { compute_name } does not exist or "
647- f"does not specify http_path, relation: { relation_name } "
648- )
627+ class QueryConfigUtils :
628+ """
629+ Utility class for getting config values from QueryHeaderContextWrapper and Credentials.
630+ """
649631
650- return http_path
632+ @staticmethod
633+ def get_http_path (context : QueryContextWrapper , creds : DatabricksCredentials ) -> str :
634+ """
635+ Get the http_path for the compute specified for the node.
636+ If none is specified default will be used.
637+ """
651638
639+ if not context .compute_name :
640+ return creds .http_path or ""
641+
642+ # Get the http_path for the named compute.
643+ http_path = None
644+ if creds .compute :
645+ http_path = creds .compute .get (context .compute_name , {}).get ("http_path" , None )
646+
647+ # no http_path for the named compute resource is an error condition
648+ if not http_path :
649+ raise DbtRuntimeError (
650+ f"Compute resource { context .compute_name } does not exist or "
651+ f"does not specify http_path, relation: { context .relation_name } "
652+ )
652653
653- def _get_max_idle_time (query_header_context : Any , creds : DatabricksCredentials ) -> int :
654- """Get the http_path for the compute specified for the node.
655- If none is specified default will be used."""
654+ return http_path
656655
657- max_idle_time = (
658- DEFAULT_MAX_IDLE_TIME if creds .connect_max_idle is None else creds .connect_max_idle
659- )
656+ @staticmethod
657+ def get_max_idle_time (context : QueryContextWrapper , creds : DatabricksCredentials ) -> int :
658+ """Get the http_path for the compute specified for the node.
659+ If none is specified default will be used."""
660660
661- if query_header_context :
662- compute_name = _get_compute_name (query_header_context )
663- if compute_name and creds .compute :
664- max_idle_time = creds .compute .get (compute_name , {}).get (
661+ max_idle_time = (
662+ DEFAULT_MAX_IDLE_TIME if creds .connect_max_idle is None else creds .connect_max_idle
663+ )
664+
665+ if context .compute_name and creds .compute :
666+ max_idle_time = creds .compute .get (context .compute_name , {}).get (
665667 "connect_max_idle" , max_idle_time
666668 )
667669
668- if not isinstance (max_idle_time , int ):
669- if isinstance (max_idle_time , str ) and max_idle_time .strip ().isnumeric ():
670- return int (max_idle_time .strip ())
671- else :
672- raise DbtRuntimeError (
673- f"{ max_idle_time } is not a valid value for connect_max_idle. "
674- "Must be a number of seconds."
675- )
670+ if not isinstance (max_idle_time , int ):
671+ if isinstance (max_idle_time , str ) and max_idle_time .strip ().isnumeric ():
672+ return int (max_idle_time .strip ())
673+ else :
674+ raise DbtRuntimeError (
675+ f"{ max_idle_time } is not a valid value for connect_max_idle. "
676+ "Must be a number of seconds."
677+ )
676678
677- return max_idle_time
679+ return max_idle_time
0 commit comments