Skip to content

Commit 1585e9b

Browse files
committed
use v4 & v3 vendor hive-metastore modules to support both versions
1 parent c32f86f commit 1585e9b

File tree

11 files changed

+101636
-50
lines changed

11 files changed

+101636
-50
lines changed

pyiceberg/catalog/hive.py

Lines changed: 45 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515
# specific language governing permissions and limitations
1616
# under the License.
1717
import getpass
18+
import importlib
1819
import logging
1920
import socket
2021
import time
@@ -32,12 +33,14 @@
3233
)
3334
from urllib.parse import urlparse
3435

35-
from hive_metastore.ThriftHiveMetastore import Client
36-
from hive_metastore.ttypes import (
36+
from hive_metastore.v4.ThriftHiveMetastore import Client
37+
from hive_metastore.v4.ttypes import (
3738
AlreadyExistsException,
3839
CheckLockRequest,
3940
EnvironmentContext,
4041
FieldSchema,
42+
GetTableRequest,
43+
GetTablesRequest,
4144
InvalidOperationException,
4245
LockComponent,
4346
LockLevel,
@@ -50,9 +53,9 @@
5053
SerDeInfo,
5154
StorageDescriptor,
5255
UnlockRequest,
56+
Database as HiveDatabase,
57+
Table as HiveTable,
5358
)
54-
from hive_metastore.ttypes import Database as HiveDatabase
55-
from hive_metastore.ttypes import Table as HiveTable
5659
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
5760
from thrift.protocol import TBinaryProtocol
5861
from thrift.transport import TSocket, TTransport
@@ -150,6 +153,9 @@ class _HiveClient:
150153

151154
_transport: TTransport
152155
_ugi: Optional[List[str]]
156+
_hive_version: int = 4
157+
_hms_v3: object
158+
_hms_v4: object
153159

154160
def __init__(
155161
self,
@@ -163,6 +169,14 @@ def __init__(
163169
self._kerberos_service_name = kerberos_service_name
164170
self._ugi = ugi.split(":") if ugi else None
165171
self._transport = self._init_thrift_transport()
172+
self.hms_v3 = importlib.import_module('hive_metastore.v3.ThriftHiveMetastore')
173+
self.hms_v4 = importlib.import_module('hive_metastore.v4.ThriftHiveMetastore')
174+
self._hive_version = self._get_hive_version()
175+
176+
def _get_hive_version(self) -> int:
177+
with self as open_client:
178+
major, *_ = open_client.getVersion().split('.')
179+
return int(major)
166180

167181
def _init_thrift_transport(self) -> TTransport:
168182
url_parts = urlparse(self._uri)
@@ -174,7 +188,9 @@ def _init_thrift_transport(self) -> TTransport:
174188

175189
def _client(self) -> Client:
176190
protocol = TBinaryProtocol.TBinaryProtocol(self._transport)
177-
client = Client(protocol)
191+
client: Client = self.hms_v4.Client(protocol)
192+
if self._hive_version and self._hive_version < 4:
193+
client: Client = self.hms_v3.Client(protocol)
178194
if self._ugi:
179195
client.set_ugi(*self._ugi)
180196
return client
@@ -332,6 +348,12 @@ def _create_hive_client(properties: Dict[str, str]) -> _HiveClient:
332348
else:
333349
raise ValueError(f"Unable to connect to hive using uri: {properties[URI]}")
334350

351+
def _update_imported_metastore_modules(self, properties) -> None:
352+
breakpoint()
353+
v3 = importlib.import_module('hive_metastore.v3.ThriftHiveMetastore')
354+
Client = v3.Client
355+
self._client = self._create_hive_client(properties)
356+
335357
def _convert_hive_into_iceberg(self, table: HiveTable) -> Table:
336358
properties: Dict[str, str] = table.parameters
337359
if TABLE_TYPE not in properties:
@@ -387,11 +409,15 @@ def _create_hive_table(self, open_client: Client, hive_table: HiveTable) -> None
387409
except AlreadyExistsException as e:
388410
raise TableAlreadyExistsError(f"Table {hive_table.dbName}.{hive_table.tableName} already exists") from e
389411

390-
def _get_hive_table(self, open_client: Client, database_name: str, table_name: str) -> HiveTable:
391-
try:
392-
return open_client.get_table_objects_by_name(dbname=database_name, tbl_names=[table_name]).pop()
393-
except IndexError as e:
394-
raise NoSuchTableError(f"Table does not exists: {table_name}") from e
412+
def _get_hive_table(self, open_client, *, dbname, tbl_name) -> HiveTable:
413+
if open_client._hive_version < 4:
414+
return open_client.get_table(dbname=dbname, tbl_name=tbl_name)
415+
return open_client.get_table_req(GetTableRequest(dbName=dbname, tblName=tbl_name)).table
416+
417+
def _get_table_objects_by_name(self, open_client, *, dbname, tbl_names) -> list[HiveTable]:
418+
if open_client._hive_version < 4:
419+
return open_client.get_table_objects_by_name(dbname=dbname, tbl_names=tbl_names)
420+
return open_client.get_table_objects_by_name_req(GetTablesRequest(dbName=dbname, tblNames=tbl_names)).tables
395421

396422
def create_table(
397423
self,
@@ -435,10 +461,7 @@ def create_table(
435461

436462
with self._client as open_client:
437463
self._create_hive_table(open_client, tbl)
438-
try:
439-
hive_table = open_client.get_table_objects_by_name(dbname=database_name, tbl_names=[table_name]).pop()
440-
except IndexError as e:
441-
raise NoSuchObjectException("get_table failed: unknown result") from e
464+
hive_table: HiveTable = self._get_hive_table(open_client, dbname=database_name, tbl_name=table_name)
442465

443466
return self._convert_hive_into_iceberg(hive_table)
444467

@@ -468,10 +491,7 @@ def register_table(self, identifier: Union[str, Identifier], metadata_location:
468491
tbl = self._convert_iceberg_into_hive(staged_table)
469492
with self._client as open_client:
470493
self._create_hive_table(open_client, tbl)
471-
try:
472-
hive_table = open_client.get_table_objects_by_name(dbname=database_name, tbl_names=[table_name]).pop()
473-
except IndexError as e:
474-
raise NoSuchObjectException("get_table failed: unknown result") from e
494+
hive_table: HiveTable = self._get_hive_table(open_client, dbname=database_name, tbl_name=table_name)
475495

476496
return self._convert_hive_into_iceberg(hive_table)
477497

@@ -544,7 +564,7 @@ def commit_table(
544564
hive_table: Optional[HiveTable]
545565
current_table: Optional[Table]
546566
try:
547-
hive_table = self._get_hive_table(open_client, database_name, table_name)
567+
hive_table = self._get_hive_table(open_client, dbname=database_name, tbl_name=table_name)
548568
current_table = self._convert_hive_into_iceberg(hive_table)
549569
except NoSuchTableError:
550570
hive_table = None
@@ -618,7 +638,7 @@ def load_table(self, identifier: Union[str, Identifier]) -> Table:
618638
database_name, table_name = self.identifier_to_database_and_table(identifier, NoSuchTableError)
619639

620640
with self._client as open_client:
621-
hive_table = self._get_hive_table(open_client, database_name, table_name)
641+
hive_table = self._get_hive_table(open_client, dbname=database_name, tbl_name=table_name)
622642

623643
return self._convert_hive_into_iceberg(hive_table)
624644

@@ -662,10 +682,7 @@ def rename_table(self, from_identifier: Union[str, Identifier], to_identifier: U
662682
to_database_name, to_table_name = self.identifier_to_database_and_table(to_identifier)
663683
try:
664684
with self._client as open_client:
665-
try:
666-
tbl = open_client.get_table_objects_by_name(dbname=from_database_name, tbl_names=[from_table_name]).pop()
667-
except IndexError as e:
668-
raise NoSuchObjectException("get_table failed: unknown result") from e
685+
tbl: HiveTable = self._get_hive_table(open_client, dbname=from_database_name, tbl_name=from_table_name)
669686
tbl.dbName = to_database_name
670687
tbl.tableName = to_table_name
671688
open_client.alter_table_with_environment_context(
@@ -737,8 +754,9 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
737754
with self._client as open_client:
738755
return [
739756
(database_name, table.tableName)
740-
for table in open_client.get_table_objects_by_name(
741-
dbname=database_name, tbl_names=open_client.get_all_tables(db_name=database_name)
757+
for table in self._get_table_objects_by_name(
758+
open_client, dbname=database_name,
759+
tbl_names=open_client.get_all_tables(db_name=database_name)
742760
)
743761
if table.parameters.get(TABLE_TYPE, "").lower() == ICEBERG
744762
]
@@ -809,7 +827,7 @@ def update_namespace_properties(
809827
if removals:
810828
for key in removals:
811829
if key in parameters:
812-
parameters.pop(key)
830+
parameters[key] = None
813831
removed.add(key)
814832
if updates:
815833
for key, value in updates.items():

tests/catalog/test_hive.py

Lines changed: 36 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -27,7 +27,7 @@
2727

2828
import pytest
2929
import thrift.transport.TSocket
30-
from hive_metastore.ttypes import (
30+
from hive_metastore.v4.ttypes import (
3131
AlreadyExistsException,
3232
EnvironmentContext,
3333
FieldSchema,
@@ -39,9 +39,9 @@
3939
SerDeInfo,
4040
SkewedInfo,
4141
StorageDescriptor,
42+
Database as HiveDatabase,
43+
Table as HiveTable,
4244
)
43-
from hive_metastore.ttypes import Database as HiveDatabase
44-
from hive_metastore.ttypes import Table as HiveTable
4545

4646
from pyiceberg.catalog import PropertiesUpdateSummary
4747
from pyiceberg.catalog.hive import (
@@ -254,6 +254,8 @@ def test_no_uri_supplied() -> None:
254254

255255

256256
def test_check_number_of_namespaces(table_schema_simple: Schema) -> None:
257+
_HiveClient._get_hive_version = MagicMock()
258+
_HiveClient._get_hive_version.return_value = 3
257259
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
258260

259261
with pytest.raises(ValueError):
@@ -280,7 +282,8 @@ def test_create_table(
280282

281283
catalog._client = MagicMock()
282284
catalog._client.__enter__().create_table.return_value = None
283-
catalog._client.__enter__().get_table_objects_by_name.return_value = [hive_table]
285+
catalog._get_hive_table = MagicMock()
286+
catalog._get_hive_table.return_value = hive_table
284287
catalog._client.__enter__().get_database.return_value = hive_database
285288
catalog.create_table(("default", "table"), schema=table_schema_with_all_types, properties={"owner": "javaberg"})
286289

@@ -459,7 +462,8 @@ def test_create_table_with_given_location_removes_trailing_slash(
459462

460463
catalog._client = MagicMock()
461464
catalog._client.__enter__().create_table.return_value = None
462-
catalog._client.__enter__().get_table_objects_by_name.return_value = [hive_table]
465+
catalog._get_hive_table = MagicMock()
466+
catalog._get_hive_table.return_value = hive_table
463467
catalog._client.__enter__().get_database.return_value = hive_database
464468
catalog.create_table(
465469
("default", "table"), schema=table_schema_with_all_types, properties={"owner": "javaberg"}, location=f"{location}/"
@@ -632,8 +636,9 @@ def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabas
632636
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
633637

634638
catalog._client = MagicMock()
639+
catalog._get_hive_table = MagicMock()
635640
catalog._client.__enter__().create_table.return_value = None
636-
catalog._client.__enter__().get_table_objects_by_name.return_value = [hive_table]
641+
catalog._get_hive_table.return_value = hive_table
637642
catalog._client.__enter__().get_database.return_value = hive_database
638643
catalog.create_table(
639644
("default", "table"), schema=table_schema_simple, properties={"owner": "javaberg", "format-version": "1"}
@@ -684,10 +689,11 @@ def test_load_table(hive_table: HiveTable) -> None:
684689
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
685690

686691
catalog._client = MagicMock()
687-
catalog._client.__enter__().get_table_objects_by_name.return_value = [hive_table]
692+
catalog._get_hive_table = MagicMock()
693+
catalog._get_hive_table.return_value = hive_table
688694
table = catalog.load_table(("default", "new_tabl2e"))
689695

690-
catalog._client.__enter__().get_table_objects_by_name.assert_called_with(dbname="default", tbl_names=["new_tabl2e"])
696+
catalog._get_hive_table.assert_called_with(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl2e")
691697

692698
expected = TableMetadataV2(
693699
location="s3://bucket/test/location",
@@ -784,11 +790,12 @@ def test_load_table_from_self_identifier(hive_table: HiveTable) -> None:
784790
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
785791

786792
catalog._client = MagicMock()
787-
catalog._client.__enter__().get_table_objects_by_name.side_effect = lambda dbname, tbl_names: [hive_table]
793+
catalog._get_hive_table = MagicMock()
794+
catalog._get_hive_table.return_value = hive_table
788795
intermediate = catalog.load_table(("default", "new_tabl2e"))
789796
table = catalog.load_table(intermediate.name())
790797

791-
catalog._client.__enter__().get_table_objects_by_name.assert_called_with(dbname="default", tbl_names=["new_tabl2e"])
798+
catalog._get_hive_table.assert_called_with(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl2e")
792799

793800
expected = TableMetadataV2(
794801
location="s3://bucket/test/location",
@@ -889,7 +896,8 @@ def test_rename_table(hive_table: HiveTable) -> None:
889896
renamed_table.tableName = "new_tabl3e"
890897

891898
catalog._client = MagicMock()
892-
catalog._client.__enter__().get_table_objects_by_name.side_effect = [[hive_table], [renamed_table]]
899+
catalog._get_hive_table = MagicMock()
900+
catalog._get_hive_table.side_effect = [hive_table, renamed_table]
893901
catalog._client.__enter__().alter_table_with_environment_context.return_value = None
894902

895903
from_identifier = ("default", "new_tabl2e")
@@ -898,8 +906,8 @@ def test_rename_table(hive_table: HiveTable) -> None:
898906

899907
assert table.name() == to_identifier
900908

901-
calls = [call(dbname="default", tbl_names=["new_tabl2e"]), call(dbname="default", tbl_names=["new_tabl3e"])]
902-
catalog._client.__enter__().get_table_objects_by_name.assert_has_calls(calls)
909+
calls = [call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl2e"), call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl3e")]
910+
catalog._get_hive_table.assert_has_calls(calls)
903911
catalog._client.__enter__().alter_table_with_environment_context.assert_called_with(
904912
dbname="default",
905913
tbl_name="new_tabl2e",
@@ -912,25 +920,26 @@ def test_rename_table_from_self_identifier(hive_table: HiveTable) -> None:
912920
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
913921

914922
catalog._client = MagicMock()
915-
catalog._client.__enter__().get_table_objects_by_name.return_value = [hive_table]
923+
catalog._get_hive_table = MagicMock()
924+
catalog._get_hive_table.return_value = hive_table
916925

917926
from_identifier = ("default", "new_tabl2e")
918927
from_table = catalog.load_table(from_identifier)
919-
catalog._client.__enter__().get_table_objects_by_name.assert_called_with(dbname="default", tbl_names=["new_tabl2e"])
928+
catalog._get_hive_table.assert_called_with(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl2e")
920929

921930
renamed_table = copy.deepcopy(hive_table)
922931
renamed_table.dbName = "default"
923932
renamed_table.tableName = "new_tabl3e"
924933

925-
catalog._client.__enter__().get_table_objects_by_name.side_effect = [[hive_table], [renamed_table]]
934+
catalog._get_hive_table.side_effect = [hive_table, renamed_table]
926935
catalog._client.__enter__().alter_table_with_environment_context.return_value = None
927936
to_identifier = ("default", "new_tabl3e")
928937
table = catalog.rename_table(from_table.name(), to_identifier)
929938

930939
assert table.name() == to_identifier
931940

932-
calls = [call(dbname="default", tbl_names=["new_tabl2e"]), call(dbname="default", tbl_names=["new_tabl3e"])]
933-
catalog._client.__enter__().get_table_objects_by_name.assert_has_calls(calls)
941+
calls = [call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl2e"), call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl3e")]
942+
catalog._get_hive_table.assert_has_calls(calls)
934943
catalog._client.__enter__().alter_table_with_environment_context.assert_called_with(
935944
dbname="default",
936945
tbl_name="new_tabl2e",
@@ -943,6 +952,7 @@ def test_rename_table_from_does_not_exists() -> None:
943952
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
944953

945954
catalog._client = MagicMock()
955+
catalog._client.__enter__()._hive_version = 3
946956
catalog._client.__enter__().alter_table_with_environment_context.side_effect = NoSuchObjectException(
947957
message="hive.default.does_not_exists table not found"
948958
)
@@ -957,6 +967,7 @@ def test_rename_table_to_namespace_does_not_exists() -> None:
957967
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
958968

959969
catalog._client = MagicMock()
970+
catalog._client.__enter__()._hive_version = 3
960971
catalog._client.__enter__().alter_table_with_environment_context.side_effect = InvalidOperationException(
961972
message="Unable to change partition or table. Database default does not exist Check metastore logs for detailed stack.does_not_exists"
962973
)
@@ -1013,13 +1024,14 @@ def test_list_tables(hive_table: HiveTable) -> None:
10131024

10141025
catalog._client = MagicMock()
10151026
catalog._client.__enter__().get_all_tables.return_value = ["table1", "table2", "table3", "table4"]
1016-
catalog._client.__enter__().get_table_objects_by_name.return_value = [tbl1, tbl2, tbl3, tbl4]
1027+
catalog._get_table_objects_by_name = MagicMock()
1028+
catalog._get_table_objects_by_name.return_value = [tbl1, tbl2, tbl3, tbl4]
10171029

10181030
got_tables = catalog.list_tables("database")
10191031
assert got_tables == [("database", "table1"), ("database", "table2")]
10201032
catalog._client.__enter__().get_all_tables.assert_called_with(db_name="database")
1021-
catalog._client.__enter__().get_table_objects_by_name.assert_called_with(
1022-
dbname="database", tbl_names=["table1", "table2", "table3", "table4"]
1033+
catalog._get_table_objects_by_name.assert_called_with(
1034+
catalog._client.__enter__(), dbname="database", tbl_names=["table1", "table2", "table3", "table4"]
10231035
)
10241036

10251037

@@ -1049,7 +1061,8 @@ def test_drop_table_from_self_identifier(hive_table: HiveTable) -> None:
10491061
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
10501062

10511063
catalog._client = MagicMock()
1052-
catalog._client.__enter__().get_table_objects_by_name.return_value = [hive_table]
1064+
catalog._get_hive_table = MagicMock()
1065+
catalog._get_hive_table.return_value = hive_table
10531066
table = catalog.load_table(("default", "new_tabl2e"))
10541067

10551068
catalog._client.__enter__().get_all_databases.return_value = ["namespace1", "namespace2"]
@@ -1156,7 +1169,7 @@ def test_update_namespace_properties(hive_database: HiveDatabase) -> None:
11561169
name="default",
11571170
description=None,
11581171
locationUri=hive_database.locationUri,
1159-
parameters={"label": "core"},
1172+
parameters={"test": None, "label": "core"},
11601173
privileges=None,
11611174
ownerName=None,
11621175
ownerType=1,
File renamed without changes.

0 commit comments

Comments
 (0)