Skip to content

Commit 5039b5d

Browse files
FokkoHonahX
andauthored
Change DataScan to accept Metadata and io (#581)
* Change DataScan to accept Metadata and io For the partial deletes I want to do a scan on in memory metadata. Changing this API allows this. * fix name-mapping issue --------- Co-authored-by: HonahX <[email protected]>
1 parent 07442cc commit 5039b5d

File tree

5 files changed

+102
-110
lines changed

5 files changed

+102
-110
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 14 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -159,7 +159,7 @@
159159
from pyiceberg.utils.truncate import truncate_upper_bound_binary_string, truncate_upper_bound_text_string
160160

161161
if TYPE_CHECKING:
162-
from pyiceberg.table import FileScanTask, Table
162+
from pyiceberg.table import FileScanTask
163163

164164
logger = logging.getLogger(__name__)
165165

@@ -1046,7 +1046,8 @@ def _read_all_delete_files(fs: FileSystem, tasks: Iterable[FileScanTask]) -> Dic
10461046

10471047
def project_table(
10481048
tasks: Iterable[FileScanTask],
1049-
table: Table,
1049+
table_metadata: TableMetadata,
1050+
io: FileIO,
10501051
row_filter: BooleanExpression,
10511052
projected_schema: Schema,
10521053
case_sensitive: bool = True,
@@ -1056,7 +1057,8 @@ def project_table(
10561057
10571058
Args:
10581059
tasks (Iterable[FileScanTask]): A URI or a path to a local file.
1059-
table (Table): The table that's being queried.
1060+
table_metadata (TableMetadata): The table metadata of the table that's being queried
1061+
io (FileIO): A FileIO to open streams to the object store
10601062
row_filter (BooleanExpression): The expression for filtering rows.
10611063
projected_schema (Schema): The output schema.
10621064
case_sensitive (bool): Case sensitivity when looking up column names.
@@ -1065,24 +1067,24 @@ def project_table(
10651067
Raises:
10661068
ResolveError: When an incompatible query is done.
10671069
"""
1068-
scheme, netloc, _ = PyArrowFileIO.parse_location(table.location())
1069-
if isinstance(table.io, PyArrowFileIO):
1070-
fs = table.io.fs_by_scheme(scheme, netloc)
1070+
scheme, netloc, _ = PyArrowFileIO.parse_location(table_metadata.location)
1071+
if isinstance(io, PyArrowFileIO):
1072+
fs = io.fs_by_scheme(scheme, netloc)
10711073
else:
10721074
try:
10731075
from pyiceberg.io.fsspec import FsspecFileIO
10741076

1075-
if isinstance(table.io, FsspecFileIO):
1077+
if isinstance(io, FsspecFileIO):
10761078
from pyarrow.fs import PyFileSystem
10771079

1078-
fs = PyFileSystem(FSSpecHandler(table.io.get_fs(scheme)))
1080+
fs = PyFileSystem(FSSpecHandler(io.get_fs(scheme)))
10791081
else:
1080-
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}")
1082+
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}")
10811083
except ModuleNotFoundError as e:
10821084
# When FsSpec is not installed
1083-
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {table.io}") from e
1085+
raise ValueError(f"Expected PyArrowFileIO or FsspecFileIO, got: {io}") from e
10841086

1085-
bound_row_filter = bind(table.schema(), row_filter, case_sensitive=case_sensitive)
1087+
bound_row_filter = bind(table_metadata.schema(), row_filter, case_sensitive=case_sensitive)
10861088

10871089
projected_field_ids = {
10881090
id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType))
@@ -1101,7 +1103,7 @@ def project_table(
11011103
deletes_per_file.get(task.file.file_path),
11021104
case_sensitive,
11031105
limit,
1104-
table.name_mapping(),
1106+
table_metadata.name_mapping(),
11051107
)
11061108
for task in tasks
11071109
]

pyiceberg/table/__init__.py

Lines changed: 29 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -103,7 +103,6 @@
103103
)
104104
from pyiceberg.table.name_mapping import (
105105
NameMapping,
106-
parse_mapping_from_json,
107106
update_mapping,
108107
)
109108
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef
@@ -1215,7 +1214,8 @@ def scan(
12151214
limit: Optional[int] = None,
12161215
) -> DataScan:
12171216
return DataScan(
1218-
table=self,
1217+
table_metadata=self.metadata,
1218+
io=self.io,
12191219
row_filter=row_filter,
12201220
selected_fields=selected_fields,
12211221
case_sensitive=case_sensitive,
@@ -1312,10 +1312,7 @@ def update_schema(self, allow_incompatible_changes: bool = False, case_sensitive
13121312

13131313
def name_mapping(self) -> Optional[NameMapping]:
13141314
"""Return the table's field-id NameMapping."""
1315-
if name_mapping_json := self.properties.get(TableProperties.DEFAULT_NAME_MAPPING):
1316-
return parse_mapping_from_json(name_mapping_json)
1317-
else:
1318-
return None
1315+
return self.metadata.name_mapping()
13191316

13201317
def append(self, df: pa.Table, snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
13211318
"""
@@ -1468,7 +1465,8 @@ def _parse_row_filter(expr: Union[str, BooleanExpression]) -> BooleanExpression:
14681465

14691466

14701467
class TableScan(ABC):
1471-
table: Table
1468+
table_metadata: TableMetadata
1469+
io: FileIO
14721470
row_filter: BooleanExpression
14731471
selected_fields: Tuple[str, ...]
14741472
case_sensitive: bool
@@ -1478,15 +1476,17 @@ class TableScan(ABC):
14781476

14791477
def __init__(
14801478
self,
1481-
table: Table,
1479+
table_metadata: TableMetadata,
1480+
io: FileIO,
14821481
row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE,
14831482
selected_fields: Tuple[str, ...] = ("*",),
14841483
case_sensitive: bool = True,
14851484
snapshot_id: Optional[int] = None,
14861485
options: Properties = EMPTY_DICT,
14871486
limit: Optional[int] = None,
14881487
):
1489-
self.table = table
1488+
self.table_metadata = table_metadata
1489+
self.io = io
14901490
self.row_filter = _parse_row_filter(row_filter)
14911491
self.selected_fields = selected_fields
14921492
self.case_sensitive = case_sensitive
@@ -1496,19 +1496,20 @@ def __init__(
14961496

14971497
def snapshot(self) -> Optional[Snapshot]:
14981498
if self.snapshot_id:
1499-
return self.table.snapshot_by_id(self.snapshot_id)
1500-
return self.table.current_snapshot()
1499+
return self.table_metadata.snapshot_by_id(self.snapshot_id)
1500+
return self.table_metadata.current_snapshot()
15011501

15021502
def projection(self) -> Schema:
1503-
current_schema = self.table.schema()
1503+
current_schema = self.table_metadata.schema()
15041504
if self.snapshot_id is not None:
1505-
snapshot = self.table.snapshot_by_id(self.snapshot_id)
1505+
snapshot = self.table_metadata.snapshot_by_id(self.snapshot_id)
15061506
if snapshot is not None:
15071507
if snapshot.schema_id is not None:
1508-
snapshot_schema = self.table.schemas().get(snapshot.schema_id)
1509-
if snapshot_schema is not None:
1510-
current_schema = snapshot_schema
1511-
else:
1508+
try:
1509+
current_schema = next(
1510+
schema for schema in self.table_metadata.schemas if schema.schema_id == snapshot.schema_id
1511+
)
1512+
except StopIteration:
15121513
warnings.warn(f"Metadata does not contain schema with id: {snapshot.schema_id}")
15131514
else:
15141515
raise ValueError(f"Snapshot not found: {self.snapshot_id}")
@@ -1534,7 +1535,7 @@ def update(self: S, **overrides: Any) -> S:
15341535
def use_ref(self: S, name: str) -> S:
15351536
if self.snapshot_id:
15361537
raise ValueError(f"Cannot override ref, already set snapshot id={self.snapshot_id}")
1537-
if snapshot := self.table.snapshot_by_name(name):
1538+
if snapshot := self.table_metadata.snapshot_by_name(name):
15381539
return self.update(snapshot_id=snapshot.snapshot_id)
15391540

15401541
raise ValueError(f"Cannot scan unknown ref={name}")
@@ -1626,33 +1627,21 @@ def _match_deletes_to_data_file(data_entry: ManifestEntry, positional_delete_ent
16261627

16271628

16281629
class DataScan(TableScan):
1629-
def __init__(
1630-
self,
1631-
table: Table,
1632-
row_filter: Union[str, BooleanExpression] = ALWAYS_TRUE,
1633-
selected_fields: Tuple[str, ...] = ("*",),
1634-
case_sensitive: bool = True,
1635-
snapshot_id: Optional[int] = None,
1636-
options: Properties = EMPTY_DICT,
1637-
limit: Optional[int] = None,
1638-
):
1639-
super().__init__(table, row_filter, selected_fields, case_sensitive, snapshot_id, options, limit)
1640-
16411630
def _build_partition_projection(self, spec_id: int) -> BooleanExpression:
1642-
project = inclusive_projection(self.table.schema(), self.table.specs()[spec_id])
1631+
project = inclusive_projection(self.table_metadata.schema(), self.table_metadata.specs()[spec_id])
16431632
return project(self.row_filter)
16441633

16451634
@cached_property
16461635
def partition_filters(self) -> KeyDefaultDict[int, BooleanExpression]:
16471636
return KeyDefaultDict(self._build_partition_projection)
16481637

16491638
def _build_manifest_evaluator(self, spec_id: int) -> Callable[[ManifestFile], bool]:
1650-
spec = self.table.specs()[spec_id]
1651-
return manifest_evaluator(spec, self.table.schema(), self.partition_filters[spec_id], self.case_sensitive)
1639+
spec = self.table_metadata.specs()[spec_id]
1640+
return manifest_evaluator(spec, self.table_metadata.schema(), self.partition_filters[spec_id], self.case_sensitive)
16521641

16531642
def _build_partition_evaluator(self, spec_id: int) -> Callable[[DataFile], bool]:
1654-
spec = self.table.specs()[spec_id]
1655-
partition_type = spec.partition_type(self.table.schema())
1643+
spec = self.table_metadata.specs()[spec_id]
1644+
partition_type = spec.partition_type(self.table_metadata.schema())
16561645
partition_schema = Schema(*partition_type.fields)
16571646
partition_expr = self.partition_filters[spec_id]
16581647

@@ -1687,16 +1676,14 @@ def plan_files(self) -> Iterable[FileScanTask]:
16871676
if not snapshot:
16881677
return iter([])
16891678

1690-
io = self.table.io
1691-
16921679
# step 1: filter manifests using partition summaries
16931680
# the filter depends on the partition spec used to write the manifest file, so create a cache of filters for each spec id
16941681

16951682
manifest_evaluators: Dict[int, Callable[[ManifestFile], bool]] = KeyDefaultDict(self._build_manifest_evaluator)
16961683

16971684
manifests = [
16981685
manifest_file
1699-
for manifest_file in snapshot.manifests(io)
1686+
for manifest_file in snapshot.manifests(self.io)
17001687
if manifest_evaluators[manifest_file.partition_spec_id](manifest_file)
17011688
]
17021689

@@ -1705,7 +1692,7 @@ def plan_files(self) -> Iterable[FileScanTask]:
17051692

17061693
partition_evaluators: Dict[int, Callable[[DataFile], bool]] = KeyDefaultDict(self._build_partition_evaluator)
17071694
metrics_evaluator = _InclusiveMetricsEvaluator(
1708-
self.table.schema(), self.row_filter, self.case_sensitive, self.options.get("include_empty_files") == "true"
1695+
self.table_metadata.schema(), self.row_filter, self.case_sensitive, self.options.get("include_empty_files") == "true"
17091696
).eval
17101697

17111698
min_data_sequence_number = _min_data_file_sequence_number(manifests)
@@ -1719,7 +1706,7 @@ def plan_files(self) -> Iterable[FileScanTask]:
17191706
lambda args: _open_manifest(*args),
17201707
[
17211708
(
1722-
io,
1709+
self.io,
17231710
manifest,
17241711
partition_evaluators[manifest.partition_spec_id],
17251712
metrics_evaluator,
@@ -1755,7 +1742,8 @@ def to_arrow(self) -> pa.Table:
17551742

17561743
return project_table(
17571744
self.plan_files(),
1758-
self.table,
1745+
self.table_metadata,
1746+
self.io,
17591747
self.row_filter,
17601748
self.projection(),
17611749
case_sensitive=self.case_sensitive,

pyiceberg/table/metadata.py

Lines changed: 14 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -35,6 +35,7 @@
3535
from pyiceberg.exceptions import ValidationError
3636
from pyiceberg.partitioning import PARTITION_FIELD_ID_START, PartitionSpec, assign_fresh_partition_spec_ids
3737
from pyiceberg.schema import Schema, assign_fresh_schema_ids
38+
from pyiceberg.table.name_mapping import NameMapping, parse_mapping_from_json
3839
from pyiceberg.table.refs import MAIN_BRANCH, SnapshotRef, SnapshotRefType
3940
from pyiceberg.table.snapshots import MetadataLogEntry, Snapshot, SnapshotLogEntry
4041
from pyiceberg.table.sorting import (
@@ -237,6 +238,13 @@ def schema(self) -> Schema:
237238
"""Return the schema for this table."""
238239
return next(schema for schema in self.schemas if schema.schema_id == self.current_schema_id)
239240

241+
def name_mapping(self) -> Optional[NameMapping]:
242+
"""Return the table's field-id NameMapping."""
243+
if name_mapping_json := self.properties.get("schema.name-mapping.default"):
244+
return parse_mapping_from_json(name_mapping_json)
245+
else:
246+
return None
247+
240248
def spec(self) -> PartitionSpec:
241249
"""Return the partition spec of this table."""
242250
return next(spec for spec in self.partition_specs if spec.spec_id == self.default_spec_id)
@@ -278,6 +286,12 @@ def new_snapshot_id(self) -> int:
278286

279287
return snapshot_id
280288

289+
def snapshot_by_name(self, name: str) -> Optional[Snapshot]:
290+
"""Return the snapshot referenced by the given name or null if no such reference exists."""
291+
if ref := self.refs.get(name):
292+
return self.snapshot_by_id(ref.snapshot_id)
293+
return None
294+
281295
def current_snapshot(self) -> Optional[Snapshot]:
282296
"""Get the current snapshot for this table, or None if there is no current snapshot."""
283297
if self.current_snapshot_id is not None:

tests/integration/test_add_files.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -158,6 +158,9 @@ def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog:
158158
for col in df.columns:
159159
assert df.filter(df[col].isNotNull()).count() == 5, "Expected all 5 rows to be non-null"
160160

161+
# check that the table can be read by pyiceberg
162+
assert len(tbl.scan().to_arrow()) == 5, "Expected 5 rows"
163+
161164

162165
@pytest.mark.integration
163166
@pytest.mark.parametrize("format_version", [1, 2])
@@ -255,6 +258,9 @@ def test_add_files_to_unpartitioned_table_with_schema_updates(
255258
value_count = 1 if col == "quux" else 6
256259
assert df.filter(df[col].isNotNull()).count() == value_count, f"Expected {value_count} rows to be non-null"
257260

261+
# check that the table can be read by pyiceberg
262+
assert len(tbl.scan().to_arrow()) == 6, "Expected 6 rows"
263+
258264

259265
@pytest.mark.integration
260266
@pytest.mark.parametrize("format_version", [1, 2])
@@ -324,6 +330,9 @@ def test_add_files_to_partitioned_table(spark: SparkSession, session_catalog: Ca
324330
assert [row.file_count for row in partition_rows] == [5]
325331
assert [(row.partition.baz, row.partition.qux_month) for row in partition_rows] == [(123, 650)]
326332

333+
# check that the table can be read by pyiceberg
334+
assert len(tbl.scan().to_arrow()) == 5, "Expected 5 rows"
335+
327336

328337
@pytest.mark.integration
329338
@pytest.mark.parametrize("format_version", [1, 2])

0 commit comments

Comments
 (0)