66from dataclasses import dataclass
77from functools import cached_property
88from typing import ClassVar , Optional
9- from urllib .parse import urlparse
9+ from urllib .parse import urlparse , ParseResult
1010
1111from databricks .labs .blueprint .installation import Installation
1212from databricks .labs .lsql .backends import SqlBackend
@@ -42,9 +42,9 @@ class LocationTrie:
4242 """
4343
4444 key : str = ""
45- parent : OptionalLocationTrie = None
45+ parent : OptionalLocationTrie = dataclasses . field ( repr = False , default = None )
4646 children : dict [str , "LocationTrie" ] = dataclasses .field (default_factory = dict )
47- tables : list [Table ] = dataclasses .field (default_factory = list )
47+ tables : list [Table ] = dataclasses .field (repr = False , default_factory = list )
4848
4949 @cached_property
5050 def _path (self ) -> list [str ]:
@@ -57,19 +57,47 @@ def _path(self) -> list[str]:
5757 return list (reversed (parts ))[1 :]
5858
5959 @property
60- def location (self ):
61- scheme , netloc , * path = self ._path
62- return f"{ scheme } ://{ netloc } /{ '/' .join (path )} "
60+ def location (self ) -> str | None :
61+ if not self .is_valid ():
62+ return None
63+ try :
64+ scheme , netloc , * path = self ._path
65+ return f"{ scheme } ://{ netloc } /{ '/' .join (path )} " .rstrip ("/" )
66+ except ValueError :
67+ return None
6368
64- @staticmethod
65- def _parse_location (location : str | None ) -> list [str ]:
69+ @classmethod
70+ def _parse_location (cls , location : str | None ) -> list [str ]:
6671 if not location :
6772 return []
68- parse_result = urlparse (location )
73+ parse_result = cls ._parse_url (location .rstrip ("/" ))
74+ if not parse_result :
75+ return []
6976 parts = [parse_result .scheme , parse_result .netloc ]
70- parts .extend (parse_result .path .strip ("/" ).split ("/" ))
77+ for part in parse_result .path .split ("/" ):
78+ if not part :
79+ continue # remove empty strings
80+ parts .append (part )
7181 return parts
7282
83+ @staticmethod
84+ def _parse_url (location : str ) -> ParseResult | None :
85+ parse_result = urlparse (location )
86+ if parse_result .scheme == 'jdbc' :
87+ jdbc_path = parse_result .path .split ('://' )
88+ if len (jdbc_path ) != 2 :
89+ return None
90+ netloc , path = jdbc_path [1 ].split ('/' , 1 )
91+ parse_result = ParseResult (
92+ scheme = f'{ parse_result .scheme } :{ jdbc_path [0 ]} ' ,
93+ netloc = netloc ,
94+ path = path ,
95+ params = '' ,
96+ query = '' ,
97+ fragment = '' ,
98+ )
99+ return parse_result
100+
73101 def insert (self , table : Table ) -> None :
74102 current = self
75103 for part in self ._parse_location (table .location ):
@@ -91,11 +119,22 @@ def find(self, table: Table) -> OptionalLocationTrie:
91119
92120 def is_valid (self ) -> bool :
93121 """A valid location has a scheme and netloc; the path is optional."""
94- if len (self ._path ) < 3 :
122+ if len (self ._path ) < 2 :
95123 return False
96124 scheme , netloc , * _ = self ._path
125+ if scheme .startswith ('jdbc:' ) and len (netloc ) > 0 :
126+ return True
97127 return scheme in _EXTERNAL_FILE_LOCATION_SCHEMES and len (netloc ) > 0
98128
129+ def is_jdbc (self ) -> bool :
130+ if not self .is_valid ():
131+ return False
132+ return self ._path [0 ].startswith ('jdbc:' )
133+
134+ def all_tables (self ) -> Iterable [Table ]:
135+ for node in self :
136+ yield from node .tables
137+
99138 def has_children (self ):
100139 return len (self .children ) > 0
101140
@@ -125,64 +164,59 @@ def __init__(
125164 @cached_property
126165 def _mounts_snapshot (self ) -> list ['Mount' ]:
127166 """Returns all mounts, sorted by longest prefixes first."""
128- return sorted (self ._mounts_crawler .snapshot (), key = lambda _ : len (_ .name ), reverse = True )
167+ return sorted (self ._mounts_crawler .snapshot (), key = lambda _ : ( len (_ . name ), _ .name ), reverse = True )
129168
130169 def _external_locations (self ) -> Iterable [ExternalLocation ]:
131- min_slash = 2
132- external_locations : list [ExternalLocation ] = []
170+ trie = LocationTrie ()
133171 for table in self ._tables_crawler .snapshot ():
134- location = table . location
135- if not location :
172+ table = self . _resolve_location ( table )
173+ if not table . location :
136174 continue
137- # TODO: refactor this with LocationTrie
138- if location .startswith ("dbfs:/mnt" ):
139- location = self .resolve_mount (location )
140- if not location :
175+ trie .insert (table )
176+ queue = list (trie .children .values ())
177+ external_locations = []
178+ while queue :
179+ curr = queue .pop ()
180+ num_children = len (curr .children ) # 0 - take parent
181+ if curr .location and (num_children > 1 or num_children == 0 ):
182+ if curr .parent and num_children == 0 and not curr .is_jdbc (): # one table having the prefix
183+ curr = curr .parent
184+ assert curr .location is not None
185+ external_location = ExternalLocation (curr .location , len (list (curr .all_tables ())))
186+ external_locations .append (external_location )
141187 continue
142- if (
143- not location .startswith ("dbfs" )
144- and (self ._prefix_size [0 ] < location .find (":/" ) < self ._prefix_size [1 ])
145- and not location .startswith ("jdbc" )
146- ):
147- self ._dbfs_locations (external_locations , location , min_slash )
148- if location .startswith ("jdbc" ):
149- self ._add_jdbc_location (external_locations , location , table )
150- return external_locations
188+ queue .extend (curr .children .values ())
189+ return sorted (external_locations , key = lambda _ : _ .location )
190+
191+ def _resolve_location (self , table : Table ) -> Table :
192+ location = table .location
193+ if not location :
194+ return table
195+ location = self ._resolve_jdbc (table )
196+ location = self .resolve_mount (location )
197+ return dataclasses .replace (table , location = location )
151198
152199 def resolve_mount (self , location : str | None ) -> str | None :
153200 if not location :
154201 return None
202+ if location .startswith ('/dbfs' ):
203+ location = 'dbfs:' + location [5 :] # convert FUSE path to DBFS path
204+ if not location .startswith ('dbfs:' ):
205+ return location # not a mount, save some cycles
155206 for mount in self ._mounts_snapshot :
156- for prefix in ( mount .as_scheme_prefix (), mount . as_fuse_prefix ()):
157- if not location .startswith (prefix ):
158- continue
159- logger .debug (f"Replacing location { prefix } with { mount .source } in { location } " )
160- location = location .replace (prefix , mount .source )
161- return location
207+ prefix = mount .as_scheme_prefix ()
208+ if not location .startswith (prefix ):
209+ continue
210+ logger .debug (f"Replacing location { prefix } with { mount .source } in { location } " )
211+ location = location .replace (prefix , mount .source )
212+ return location
162213 logger .debug (f"Mount not found for location { location } . Skipping replacement." )
163214 return location
164215
165- @staticmethod
166- def _dbfs_locations (external_locations , location , min_slash ):
167- dupe = False
168- loc = 0
169- while loc < len (external_locations ) and not dupe :
170- common = (
171- os .path .commonpath ([external_locations [loc ].location , os .path .dirname (location ) + "/" ]).replace (
172- ":/" , "://"
173- )
174- + "/"
175- )
176- if common .count ("/" ) > min_slash :
177- table_count = external_locations [loc ].table_count
178- external_locations [loc ] = ExternalLocation (common , table_count + 1 )
179- dupe = True
180- loc += 1
181- if not dupe :
182- external_locations .append (ExternalLocation (os .path .dirname (location ) + "/" , 1 ))
183-
184- def _add_jdbc_location (self , external_locations , location , table ):
185- dupe = False
216+ def _resolve_jdbc (self , table : Table ) -> str | None :
217+ location = table .location
218+ if not location or not table .storage_properties or not location .startswith ('jdbc:' ):
219+ return location
186220 pattern = r"(\w+)=(.*?)(?=\s*,|\s*\])"
187221 # Find all matches in the input string
188222 # Storage properties is of the format
@@ -201,20 +235,12 @@ def _add_jdbc_location(self, external_locations, location, table):
201235 # currently supporting databricks and mysql external tables
202236 # add other jdbc types
203237 if "databricks" in location .lower ():
204- jdbc_location = f"jdbc:databricks://{ host } ;httpPath={ httppath } "
205- elif "mysql" in location .lower ():
206- jdbc_location = f"jdbc:mysql://{ host } :{ port } /{ database } "
207- elif not provider == "" :
208- jdbc_location = f"jdbc:{ provider .lower ()} ://{ host } :{ port } /{ database } "
209- else :
210- jdbc_location = f"{ location .lower ()} /{ host } :{ port } /{ database } "
211- for ext_loc in external_locations :
212- if ext_loc .location == jdbc_location :
213- ext_loc .table_count += 1
214- dupe = True
215- break
216- if not dupe :
217- external_locations .append (ExternalLocation (jdbc_location , 1 ))
238+ return f"jdbc:databricks://{ host } ;httpPath={ httppath } "
239+ if "mysql" in location .lower ():
240+ return f"jdbc:mysql://{ host } :{ port } /{ database } "
241+ if not provider == "" :
242+ return f"jdbc:{ provider .lower ()} ://{ host } :{ port } /{ database } "
243+ return f"{ location .lower ()} /{ host } :{ port } /{ database } "
218244
219245 def _crawl (self ) -> Iterable [ExternalLocation ]:
220246 return self ._external_locations ()
0 commit comments