Skip to content

Commit 3e391a7

Browse files
authored
Clear updates/requirements after commit (#1961)
# Rationale for this change Resolves #1946 # Are these changes tested? Yes, using a test that used to fail before :) # Are there any user-facing changes? <!-- In the case of user-facing changes, please add the changelog label. -->
1 parent 14ee8da commit 3e391a7

File tree

2 files changed

+32
-11
lines changed

2 files changed

+32
-11
lines changed

pyiceberg/table/__init__.py

Lines changed: 15 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -292,8 +292,6 @@ def _apply(self, updates: Tuple[TableUpdate, ...], requirements: Tuple[TableRequ
292292

293293
if self._autocommit:
294294
self.commit_transaction()
295-
self._updates = ()
296-
self._requirements = ()
297295

298296
return self
299297

@@ -937,13 +935,15 @@ def commit_transaction(self) -> Table:
937935
updates=self._updates,
938936
requirements=self._requirements,
939937
)
940-
return self._table
941-
else:
942-
return self._table
938+
939+
self._updates = ()
940+
self._requirements = ()
941+
942+
return self._table
943943

944944

945945
class CreateTableTransaction(Transaction):
946-
"""A transaction that involves the creation of a a new table."""
946+
"""A transaction that involves the creation of a new table."""
947947

948948
def _initial_changes(self, table_metadata: TableMetadata) -> None:
949949
"""Set the initial changes that can reconstruct the initial table metadata when creating the CreateTableTransaction."""
@@ -988,11 +988,15 @@ def commit_transaction(self) -> Table:
988988
Returns:
989989
The table with the updates applied.
990990
"""
991-
self._requirements = (AssertCreate(),)
992-
self._table._do_commit( # pylint: disable=W0212
993-
updates=self._updates,
994-
requirements=self._requirements,
995-
)
991+
if len(self._updates) > 0:
992+
self._table._do_commit( # pylint: disable=W0212
993+
updates=self._updates,
994+
requirements=(AssertCreate(),),
995+
)
996+
997+
self._updates = ()
998+
self._requirements = ()
999+
9961000
return self._table
9971001

9981002

tests/integration/test_writes/test_writes.py

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1802,6 +1802,23 @@ def test_write_optional_list(session_catalog: Catalog) -> None:
18021802
assert len(session_catalog.load_table(identifier).scan().to_arrow()) == 4
18031803

18041804

1805+
@pytest.mark.integration
1806+
@pytest.mark.parametrize("format_version", [1, 2])
1807+
def test_double_commit_transaction(
1808+
spark: SparkSession, session_catalog: Catalog, arrow_table_with_null: pa.Table, format_version: int
1809+
) -> None:
1810+
identifier = "default.arrow_data_files"
1811+
tbl = _create_table(session_catalog, identifier, {"format-version": format_version}, [])
1812+
1813+
assert len(tbl.metadata.metadata_log) == 0
1814+
1815+
with tbl.transaction() as tx:
1816+
tx.append(arrow_table_with_null)
1817+
tx.commit_transaction()
1818+
1819+
assert len(tbl.metadata.metadata_log) == 1
1820+
1821+
18051822
@pytest.mark.integration
18061823
@pytest.mark.parametrize("format_version", [1, 2])
18071824
def test_evolve_and_write(

0 commit comments

Comments
 (0)