Skip to content

Commit 990ce80

Browse files
authored
Make add_files to support snapshot_properties argument (#695)
1 parent 0eb0c1c commit 990ce80

File tree

2 files changed

+36
-12
lines changed

2 files changed

+36
-12
lines changed

pyiceberg/table/__init__.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -443,7 +443,7 @@ def overwrite(
443443
for data_file in data_files:
444444
update_snapshot.append_data_file(data_file)
445445

446-
def add_files(self, file_paths: List[str]) -> None:
446+
def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
447447
"""
448448
Shorthand API for adding files as data files to the table transaction.
449449
@@ -455,7 +455,7 @@ def add_files(self, file_paths: List[str]) -> None:
455455
"""
456456
if self._table.name_mapping() is None:
457457
self.set_properties(**{TableProperties.DEFAULT_NAME_MAPPING: self._table.schema().name_mapping.model_dump_json()})
458-
with self.update_snapshot().fast_append() as update_snapshot:
458+
with self.update_snapshot(snapshot_properties=snapshot_properties).fast_append() as update_snapshot:
459459
data_files = _parquet_files_to_data_files(
460460
table_metadata=self._table.metadata, file_paths=file_paths, io=self._table.io
461461
)
@@ -1341,7 +1341,7 @@ def overwrite(
13411341
with self.transaction() as tx:
13421342
tx.overwrite(df=df, overwrite_filter=overwrite_filter, snapshot_properties=snapshot_properties)
13431343

1344-
def add_files(self, file_paths: List[str]) -> None:
1344+
def add_files(self, file_paths: List[str], snapshot_properties: Dict[str, str] = EMPTY_DICT) -> None:
13451345
"""
13461346
Shorthand API for adding files as data files to the table.
13471347
@@ -1352,7 +1352,7 @@ def add_files(self, file_paths: List[str]) -> None:
13521352
FileNotFoundError: If the file does not exist.
13531353
"""
13541354
with self.transaction() as tx:
1355-
tx.add_files(file_paths=file_paths)
1355+
tx.add_files(file_paths=file_paths, snapshot_properties=snapshot_properties)
13561356

13571357
def update_spec(self, case_sensitive: bool = True) -> UpdateSpec:
13581358
return UpdateSpec(Transaction(self, autocommit=True), case_sensitive=case_sensitive)

tests/integration/test_add_files.py

Lines changed: 32 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717
# pylint:disable=redefined-outer-name
1818

1919
from datetime import date
20-
from typing import Optional
20+
from typing import Iterator, Optional
2121

2222
import pyarrow as pa
2323
import pyarrow.parquet as pq
@@ -122,8 +122,13 @@ def _create_table(
122122
return tbl
123123

124124

125+
@pytest.fixture(name="format_version", params=[pytest.param(1, id="format_version=1"), pytest.param(2, id="format_version=2")])
126+
def format_version_fixure(request: pytest.FixtureRequest) -> Iterator[int]:
127+
"""Fixture to run tests with different table format versions."""
128+
yield request.param
129+
130+
125131
@pytest.mark.integration
126-
@pytest.mark.parametrize("format_version", [1, 2])
127132
def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
128133
identifier = f"default.unpartitioned_table_v{format_version}"
129134
tbl = _create_table(session_catalog, identifier, format_version)
@@ -163,7 +168,6 @@ def test_add_files_to_unpartitioned_table(spark: SparkSession, session_catalog:
163168

164169

165170
@pytest.mark.integration
166-
@pytest.mark.parametrize("format_version", [1, 2])
167171
def test_add_files_to_unpartitioned_table_raises_file_not_found(
168172
spark: SparkSession, session_catalog: Catalog, format_version: int
169173
) -> None:
@@ -184,7 +188,6 @@ def test_add_files_to_unpartitioned_table_raises_file_not_found(
184188

185189

186190
@pytest.mark.integration
187-
@pytest.mark.parametrize("format_version", [1, 2])
188191
def test_add_files_to_unpartitioned_table_raises_has_field_ids(
189192
spark: SparkSession, session_catalog: Catalog, format_version: int
190193
) -> None:
@@ -205,7 +208,6 @@ def test_add_files_to_unpartitioned_table_raises_has_field_ids(
205208

206209

207210
@pytest.mark.integration
208-
@pytest.mark.parametrize("format_version", [1, 2])
209211
def test_add_files_to_unpartitioned_table_with_schema_updates(
210212
spark: SparkSession, session_catalog: Catalog, format_version: int
211213
) -> None:
@@ -263,7 +265,6 @@ def test_add_files_to_unpartitioned_table_with_schema_updates(
263265

264266

265267
@pytest.mark.integration
266-
@pytest.mark.parametrize("format_version", [1, 2])
267268
def test_add_files_to_partitioned_table(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
268269
identifier = f"default.partitioned_table_v{format_version}"
269270

@@ -335,7 +336,6 @@ def test_add_files_to_partitioned_table(spark: SparkSession, session_catalog: Ca
335336

336337

337338
@pytest.mark.integration
338-
@pytest.mark.parametrize("format_version", [1, 2])
339339
def test_add_files_to_bucket_partitioned_table_fails(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
340340
identifier = f"default.partitioned_table_bucket_fails_v{format_version}"
341341

@@ -378,7 +378,6 @@ def test_add_files_to_bucket_partitioned_table_fails(spark: SparkSession, sessio
378378

379379

380380
@pytest.mark.integration
381-
@pytest.mark.parametrize("format_version", [1, 2])
382381
def test_add_files_to_partitioned_table_fails_with_lower_and_upper_mismatch(
383382
spark: SparkSession, session_catalog: Catalog, format_version: int
384383
) -> None:
@@ -424,3 +423,28 @@ def test_add_files_to_partitioned_table_fails_with_lower_and_upper_mismatch(
424423
"Cannot infer partition value from parquet metadata as there are more than one partition values for Partition Field: baz. lower_value=123, upper_value=124"
425424
in str(exc_info.value)
426425
)
426+
427+
428+
@pytest.mark.integration
429+
def test_add_files_snapshot_properties(spark: SparkSession, session_catalog: Catalog, format_version: int) -> None:
430+
identifier = f"default.unpartitioned_table_v{format_version}"
431+
tbl = _create_table(session_catalog, identifier, format_version)
432+
433+
file_paths = [f"s3://warehouse/default/unpartitioned/v{format_version}/test-{i}.parquet" for i in range(5)]
434+
# write parquet files
435+
for file_path in file_paths:
436+
fo = tbl.io.new_output(file_path)
437+
with fo.create(overwrite=True) as fos:
438+
with pq.ParquetWriter(fos, schema=ARROW_SCHEMA) as writer:
439+
writer.write_table(ARROW_TABLE)
440+
441+
# add the parquet files as data files
442+
tbl.add_files(file_paths=file_paths, snapshot_properties={"snapshot_prop_a": "test_prop_a"})
443+
444+
# NameMapping must have been set to enable reads
445+
assert tbl.name_mapping() is not None
446+
447+
summary = spark.sql(f"SELECT * FROM {identifier}.snapshots;").collect()[0].summary
448+
449+
assert "snapshot_prop_a" in summary
450+
assert summary["snapshot_prop_a"] == "test_prop_a"

0 commit comments

Comments
 (0)