Skip to content

Commit ba9ff98

Browse files
authored
Add typealias for table version (#566)
* typealias for table version * typealias for table version * typealias for table version * typealias for table version * typealias for table version * typealias for table version replaced in all files
1 parent 474b37b commit ba9ff98

File tree

4 files changed

+24
-18
lines changed

4 files changed

+24
-18
lines changed

pyiceberg/manifest.py

Lines changed: 12 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@
3737
from pyiceberg.io import FileIO, InputFile, OutputFile
3838
from pyiceberg.partitioning import PartitionSpec
3939
from pyiceberg.schema import Schema
40-
from pyiceberg.typedef import EMPTY_DICT, Record
40+
from pyiceberg.typedef import EMPTY_DICT, Record, TableVersion
4141
from pyiceberg.types import (
4242
BinaryType,
4343
BooleanType,
@@ -302,7 +302,7 @@ def _(partition_field_type: PrimitiveType) -> PrimitiveType:
302302
return partition_field_type
303303

304304

305-
def data_file_with_partition(partition_type: StructType, format_version: Literal[1, 2]) -> StructType:
305+
def data_file_with_partition(partition_type: StructType, format_version: TableVersion) -> StructType:
306306
data_file_partition_type = StructType(*[
307307
NestedField(
308308
field_id=field.field_id,
@@ -372,7 +372,7 @@ def __setattr__(self, name: str, value: Any) -> None:
372372
value = FileFormat[value]
373373
super().__setattr__(name, value)
374374

375-
def __init__(self, format_version: Literal[1, 2] = DEFAULT_READ_VERSION, *data: Any, **named_data: Any) -> None:
375+
def __init__(self, format_version: TableVersion = DEFAULT_READ_VERSION, *data: Any, **named_data: Any) -> None:
376376
super().__init__(
377377
*data,
378378
**{"struct": DATA_FILE_TYPE[format_version], **named_data},
@@ -408,7 +408,7 @@ def __eq__(self, other: Any) -> bool:
408408
MANIFEST_ENTRY_SCHEMAS_STRUCT = {format_version: schema.as_struct() for format_version, schema in MANIFEST_ENTRY_SCHEMAS.items()}
409409

410410

411-
def manifest_entry_schema_with_data_file(format_version: Literal[1, 2], data_file: StructType) -> Schema:
411+
def manifest_entry_schema_with_data_file(format_version: TableVersion, data_file: StructType) -> Schema:
412412
return Schema(*[
413413
NestedField(2, "data_file", data_file, required=True) if field.field_id == 2 else field
414414
for field in MANIFEST_ENTRY_SCHEMAS[format_version].fields
@@ -719,9 +719,9 @@ def content(self) -> ManifestContent: ...
719719

720720
@property
721721
@abstractmethod
722-
def version(self) -> Literal[1, 2]: ...
722+
def version(self) -> TableVersion: ...
723723

724-
def _with_partition(self, format_version: Literal[1, 2]) -> Schema:
724+
def _with_partition(self, format_version: TableVersion) -> Schema:
725725
data_file_type = data_file_with_partition(
726726
format_version=format_version, partition_type=self._spec.partition_type(self._schema)
727727
)
@@ -807,7 +807,7 @@ def content(self) -> ManifestContent:
807807
return ManifestContent.DATA
808808

809809
@property
810-
def version(self) -> Literal[1, 2]:
810+
def version(self) -> TableVersion:
811811
return 1
812812

813813
def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry:
@@ -834,7 +834,7 @@ def content(self) -> ManifestContent:
834834
return ManifestContent.DATA
835835

836836
@property
837-
def version(self) -> Literal[1, 2]:
837+
def version(self) -> TableVersion:
838838
return 2
839839

840840
def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry:
@@ -847,7 +847,7 @@ def prepare_entry(self, entry: ManifestEntry) -> ManifestEntry:
847847

848848

849849
def write_manifest(
850-
format_version: Literal[1, 2], spec: PartitionSpec, schema: Schema, output_file: OutputFile, snapshot_id: int
850+
format_version: TableVersion, spec: PartitionSpec, schema: Schema, output_file: OutputFile, snapshot_id: int
851851
) -> ManifestWriter:
852852
if format_version == 1:
853853
return ManifestWriterV1(spec, schema, output_file, snapshot_id)
@@ -858,14 +858,14 @@ def write_manifest(
858858

859859

860860
class ManifestListWriter(ABC):
861-
_format_version: Literal[1, 2]
861+
_format_version: TableVersion
862862
_output_file: OutputFile
863863
_meta: Dict[str, str]
864864
_manifest_files: List[ManifestFile]
865865
_commit_snapshot_id: int
866866
_writer: AvroOutputFile[ManifestFile]
867867

868-
def __init__(self, format_version: Literal[1, 2], output_file: OutputFile, meta: Dict[str, Any]):
868+
def __init__(self, format_version: TableVersion, output_file: OutputFile, meta: Dict[str, Any]):
869869
self._format_version = format_version
870870
self._output_file = output_file
871871
self._meta = meta
@@ -957,7 +957,7 @@ def prepare_manifest(self, manifest_file: ManifestFile) -> ManifestFile:
957957

958958

959959
def write_manifest_list(
960-
format_version: Literal[1, 2],
960+
format_version: TableVersion,
961961
output_file: OutputFile,
962962
snapshot_id: int,
963963
parent_snapshot_id: Optional[int],

pyiceberg/table/__init__.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@
121121
KeyDefaultDict,
122122
Properties,
123123
Record,
124+
TableVersion,
124125
)
125126
from pyiceberg.types import (
126127
IcebergType,
@@ -293,7 +294,7 @@ def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequ
293294

294295
return self
295296

296-
def upgrade_table_version(self, format_version: Literal[1, 2]) -> Transaction:
297+
def upgrade_table_version(self, format_version: TableVersion) -> Transaction:
297298
"""Set the table to a certain version.
298299
299300
Args:
@@ -1023,7 +1024,7 @@ def scan(
10231024
)
10241025

10251026
@property
1026-
def format_version(self) -> Literal[1, 2]:
1027+
def format_version(self) -> TableVersion:
10271028
return self.metadata.format_version
10281029

10291030
def schema(self) -> Schema:

pyiceberg/typedef.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
Dict,
2727
Generic,
2828
List,
29+
Literal,
2930
Optional,
3031
Protocol,
3132
Set,
@@ -37,6 +38,7 @@
3738
from uuid import UUID
3839

3940
from pydantic import BaseModel, ConfigDict, RootModel
41+
from typing_extensions import TypeAlias
4042

4143
if TYPE_CHECKING:
4244
from pyiceberg.types import StructType
@@ -199,3 +201,6 @@ def __repr__(self) -> str:
199201
def record_fields(self) -> List[str]:
200202
"""Return values of all the fields of the Record class except those specified in skip_fields."""
201203
return [self.__getattribute__(v) if hasattr(self, v) else None for v in self._position_to_field_name]
204+
205+
206+
TableVersion: TypeAlias = Literal[1, 2]

tests/utils/test_manifest.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
# under the License.
1717
# pylint: disable=redefined-outer-name,arguments-renamed,fixme
1818
from tempfile import TemporaryDirectory
19-
from typing import Dict, Literal
19+
from typing import Dict
2020

2121
import fastavro
2222
import pytest
@@ -39,7 +39,7 @@
3939
from pyiceberg.schema import Schema
4040
from pyiceberg.table.snapshots import Operation, Snapshot, Summary
4141
from pyiceberg.transforms import IdentityTransform
42-
from pyiceberg.typedef import Record
42+
from pyiceberg.typedef import Record, TableVersion
4343
from pyiceberg.types import IntegerType, NestedField
4444

4545

@@ -308,7 +308,7 @@ def test_read_manifest_v2(generated_manifest_file_file_v2: str) -> None:
308308

309309
@pytest.mark.parametrize("format_version", [1, 2])
310310
def test_write_manifest(
311-
generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: Literal[1, 2]
311+
generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: TableVersion
312312
) -> None:
313313
io = load_file_io()
314314
snapshot = Snapshot(
@@ -478,7 +478,7 @@ def test_write_manifest(
478478

479479
@pytest.mark.parametrize("format_version", [1, 2])
480480
def test_write_manifest_list(
481-
generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: Literal[1, 2]
481+
generated_manifest_file_file_v1: str, generated_manifest_file_file_v2: str, format_version: TableVersion
482482
) -> None:
483483
io = load_file_io()
484484

0 commit comments

Comments
 (0)