Skip to content

Commit 26c7246

Browse files
saraivdbxSara Ivanyosnfx
authored
Crawl Table ACLs in all databases (#122)
This PR enumerates all databases for `GrantsCrawler`. --------- Co-authored-by: Sara Ivanyos <[email protected]> Co-authored-by: Serge Smertin <[email protected]>
1 parent d8f78a4 commit 26c7246

File tree

11 files changed

+192
-46
lines changed

11 files changed

+192
-46
lines changed

examples/migration_config.yml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,10 @@ inventory:
44
database: default
55
name: uc_migration_inventory
66

7+
tacl:
8+
databases: [ "default" ]
79

8-
with_table_acls: False
10+
warehouse_id: None
911

1012
groups:
1113
selected: [ "analyst" ]

notebooks/toolkit.py

Lines changed: 24 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -36,7 +36,8 @@
3636
# COMMAND ----------
3737

3838
from databricks.labs.ucx.toolkits.group_migration import GroupMigrationToolkit
39-
from databricks.labs.ucx.config import MigrationConfig, InventoryConfig, GroupsConfig, InventoryTable
39+
from databricks.labs.ucx.config import MigrationConfig, InventoryConfig, GroupsConfig, InventoryTable, TaclConfig
40+
from databricks.labs.ucx.toolkits.table_acls import TaclToolkit
4041

4142
# COMMAND ----------
4243

@@ -47,17 +48,28 @@
4748
# COMMAND ----------
4849

4950
config = MigrationConfig(
50-
with_table_acls=False,
5151
inventory=InventoryConfig(table=InventoryTable(catalog="main", database="default", name="ucx_migration_inventory")),
5252
groups=GroupsConfig(
5353
# use this option to select specific groups manually
5454
selected=["groupA", "groupB"],
5555
# use this option to select all groups automatically
5656
# auto=True
5757
),
58+
tacl=TaclConfig(
59+
# use this option to select specific databases manually
60+
databases=["default"],
61+
# use this option to select all databases automatically
62+
# auto=True
63+
),
5864
log_level="DEBUG",
5965
)
6066
toolkit = GroupMigrationToolkit(config)
67+
tacltoolkit = TaclToolkit(
68+
toolkit._ws,
69+
config.inventory.table.catalog,
70+
config.inventory.table.schema,
71+
databases=config.tacl.databases,
72+
)
6173

6274
# COMMAND ----------
6375

@@ -152,6 +164,16 @@
152164

153165
# COMMAND ----------
154166

167+
# MAGIC %md
168+
# MAGIC
169+
# MAGIC ## Inventorize Table ACL's
170+
171+
# COMMAND ----------
172+
173+
tacltoolkit.grants_snapshot()
174+
175+
# COMMAND ----------
176+
155177
# MAGIC %md
156178
# MAGIC
157179
# MAGIC ## Cleanup the inventory table

pyproject.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -187,4 +187,4 @@ exclude_lines = [
187187
"no cov",
188188
"if __name__ == .__main__.:",
189189
"if TYPE_CHECKING:",
190-
]
190+
]

src/databricks/labs/ucx/config.py

Lines changed: 20 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -90,10 +90,28 @@ def from_dict(cls, raw: dict):
9090
return cls(**raw)
9191

9292

93+
@dataclass
94+
class TaclConfig:
95+
databases: list[str] | None = None
96+
auto: bool | None = None
97+
98+
def __post_init__(self):
99+
if not self.databases and self.auto is None:
100+
msg = "Either selected or auto must be set"
101+
raise ValueError(msg)
102+
if self.databases and self.auto is False:
103+
msg = "No selected groups provided, but auto-collection is disabled"
104+
raise ValueError(msg)
105+
106+
@classmethod
107+
def from_dict(cls, raw: dict):
108+
return cls(**raw)
109+
110+
93111
@dataclass
94112
class MigrationConfig:
95113
inventory: InventoryConfig
96-
with_table_acls: bool
114+
tacl: TaclConfig
97115
groups: GroupsConfig
98116
connect: ConnectConfig | None = None
99117
num_threads: int | None = 4
@@ -102,9 +120,6 @@ class MigrationConfig:
102120
def __post_init__(self):
103121
if self.connect is None:
104122
self.connect = ConnectConfig()
105-
if self.with_table_acls:
106-
msg = "Table ACLS are not yet implemented"
107-
raise NotImplementedError(msg)
108123

