Skip to content

Commit c385d77

Browse files
authored
Use LocationTrie to infer a list of UC external locations (#2965)
This PR removes duplicate code for determining overlapping storage prefixes. this also lays down some groundwork for federated SQL connections
1 parent b635244 commit c385d77

File tree

3 files changed

+120
-114
lines changed

3 files changed

+120
-114
lines changed

src/databricks/labs/ucx/hive_metastore/locations.py

Lines changed: 96 additions & 70 deletions
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,7 @@
66
from dataclasses import dataclass
77
from functools import cached_property
88
from typing import ClassVar, Optional
9-
from urllib.parse import urlparse
9+
from urllib.parse import urlparse, ParseResult
1010

1111
from databricks.labs.blueprint.installation import Installation
1212
from databricks.labs.lsql.backends import SqlBackend
@@ -42,9 +42,9 @@ class LocationTrie:
4242
"""
4343

4444
key: str = ""
45-
parent: OptionalLocationTrie = None
45+
parent: OptionalLocationTrie = dataclasses.field(repr=False, default=None)
4646
children: dict[str, "LocationTrie"] = dataclasses.field(default_factory=dict)
47-
tables: list[Table] = dataclasses.field(default_factory=list)
47+
tables: list[Table] = dataclasses.field(repr=False, default_factory=list)
4848

4949
@cached_property
5050
def _path(self) -> list[str]:
@@ -57,19 +57,47 @@ def _path(self) -> list[str]:
5757
return list(reversed(parts))[1:]
5858

5959
@property
60-
def location(self):
61-
scheme, netloc, *path = self._path
62-
return f"{scheme}://{netloc}/{'/'.join(path)}"
60+
def location(self) -> str | None:
61+
if not self.is_valid():
62+
return None
63+
try:
64+
scheme, netloc, *path = self._path
65+
return f"{scheme}://{netloc}/{'/'.join(path)}".rstrip("/")
66+
except ValueError:
67+
return None
6368

64-
@staticmethod
65-
def _parse_location(location: str | None) -> list[str]:
69+
@classmethod
70+
def _parse_location(cls, location: str | None) -> list[str]:
6671
if not location:
6772
return []
68-
parse_result = urlparse(location)
73+
parse_result = cls._parse_url(location.rstrip("/"))
74+
if not parse_result:
75+
return []
6976
parts = [parse_result.scheme, parse_result.netloc]
70-
parts.extend(parse_result.path.strip("/").split("/"))
77+
for part in parse_result.path.split("/"):
78+
if not part:
79+
continue # remove empty strings
80+
parts.append(part)
7181
return parts
7282

83+
@staticmethod
84+
def _parse_url(location: str) -> ParseResult | None:
85+
parse_result = urlparse(location)
86+
if parse_result.scheme == 'jdbc':
87+
jdbc_path = parse_result.path.split('://')
88+
if len(jdbc_path) != 2:
89+
return None
90+
netloc, path = jdbc_path[1].split('/', 1)
91+
parse_result = ParseResult(
92+
scheme=f'{parse_result.scheme}:{jdbc_path[0]}',
93+
netloc=netloc,
94+
path=path,
95+
params='',
96+
query='',
97+
fragment='',
98+
)
99+
return parse_result
100+
73101
def insert(self, table: Table) -> None:
74102
current = self
75103
for part in self._parse_location(table.location):
@@ -91,11 +119,22 @@ def find(self, table: Table) -> OptionalLocationTrie:
91119

92120
def is_valid(self) -> bool:
93121
"""A valid location has a scheme and netloc; the path is optional."""
94-
if len(self._path) < 3:
122+
if len(self._path) < 2:
95123
return False
96124
scheme, netloc, *_ = self._path
125+
if scheme.startswith('jdbc:') and len(netloc) > 0:
126+
return True
97127
return scheme in _EXTERNAL_FILE_LOCATION_SCHEMES and len(netloc) > 0
98128

129+
def is_jdbc(self) -> bool:
130+
if not self.is_valid():
131+
return False
132+
return self._path[0].startswith('jdbc:')
133+
134+
def all_tables(self) -> Iterable[Table]:
135+
for node in self:
136+
yield from node.tables
137+
99138
def has_children(self):
100139
return len(self.children) > 0
101140

@@ -125,64 +164,59 @@ def __init__(
125164
@cached_property
126165
def _mounts_snapshot(self) -> list['Mount']:
127166
"""Returns all mounts, sorted by longest prefixes first."""
128-
return sorted(self._mounts_crawler.snapshot(), key=lambda _: len(_.name), reverse=True)
167+
return sorted(self._mounts_crawler.snapshot(), key=lambda _: (len(_.name), _.name), reverse=True)
129168

130169
def _external_locations(self) -> Iterable[ExternalLocation]:
131-
min_slash = 2
132-
external_locations: list[ExternalLocation] = []
170+
trie = LocationTrie()
133171
for table in self._tables_crawler.snapshot():
134-
location = table.location
135-
if not location:
172+
table = self._resolve_location(table)
173+
if not table.location:
136174
continue
137-
# TODO: refactor this with LocationTrie
138-
if location.startswith("dbfs:/mnt"):
139-
location = self.resolve_mount(location)
140-
if not location:
175+
trie.insert(table)
176+
queue = list(trie.children.values())
177+
external_locations = []
178+
while queue:
179+
curr = queue.pop()
180+
num_children = len(curr.children) # 0 - take parent
181+
if curr.location and (num_children > 1 or num_children == 0):
182+
if curr.parent and num_children == 0 and not curr.is_jdbc(): # one table having the prefix
183+
curr = curr.parent
184+
assert curr.location is not None
185+
external_location = ExternalLocation(curr.location, len(list(curr.all_tables())))
186+
external_locations.append(external_location)
141187
continue
142-
if (
143-
not location.startswith("dbfs")
144-
and (self._prefix_size[0] < location.find(":/") < self._prefix_size[1])
145-
and not location.startswith("jdbc")
146-
):
147-
self._dbfs_locations(external_locations, location, min_slash)
148-
if location.startswith("jdbc"):
149-
self._add_jdbc_location(external_locations, location, table)
150-
return external_locations
188+
queue.extend(curr.children.values())
189+
return sorted(external_locations, key=lambda _: _.location)
190+
191+
def _resolve_location(self, table: Table) -> Table:
192+
location = table.location
193+
if not location:
194+
return table
195+
location = self._resolve_jdbc(table)
196+
location = self.resolve_mount(location)
197+
return dataclasses.replace(table, location=location)
151198

152199
def resolve_mount(self, location: str | None) -> str | None:
153200
if not location:
154201
return None
202+
if location.startswith('/dbfs'):
203+
location = 'dbfs:' + location[5:] # convert FUSE path to DBFS path
204+
if not location.startswith('dbfs:'):
205+
return location # not a mount, save some cycles
155206
for mount in self._mounts_snapshot:
156-
for prefix in (mount.as_scheme_prefix(), mount.as_fuse_prefix()):
157-
if not location.startswith(prefix):
158-
continue
159-
logger.debug(f"Replacing location {prefix} with {mount.source} in {location}")
160-
location = location.replace(prefix, mount.source)
161-
return location
207+
prefix = mount.as_scheme_prefix()
208+
if not location.startswith(prefix):
209+
continue
210+
logger.debug(f"Replacing location {prefix} with {mount.source} in {location}")
211+
location = location.replace(prefix, mount.source)
212+
return location
162213
logger.debug(f"Mount not found for location {location}. Skipping replacement.")
163214
return location
164215

165-
@staticmethod
166-
def _dbfs_locations(external_locations, location, min_slash):
167-
dupe = False
168-
loc = 0
169-
while loc < len(external_locations) and not dupe:
170-
common = (
171-
os.path.commonpath([external_locations[loc].location, os.path.dirname(location) + "/"]).replace(
172-
":/", "://"
173-
)
174-
+ "/"
175-
)
176-
if common.count("/") > min_slash:
177-
table_count = external_locations[loc].table_count
178-
external_locations[loc] = ExternalLocation(common, table_count + 1)
179-
dupe = True
180-
loc += 1
181-
if not dupe:
182-
external_locations.append(ExternalLocation(os.path.dirname(location) + "/", 1))
183-
184-
def _add_jdbc_location(self, external_locations, location, table):
185-
dupe = False
216+
def _resolve_jdbc(self, table: Table) -> str | None:
217+
location = table.location
218+
if not location or not table.storage_properties or not location.startswith('jdbc:'):
219+
return location
186220
pattern = r"(\w+)=(.*?)(?=\s*,|\s*\])"
187221
# Find all matches in the input string
188222
# Storage properties is of the format
@@ -201,20 +235,12 @@ def _add_jdbc_location(self, external_locations, location, table):
201235
# currently supporting databricks and mysql external tables
202236
# add other jdbc types
203237
if "databricks" in location.lower():
204-
jdbc_location = f"jdbc:databricks://{host};httpPath={httppath}"
205-
elif "mysql" in location.lower():
206-
jdbc_location = f"jdbc:mysql://{host}:{port}/{database}"
207-
elif not provider == "":
208-
jdbc_location = f"jdbc:{provider.lower()}://{host}:{port}/{database}"
209-
else:
210-
jdbc_location = f"{location.lower()}/{host}:{port}/{database}"
211-
for ext_loc in external_locations:
212-
if ext_loc.location == jdbc_location:
213-
ext_loc.table_count += 1
214-
dupe = True
215-
break
216-
if not dupe:
217-
external_locations.append(ExternalLocation(jdbc_location, 1))
238+
return f"jdbc:databricks://{host};httpPath={httppath}"
239+
if "mysql" in location.lower():
240+
return f"jdbc:mysql://{host}:{port}/{database}"
241+
if not provider == "":
242+
return f"jdbc:{provider.lower()}://{host}:{port}/{database}"
243+
return f"{location.lower()}/{host}:{port}/{database}"
218244

219245
def _crawl(self) -> Iterable[ExternalLocation]:
220246
return self._external_locations()

tests/integration/hive_metastore/test_external_locations.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,7 +68,7 @@ def test_external_locations(ws, sql_backend, inventory_schema, env_or_skip):
6868
"bar",
6969
"EXTERNAL",
7070
"delta",
71-
location="jdbc://providerunknown/",
71+
location="jdbc:providerunknown:/",
7272
storage_properties="[database=test_db, host=somedb.us-east-1.rds.amazonaws.com, \
7373
port=1234, dbtable=sometable, user=*********(redacted), password=*********(redacted)]",
7474
),
@@ -81,14 +81,14 @@ def test_external_locations(ws, sql_backend, inventory_schema, env_or_skip):
8181
crawler = ExternalLocations(ws, sql_backend, inventory_schema, tables_crawler, mounts_crawler)
8282
results = crawler.snapshot()
8383
assert results == [
84-
ExternalLocation('s3://test_location/', 2),
85-
ExternalLocation('s3://bar/test3/', 1),
8684
ExternalLocation(
8785
'jdbc:databricks://dbc-test1-aa11.cloud.databricks.com;httpPath=/sql/1.0/warehouses/65b52fb5bd86a7be', 1
8886
),
8987
ExternalLocation('jdbc:mysql://somemysql.us-east-1.rds.amazonaws.com:3306/test_db', 1),
9088
ExternalLocation('jdbc:providerknown://somedb.us-east-1.rds.amazonaws.com:1234/test_db', table_count=2),
91-
ExternalLocation('jdbc://providerunknown//somedb.us-east-1.rds.amazonaws.com:1234/test_db', 1),
89+
ExternalLocation('jdbc:providerunknown://somedb.us-east-1.rds.amazonaws.com:1234/test_db', 1),
90+
ExternalLocation('s3://bar/test3', 1),
91+
ExternalLocation('s3://test_location', 2),
9292
]
9393

9494

tests/unit/hive_metastore/test_locations.py

Lines changed: 20 additions & 40 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@
2222
@pytest.mark.parametrize(
2323
"location",
2424
[
25-
"s3://databricks-e2demofieldengwest/b169/b50"
25+
"s3://databricks-e2demofieldengwest/b169/b50",
2626
"s3a://databricks-datasets-oregon/delta-sharing/share/open-datasets.share",
2727
"s3n://bucket-name/path-to-file-in-bucket",
2828
"gcs://test_location2/test2/table2",
@@ -166,8 +166,9 @@ def test_external_locations():
166166
table_factory(["s3://us-east-1-dev-account-staging-uc-ext-loc-bucket-1/Location/Table2", ""]),
167167
table_factory(["s3://us-east-1-dev-account-staging-uc-ext-loc-bucket-23/testloc/Table3", ""]),
168168
table_factory(["s3://us-east-1-dev-account-staging-uc-ext-loc-bucket-23/anotherloc/Table4", ""]),
169+
table_factory(["gcs://test_location2/a/b/table2", ""]),
169170
table_factory(["dbfs:/mnt/ucx/database1/table1", ""]),
170-
table_factory(["dbfs:/mnt/ucx/database2/table2", ""]),
171+
table_factory(["/dbfs/mnt/ucx/database2/table2", ""]),
171172
table_factory(["DatabricksRootmntDatabricksRoot", ""]),
172173
table_factory(
173174
[
@@ -212,18 +213,18 @@ def test_external_locations():
212213
mounts_crawler.snapshot.return_value = [Mount("/mnt/ucx", "s3://us-east-1-ucx-container")]
213214
sql_backend = MockBackend()
214215
crawler = ExternalLocations(Mock(), sql_backend, "test", tables_crawler, mounts_crawler)
215-
result_set = crawler.snapshot()
216-
assert len(result_set) == 7
217-
assert result_set[0].location == "s3://us-east-1-dev-account-staging-uc-ext-loc-bucket-1/Location/"
218-
assert result_set[0].table_count == 2
219-
assert result_set[1].location == "s3://us-east-1-dev-account-staging-uc-ext-loc-bucket-23/"
220-
assert (
221-
result_set[3].location
222-
== "jdbc:databricks://dbc-test1-aa11.cloud.databricks.com;httpPath=/sql/1.0/warehouses/65b52fb5bd86a7be"
223-
)
224-
assert result_set[4].location == "jdbc:mysql://somemysql.us-east-1.rds.amazonaws.com:3306/test_db"
225-
assert result_set[5].location == "jdbc:providerknown://somedb.us-east-1.rds.amazonaws.com:1234/test_db"
226-
assert result_set[6].location == "jdbc:providerunknown://somedb.us-east-1.rds.amazonaws.com:1234/test_db"
216+
assert crawler.snapshot() == [
217+
ExternalLocation('gcs://test_location2/a/b', 1),
218+
ExternalLocation(
219+
'jdbc:databricks://dbc-test1-aa11.cloud.databricks.com;httpPath=/sql/1.0/warehouses/65b52fb5bd86a7be', 1
220+
),
221+
ExternalLocation('jdbc:mysql://somemysql.us-east-1.rds.amazonaws.com:3306/test_db', 1),
222+
ExternalLocation('jdbc:providerknown://somedb.us-east-1.rds.amazonaws.com:1234/test_db', 2),
223+
ExternalLocation('jdbc:providerunknown://somedb.us-east-1.rds.amazonaws.com:1234/test_db', 1),
224+
ExternalLocation('s3://us-east-1-dev-account-staging-uc-ext-loc-bucket-1/Location', 2),
225+
ExternalLocation('s3://us-east-1-dev-account-staging-uc-ext-loc-bucket-23', 2),
226+
ExternalLocation('s3://us-east-1-ucx-container', 2),
227+
]
227228

228229

229230
LOCATION_STORAGE = MockBackend.rows("location", "storage_properties")
@@ -237,10 +238,7 @@ def test_save_external_location_mapping_missing_location():
237238
tables_crawler = create_autospec(TablesCrawler)
238239
tables_crawler.snapshot.return_value = [
239240
table_factory(["s3://test_location/test1/table1", ""]),
240-
table_factory(["gcs://test_location2/test2/table2", ""]),
241-
table_factory(["abfss://[email protected]/test2/table3", ""]),
242-
table_factory(["s3a://test_location_2/test1/table1", ""]),
243-
table_factory(["s3n://test_location_3/test1/table1", ""]),
241+
table_factory(["s3://test_location/test1/table2", ""]),
244242
]
245243
mounts_crawler = create_autospec(MountsCrawler)
246244
mounts_crawler.snapshot.return_value = []
@@ -257,26 +255,6 @@ def test_save_external_location_mapping_missing_location():
257255
' name = "test_location_test1"\n'
258256
' url = "s3://test_location/test1"\n'
259257
" credential_name = databricks_storage_credential.<storage_credential_reference>.id\n"
260-
"}\n\n"
261-
'resource "databricks_external_location" "test_location2_test2" { \n'
262-
' name = "test_location2_test2"\n'
263-
' url = "gcs://test_location2/test2"\n'
264-
" credential_name = databricks_storage_credential.<storage_credential_reference>.id\n"
265-
"}\n\n"
266-
'resource "databricks_external_location" "cont1_storagetest1_test2" { \n'
267-
' name = "cont1_storagetest1_test2"\n'
268-
' url = "abfss://[email protected]/test2"\n'
269-
" credential_name = databricks_storage_credential.<storage_credential_reference>.id\n"
270-
"}\n\n"
271-
'resource "databricks_external_location" "test_location_2_test1" { \n'
272-
' name = "test_location_2_test1"\n'
273-
' url = "s3a://test_location_2/test1"\n'
274-
" credential_name = databricks_storage_credential.<storage_credential_reference>.id\n"
275-
"}\n\n"
276-
'resource "databricks_external_location" "test_location_3_test1" { \n'
277-
' name = "test_location_3_test1"\n'
278-
' url = "s3n://test_location_3/test1"\n'
279-
" credential_name = databricks_storage_credential.<storage_credential_reference>.id\n"
280258
"}\n"
281259
).encode("utf8"),
282260
)
@@ -317,8 +295,10 @@ def test_match_table_external_locations():
317295
matching_locations, missing_locations = location_crawler.match_table_external_locations()
318296

319297
assert len(matching_locations) == 1
320-
assert ExternalLocation("gcs://test_location2/a/b/", 1) in missing_locations
321-
assert ExternalLocation("abfss://cont1@storagetest1/a/", 2) in missing_locations
298+
assert [
299+
ExternalLocation("abfss://cont1@storagetest1/a", 2),
300+
ExternalLocation("gcs://test_location2/a/b", 1),
301+
] == missing_locations
322302

323303

324304
def test_mount_listing_multiple_folders():

0 commit comments

Comments
 (0)