Skip to content

Commit bfbd52c

Browse files
committed
fix linter and tests
1 parent b0b7873 commit bfbd52c

File tree

4 files changed

+36
-33
lines changed

4 files changed

+36
-33
lines changed

pyiceberg/catalog/hive.py

Lines changed: 12 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,11 @@
5353
SerDeInfo,
5454
StorageDescriptor,
5555
UnlockRequest,
56+
)
57+
from hive_metastore.v4.ttypes import (
5658
Database as HiveDatabase,
59+
)
60+
from hive_metastore.v4.ttypes import (
5761
Table as HiveTable,
5862
)
5963
from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_exponential
@@ -169,13 +173,13 @@ def __init__(
169173
self._kerberos_service_name = kerberos_service_name
170174
self._ugi = ugi.split(":") if ugi else None
171175
self._transport = self._init_thrift_transport()
172-
self.hms_v3 = importlib.import_module('hive_metastore.v3.ThriftHiveMetastore')
173-
self.hms_v4 = importlib.import_module('hive_metastore.v4.ThriftHiveMetastore')
176+
self.hms_v3 = importlib.import_module("hive_metastore.v3.ThriftHiveMetastore")
177+
self.hms_v4 = importlib.import_module("hive_metastore.v4.ThriftHiveMetastore")
174178
self._hive_version = self._get_hive_version()
175179

176180
def _get_hive_version(self) -> int:
177181
with self as open_client:
178-
major, *_ = open_client.getVersion().split('.')
182+
major, *_ = open_client.getVersion().split(".")
179183
return int(major)
180184

181185
def _init_thrift_transport(self) -> TTransport:
@@ -188,10 +192,8 @@ def _init_thrift_transport(self) -> TTransport:
188192

189193
def _client(self) -> Client:
190194
protocol = TBinaryProtocol.TBinaryProtocol(self._transport)
191-
if self._hive_version < 4:
192-
client: Client = self.hms_v3.Client(protocol)
193-
else:
194-
client: Client = self.hms_v4.Client(protocol)
195+
hms = self.hms_v3 if self._hive_version < 4 else self.hms_v4
196+
client: Client = hms.Client(protocol)
195197
if self._ugi:
196198
client.set_ugi(*self._ugi)
197199
return client
@@ -404,12 +406,12 @@ def _create_hive_table(self, open_client: Client, hive_table: HiveTable) -> None
404406
except AlreadyExistsException as e:
405407
raise TableAlreadyExistsError(f"Table {hive_table.dbName}.{hive_table.tableName} already exists") from e
406408

407-
def _get_hive_table(self, open_client, *, dbname, tbl_name) -> HiveTable:
409+
def _get_hive_table(self, open_client: Client, *, dbname: str, tbl_name: str) -> HiveTable:
408410
if open_client._hive_version < 4:
409411
return open_client.get_table(dbname=dbname, tbl_name=tbl_name)
410412
return open_client.get_table_req(GetTableRequest(dbName=dbname, tblName=tbl_name)).table
411413

412-
def _get_table_objects_by_name(self, open_client, *, dbname, tbl_names) -> list[HiveTable]:
414+
def _get_table_objects_by_name(self, open_client: Client, *, dbname: str, tbl_names: list[str]) -> list[HiveTable]:
413415
if open_client._hive_version < 4:
414416
return open_client.get_table_objects_by_name(dbname=dbname, tbl_names=tbl_names)
415417
return open_client.get_table_objects_by_name_req(GetTablesRequest(dbName=dbname, tblNames=tbl_names)).tables
@@ -750,8 +752,7 @@ def list_tables(self, namespace: Union[str, Identifier]) -> List[Identifier]:
750752
return [
751753
(database_name, table.tableName)
752754
for table in self._get_table_objects_by_name(
753-
open_client, dbname=database_name,
754-
tbl_names=open_client.get_all_tables(db_name=database_name)
755+
open_client, dbname=database_name, tbl_names=open_client.get_all_tables(db_name=database_name)
755756
)
756757
if table.parameters.get(TABLE_TYPE, "").lower() == ICEBERG
757758
]

tests/catalog/test_hive.py

Lines changed: 22 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,11 @@
3939
SerDeInfo,
4040
SkewedInfo,
4141
StorageDescriptor,
42+
)
43+
from hive_metastore.v4.ttypes import (
4244
Database as HiveDatabase,
45+
)
46+
from hive_metastore.v4.ttypes import (
4347
Table as HiveTable,
4448
)
4549