109124
def as_dict(self) -> dict:
110125
from dataclasses import fields, is_dataclass
@@ -126,7 +141,7 @@ def inner(x):
126141
def from_dict(cls, raw: dict) -> "MigrationConfig":
127142
return cls(
128143
inventory=InventoryConfig.from_dict(raw.get("inventory", {})),
129-
with_table_acls=raw.get("with_table_acls", False),
144+
tacl=TaclConfig.from_dict(raw.get("tacl", {})),
130145
groups=GroupsConfig.from_dict(raw.get("groups", {})),
131146
connect=ConnectConfig.from_dict(raw.get("connect", {})),
132147
num_threads=raw.get("num_threads", 4),

src/databricks/labs/ucx/tacl/_internal.py

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -143,15 +143,25 @@ def _snapshot(self, klass, fetcher, loader) -> list[any]:
143143
Returns:
144144
list[any]: A list of data records, either fetched or loaded.
145145
"""
146+
loaded = False
147+
trigger_load = ValueError("trigger records load")
146148
while True:
147149
try:
148150
logger.debug(f"[{self._full_name}] fetching {self._table} inventory")
149-
return list(fetcher())
151+
cached_results = list(fetcher())
152+
if len(cached_results) == 0 and loaded:
153+
return cached_results
154+
if len(cached_results) == 0 and not loaded:
155+
raise trigger_load
156+
return cached_results
150157
except Exception as e:
151-
if "TABLE_OR_VIEW_NOT_FOUND" not in str(e):
158+
if not (e == trigger_load or "TABLE_OR_VIEW_NOT_FOUND" in str(e)):
152159
raise e
153-
logger.debug(f"[{self._full_name}] {self._table} inventory not found, crawling")
154-
self._append_records(klass, loader())
160+
logger.debug(f"[{self._full_name}] crawling new batch for {self._table}")
161+
loaded_records = list(loader())
162+
if len(loaded_records) > 0:
163+
self._append_records(klass, loaded_records)
164+
loaded = True
155165

156166
@staticmethod
157167
def _row_to_sql(row, fields):

src/databricks/labs/ucx/toolkits/table_acls.py

Lines changed: 28 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,5 @@
1+
import logging
2+
13
from databricks.sdk import WorkspaceClient
24

35
from databricks.labs.ucx.tacl._internal import (
@@ -8,17 +10,38 @@
810
from databricks.labs.ucx.tacl.grants import GrantsCrawler
911
from databricks.labs.ucx.tacl.tables import TablesCrawler
1012

13+
logger = logging.getLogger(__name__)
14+
1115

1216
class TaclToolkit:
13-
def __init__(self, ws: WorkspaceClient, inventory_catalog, inventory_schema, warehouse_id=None):
17+
def __init__(
18+
self,
19+
ws: WorkspaceClient,
20+
inventory_catalog,
21+
inventory_schema,
22+
warehouse_id=None,
23+
databases=None,
24+
):
1425
self._tc = TablesCrawler(self._backend(ws, warehouse_id), inventory_catalog, inventory_schema)
1526
self._gc = GrantsCrawler(self._tc)
1627

17-
def database_snapshot(self, schema):
18-
return self._tc.snapshot("hive_metastore", schema)
28+
self._databases = (
29+
databases if databases else [database.as_dict()["databaseName"] for database in self._tc._all_databases()]
30+
)
31+
32+
def database_snapshot(self):
33+
tables = []
34+
for db in self._databases:
35+
for t in self._tc.snapshot("hive_metastore", db):
36+
tables.append(t)
37+
return tables
1938

20-
def grants_snapshot(self, schema):
21-
return self._gc.snapshot("hive_metastore", schema)
39+
def grants_snapshot(self):
40+
grants = []
41+
for db in self._databases:
42+
for grant in self._gc.snapshot("hive_metastore", db):
43+
grants.append(grant)
44+
return grants
2245

2346
@staticmethod
2447
def _backend(ws: WorkspaceClient, warehouse_id: str | None = None) -> SqlBackend:

tests/integration/test_tacls.py

Lines changed: 34 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -8,16 +8,22 @@
88
logger = logging.getLogger(__name__)
99

1010

11-
def test_describe_all_tables(ws: WorkspaceClient, make_catalog, make_schema, make_table):
11+
def test_describe_all_tables_in_databases(ws: WorkspaceClient, make_catalog, make_schema, make_table):
1212
warehouse_id = os.environ["TEST_DEFAULT_WAREHOUSE_ID"]
1313

1414
logger.info("setting up fixtures")
15-
schema = make_schema(catalog="hive_metastore")
16-
managed_table = make_table(schema=schema)
17-
external_table = make_table(schema=schema, external=True)
18-
tmp_table = make_table(schema=schema, ctas="SELECT 2+2 AS four")
19-
view = make_table(schema=schema, ctas="SELECT 2+2 AS four", view=True)
20-
non_delta = make_table(schema=schema, non_detla=True)
15+
16+
schema_a = make_schema(catalog="hive_metastore")
17+
18+
schema_b = make_schema(catalog="hive_metastore")
19+
20+
schema_c = make_schema(catalog="hive_metastore")
21+
22+
managed_table = make_table(schema=schema_a)
23+
external_table = make_table(schema=schema_b, external=True)
24+
tmp_table = make_table(schema=schema_a, ctas="SELECT 2+2 AS four")
25+
view = make_table(schema=schema_b, ctas="SELECT 2+2 AS four", view=True)
26+
non_delta = make_table(schema=schema_a, non_detla=True)
2127

2228
logger.info(
2329
f"managed_table={managed_table}, "
@@ -28,10 +34,13 @@ def test_describe_all_tables(ws: WorkspaceClient, make_catalog, make_schema, mak
2834

2935
inventory_schema = make_schema(catalog=make_catalog())
3036
inventory_catalog, inventory_schema = inventory_schema.split(".")
31-
tak = TaclToolkit(ws, inventory_catalog, inventory_schema, warehouse_id)
37+
38+
databases = [schema_a.split(".")[1], schema_b.split(".")[1], schema_c.split(".")[1]]
39+
40+
tak = TaclToolkit(ws, inventory_catalog, inventory_schema, warehouse_id=warehouse_id, databases=databases)
3241

3342
all_tables = {}
34-
for t in tak.database_snapshot(schema.split(".")[1]):
43+
for t in tak.database_snapshot():
3544
all_tables[t.key] = t
3645

3746
assert len(all_tables) == 5
@@ -43,28 +52,33 @@ def test_describe_all_tables(ws: WorkspaceClient, make_catalog, make_schema, mak
4352
assert all_tables[view].view_text == "SELECT 2+2 AS four"
4453

4554

46-
def test_all_grants_in_database(ws: WorkspaceClient, sql_exec, make_catalog, make_schema, make_table, make_group):
55+
def test_all_grants_in_databases(ws: WorkspaceClient, sql_exec, make_catalog, make_schema, make_table, make_group):
4756
warehouse_id = os.environ["TEST_DEFAULT_WAREHOUSE_ID"]
4857

4958
group_a = make_group()
5059
group_b = make_group()
51-
schema = make_schema()
52-
table = make_table(schema=schema, external=True)
60+
schema_a = make_schema()
61+
schema_b = make_schema()
62+
table_a = make_table(schema=schema_a)
63+
table_b = make_table(schema=schema_b)
5364

5465
sql_exec(f"GRANT USAGE ON SCHEMA default TO `{group_a.display_name}`")
5566
sql_exec(f"GRANT USAGE ON SCHEMA default TO `{group_b.display_name}`")
56-
sql_exec(f"GRANT SELECT ON TABLE {table} TO `{group_a.display_name}`")
57-
sql_exec(f"GRANT MODIFY ON SCHEMA {schema} TO `{group_b.display_name}`")
67+
sql_exec(f"GRANT SELECT ON TABLE {table_a} TO `{group_a.display_name}`")
68+
sql_exec(f"GRANT SELECT ON TABLE {table_b} TO `{group_b.display_name}`")
69+
sql_exec(f"GRANT MODIFY ON SCHEMA {schema_b} TO `{group_b.display_name}`")
5870

5971
inventory_schema = make_schema(catalog=make_catalog())
6072
inventory_catalog, inventory_schema = inventory_schema.split(".")
61-
tak = TaclToolkit(ws, inventory_catalog, inventory_schema, warehouse_id)
73+
74+
tak = TaclToolkit(ws, inventory_catalog, inventory_schema, warehouse_id=warehouse_id)
6275

6376
all_grants = {}
64-
for grant in tak.grants_snapshot(schema.split(".")[1]):
65-
logger.info(f"grant:\n{grant}\n hive: {grant.hive_grant_sql()}\n uc: {grant.uc_grant_sql()}")
77+
for grant in tak.grants_snapshot():
78+
logging.info(f"grant:\n{grant}\n hive: {grant.hive_grant_sql()}\n uc: {grant.uc_grant_sql()}")
6679
all_grants[f"{grant.principal}.{grant.object_key}"] = grant.action_type
6780

68-
assert len(all_grants) >= 2, "must have at least two grants"
69-
assert all_grants[f"{group_a.display_name}.{table}"] == "SELECT"
70-
assert all_grants[f"{group_b.display_name}.{schema}"] == "MODIFY"
81+
assert len(all_grants) >= 3, "must have at least three grants"
82+
assert all_grants[f"{group_a.display_name}.{table_a}"] == "SELECT"
83+
assert all_grants[f"{group_b.display_name}.{table_b}"] == "SELECT"
84+
assert all_grants[f"{group_b.display_name}.{schema_b}"] == "MODIFY"

tests/unit/test_config.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,14 +3,14 @@
33
from functools import partial
44
from pathlib import Path
55

6-
import pytest
76
import yaml
87

98
from databricks.labs.ucx.config import (
109
GroupsConfig,
1110
InventoryConfig,
1211
InventoryTable,
1312
MigrationConfig,
13+
TaclConfig,
1414
)
1515

1616

@@ -19,12 +19,9 @@ def test_initialization():
1919
MigrationConfig,
2020
inventory=InventoryConfig(table=InventoryTable(catalog="catalog", database="database", name="name")),
2121
groups=GroupsConfig(auto=True),
22+
tacl=TaclConfig(databases=["default"]),
2223
)
23-
24-
with pytest.raises(NotImplementedError):
25-
mc(with_table_acls=True)
26-
27-
mc(with_table_acls=False)
24+
mc()
2825

2926

3027
# path context manager
@@ -54,9 +51,10 @@ def test_reader(tmp_path: Path):
5451
MigrationConfig,
5552
inventory=InventoryConfig(table=InventoryTable(catalog="catalog", database="database", name="name")),
5653
groups=GroupsConfig(auto=True),
54+
tacl=TaclConfig(databases=["default"]),
5755
)
5856

59-
config: MigrationConfig = mc(with_table_acls=False)
57+
config: MigrationConfig = mc()
6058
config_file = tmp_path / "config.yml"
6159

6260
as_dict = config.as_dict()

tests/unit/test_crawler_base.py

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,13 @@ class Foo:
1313
second: bool
1414

1515

16+
@dataclass
17+
class Bar:
18+
first: str
19+
second: bool
20+
third: float
21+
22+
1623
def test_invalid():
1724
with pytest.raises(ValueError):
1825
CrawlerBase(MockBackend(), "a.a.a", "b", "c")
@@ -42,6 +49,22 @@ def fetcher():
4249
assert insert == b.queries[0]
4350

4451

52+
def test_snapshot_appends_incorrect_type():
53+
b = MockBackend()
54+
cb = CrawlerBase(b, "a", "b", "c")
55+
runs = []
56+
57+
def fetcher():
58+
if len(runs) == 0:
59+
runs.append(1)
60+
msg = "TABLE_OR_VIEW_NOT_FOUND"
61+
raise RuntimeError(msg)
62+
return []
63+
64+
with pytest.raises(ValueError):
65+
cb._snapshot(Bar, fetcher=fetcher, loader=lambda: [Bar(first="first", second=True, third=3.14)])
66+
67+
4568
def test_snapshot_appends_to_new_table():
4669
b = MockBackend(fails_on_first={"INSERT INTO a.b.c": "TABLE_OR_VIEW_NOT_FOUND ..."})
4770
cb = CrawlerBase(b, "a", "b", "c")
@@ -62,3 +85,19 @@ def fetcher():
6285
assert insert == b.queries[0]
6386
assert create == b.queries[1]
6487
assert insert == b.queries[2]
88+
89+
90+
def test_snapshot_wrong_error():
91+
b = MockBackend(fails_on_first={"INSERT INTO a.b.c": "TABLE_NOT_FOUND ..."})
92+
cb = CrawlerBase(b, "a", "b", "c")
93+
runs = []
94+
95+
def fetcher():
96+
if len(runs) == 0:
97+
runs.append(1)
98+
msg = "TABLE_OR_VIEW_NOT_FOUND"
99+
raise RuntimeError(msg)
100+
return []
101+
102+
with pytest.raises(RuntimeError):
103+
cb._snapshot(Foo, fetcher=fetcher, loader=lambda: [Foo(first="first", second=True)])

0 commit comments

Comments
 (0)