@@ -254,8 +258,7 @@ def test_no_uri_supplied() -> None:
254258

255259

256260
def test_check_number_of_namespaces(table_schema_simple: Schema) -> None:
257-
_HiveClient._get_hive_version = MagicMock()
258-
_HiveClient._get_hive_version.return_value = 3
261+
_HiveClient._get_hive_version = MagicMock(return_value=3) # type: ignore
259262
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
260263

261264
with pytest.raises(ValueError):
@@ -282,8 +285,7 @@ def test_create_table(
282285

283286
catalog._client = MagicMock()
284287
catalog._client.__enter__().create_table.return_value = None
285-
catalog._get_hive_table = MagicMock()
286-
catalog._get_hive_table.return_value = hive_table
288+
catalog._get_hive_table = MagicMock(return_value=hive_table) # type: ignore
287289
catalog._client.__enter__().get_database.return_value = hive_database
288290
catalog.create_table(("default", "table"), schema=table_schema_with_all_types, properties={"owner": "javaberg"})
289291

@@ -462,8 +464,7 @@ def test_create_table_with_given_location_removes_trailing_slash(
462464

463465
catalog._client = MagicMock()
464466
catalog._client.__enter__().create_table.return_value = None
465-
catalog._get_hive_table = MagicMock()
466-
catalog._get_hive_table.return_value = hive_table
467+
catalog._get_hive_table = MagicMock(return_value=hive_table) # type: ignore
467468
catalog._client.__enter__().get_database.return_value = hive_database
468469
catalog.create_table(
469470
("default", "table"), schema=table_schema_with_all_types, properties={"owner": "javaberg"}, location=f"{location}/"
@@ -636,7 +637,7 @@ def test_create_v1_table(table_schema_simple: Schema, hive_database: HiveDatabas
636637
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
637638

638639
catalog._client = MagicMock()
639-
catalog._get_hive_table = MagicMock()
640+
catalog._get_hive_table = MagicMock() # type: ignore
640641
catalog._client.__enter__().create_table.return_value = None
641642
catalog._get_hive_table.return_value = hive_table
642643
catalog._client.__enter__().get_database.return_value = hive_database
@@ -689,8 +690,7 @@ def test_load_table(hive_table: HiveTable) -> None:
689690
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
690691

691692
catalog._client = MagicMock()
692-
catalog._get_hive_table = MagicMock()
693-
catalog._get_hive_table.return_value = hive_table
693+
catalog._get_hive_table = MagicMock(return_value=hive_table) # type: ignore
694694
table = catalog.load_table(("default", "new_tabl2e"))
695695

696696
catalog._get_hive_table.assert_called_with(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl2e")
@@ -790,8 +790,7 @@ def test_load_table_from_self_identifier(hive_table: HiveTable) -> None:
790790
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
791791

792792
catalog._client = MagicMock()
793-
catalog._get_hive_table = MagicMock()
794-
catalog._get_hive_table.return_value = hive_table
793+
catalog._get_hive_table = MagicMock(return_value=hive_table) # type: ignore
795794
intermediate = catalog.load_table(("default", "new_tabl2e"))
796795
table = catalog.load_table(intermediate.name())
797796

@@ -896,8 +895,7 @@ def test_rename_table(hive_table: HiveTable) -> None:
896895
renamed_table.tableName = "new_tabl3e"
897896

898897
catalog._client = MagicMock()
899-
catalog._get_hive_table = MagicMock()
900-
catalog._get_hive_table.side_effect = [hive_table, renamed_table]
898+
catalog._get_hive_table = MagicMock(side_effect=[hive_table, renamed_table]) # type: ignore
901899
catalog._client.__enter__().alter_table_with_environment_context.return_value = None
902900

903901
from_identifier = ("default", "new_tabl2e")
@@ -906,7 +904,10 @@ def test_rename_table(hive_table: HiveTable) -> None:
906904

907905
assert table.name() == to_identifier
908906

909-
calls = [call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl2e"), call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl3e")]
907+
calls = [
908+
call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl2e"),
909+
call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl3e"),
910+
]
910911
catalog._get_hive_table.assert_has_calls(calls)
911912
catalog._client.__enter__().alter_table_with_environment_context.assert_called_with(
912913
dbname="default",
@@ -920,8 +921,7 @@ def test_rename_table_from_self_identifier(hive_table: HiveTable) -> None:
920921
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
921922

922923
catalog._client = MagicMock()
923-
catalog._get_hive_table = MagicMock()
924-
catalog._get_hive_table.return_value = hive_table
924+
catalog._get_hive_table = MagicMock(return_value=hive_table) # type: ignore
925925

926926
from_identifier = ("default", "new_tabl2e")
927927
from_table = catalog.load_table(from_identifier)
@@ -938,7 +938,10 @@ def test_rename_table_from_self_identifier(hive_table: HiveTable) -> None:
938938

939939
assert table.name() == to_identifier
940940

941-
calls = [call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl2e"), call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl3e")]
941+
calls = [
942+
call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl2e"),
943+
call(catalog._client.__enter__(), dbname="default", tbl_name="new_tabl3e"),
944+
]
942945
catalog._get_hive_table.assert_has_calls(calls)
943946
catalog._client.__enter__().alter_table_with_environment_context.assert_called_with(
944947
dbname="default",
@@ -1024,7 +1027,7 @@ def test_list_tables(hive_table: HiveTable) -> None:
10241027

10251028
catalog._client = MagicMock()
10261029
catalog._client.__enter__().get_all_tables.return_value = ["table1", "table2", "table3", "table4"]
1027-
catalog._get_table_objects_by_name = MagicMock()
1030+
catalog._get_table_objects_by_name = MagicMock() # type: ignore
10281031
catalog._get_table_objects_by_name.return_value = [tbl1, tbl2, tbl3, tbl4]
10291032

10301033
got_tables = catalog.list_tables("database")
@@ -1061,8 +1064,7 @@ def test_drop_table_from_self_identifier(hive_table: HiveTable) -> None:
10611064
catalog = HiveCatalog(HIVE_CATALOG_NAME, uri=HIVE_METASTORE_FAKE_URL)
10621065

10631066
catalog._client = MagicMock()
1064-
catalog._get_hive_table = MagicMock()
1065-
catalog._get_hive_table.return_value = hive_table
1067+
catalog._get_hive_table = MagicMock(return_value=hive_table) # type: ignore
10661068
table = catalog.load_table(("default", "new_tabl2e"))
10671069

10681070
catalog._client.__enter__().get_all_databases.return_value = ["namespace1", "namespace2"]

tests/integration/test_reads.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
import pyarrow as pa
2727
import pyarrow.parquet as pq
2828
import pytest
29-
from hive_metastore.ttypes import LockRequest, LockResponse, LockState, UnlockRequest
29+
from hive_metastore.v4.ttypes import LockRequest, LockResponse, LockState, UnlockRequest
3030
from pyarrow.fs import S3FileSystem
3131
from pydantic_core import ValidationError
3232
from pyspark.sql import SparkSession

tests/integration/test_writes/test_writes.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1174,7 +1174,7 @@ def test_hive_catalog_storage_descriptor_has_changed(
11741174
schema.update_column("binary", doc="this is binary")
11751175

11761176
with session_catalog_hive._client as open_client:
1177-
hive_table = session_catalog_hive._get_hive_table(open_client, "default", "test_storage_descriptor")
1177+
hive_table = session_catalog_hive._get_hive_table(open_client, dbname="default", tbl_name="test_storage_descriptor")
11781178
assert "this is string_long" in str(hive_table.sd)
11791179
assert "this is binary" in str(hive_table.sd)
11801180

0 commit comments

Comments
 (0)