Skip to content

Commit 1a61811

Browse files
committed
[python] introduce update_columns in python api.
1 parent 4e20084 commit 1a61811

File tree

2 files changed

+74
-5
lines changed

2 files changed

+74
-5
lines changed

paimon-python/pypaimon/tests/partial_columns_write_test.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -328,6 +328,49 @@ def test_multiple_calls(self):
328328
self.assertEqual(ages, expected_ages, "Age column was not updated correctly")
329329
self.assertEqual(cities, expected_cities, "City column was not updated correctly")
330330

331+
def test_wrong_total_row_count(self):
332+
"""Test that wrong total row count raises an error."""
333+
# Create table with initial data
334+
table = self._create_table()
335+
336+
# Create data evolution writer using BatchTableWrite
337+
write_builder = table.new_batch_write_builder().update_columns_by_row_id()
338+
batch_write = write_builder.new_write().with_write_type(['age'])
339+
340+
# Prepare update data with wrong row count (only 3 rows instead of 5)
341+
update_data = pa.Table.from_pydict({
342+
'_ROW_ID': [0, 1, 2],
343+
'age': [26, 31, 36]
344+
})
345+
346+
# Should raise ValueError for total row count mismatch
347+
with self.assertRaises(ValueError) as context:
348+
batch_write.write_arrow(update_data)
349+
350+
self.assertIn("does not match table total row count", str(context.exception))
351+
batch_write.close()
352+
353+
def test_wrong_first_row_id_row_count(self):
354+
"""Test that wrong row count for a first_row_id raises an error."""
355+
# Create table with initial data
356+
table = self._create_table()
357+
358+
# Create data evolution writer using BatchTableWrite
359+
write_builder = table.new_batch_write_builder().update_columns_by_row_id()
360+
batch_write = write_builder.new_write().with_write_type(['age'])
361+
362+
# Prepare update data with duplicate row_id (violates monotonically increasing)
363+
update_data = pa.Table.from_pydict({
364+
'_ROW_ID': [0, 1, 1, 4, 5],
365+
'age': [26, 31, 36, 37, 38]
366+
})
367+
368+
# Should raise ValueError for row ID validation
369+
with self.assertRaises(ValueError) as context:
370+
batch_write.write_arrow(update_data)
371+
372+
self.assertIn("Row IDs are not monotonically increasing", str(context.exception))
373+
batch_write.close()
331374

332375
if __name__ == '__main__':
333376
unittest.main()

paimon-python/pypaimon/write/partial_column_write.py

Lines changed: 31 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -44,7 +44,10 @@ def __init__(self, table, commit_user: str):
4444
self.commit_user = commit_user
4545

4646
# Load existing first_row_ids and build partition map
47-
self.first_row_ids, self.first_row_id_to_partition_map = self._load_existing_files_info()
47+
(self.first_row_ids,
48+
self.first_row_id_to_partition_map,
49+
self.first_row_id_to_row_count_map,
50+
self.total_row_count) = self._load_existing_files_info()
4851

4952
# Collect commit messages
5053
self.commit_messages = []
@@ -53,18 +56,24 @@ def _load_existing_files_info(self):
5356
"""Load existing first_row_ids and build partition map for efficient lookup."""
5457
first_row_ids = []
5558
first_row_id_to_partition_map: Dict[int, GenericRow] = {}
59+
first_row_id_to_row_count_map: Dict[int, int] = {}
5660

5761
read_builder = self.table.new_read_builder()
5862
scan = read_builder.new_scan()
5963
splits = scan.plan().splits()
6064

6165
for split in splits:
6266
for file in split.files:
63-
if file.first_row_id is not None:
64-
first_row_ids.append(file.first_row_id)
65-
first_row_id_to_partition_map[file.first_row_id] = split.partition
67+
if file.first_row_id is not None and not file.file_name.endswith('.blob'):
68+
first_row_id = file.first_row_id
69+
first_row_ids.append(first_row_id)
70+
first_row_id_to_partition_map[first_row_id] = split.partition
71+
first_row_id_to_row_count_map[first_row_id] = file.row_count
6672

67-
return sorted(list(set(first_row_ids))), first_row_id_to_partition_map
73+
total_row_count = sum(first_row_id_to_row_count_map.values())
74+
75+
return sorted(list(set(first_row_ids))
76+
), first_row_id_to_partition_map, first_row_id_to_row_count_map, total_row_count
6877

6978
def update_columns(self, data: pa.Table, column_names: List[str]) -> List:
7079
"""
@@ -91,6 +100,11 @@ def update_columns(self, data: pa.Table, column_names: List[str]) -> List:
91100
if col_name not in self.table.field_names:
92101
raise ValueError(f"Column {col_name} not found in table schema")
93102

103+
# Validate data row count matches total row count
104+
if data.num_rows != self.total_row_count:
105+
raise ValueError(
106+
f"Input data row count ({data.num_rows}) does not match table total row count ({self.total_row_count})")
107+
94108
# Sort data by _ROW_ID column
95109
sorted_data = data.sort_by([(SpecialFields.ROW_ID.name, "ascending")])
96110

@@ -106,6 +120,12 @@ def _calculate_first_row_id(self, data: pa.Table) -> pa.Table:
106120
"""Calculate _first_row_id for each row based on _ROW_ID."""
107121
row_ids = data[SpecialFields.ROW_ID.name].to_pylist()
108122

123+
# Validate that row_ids are monotonically increasing starting from 0
124+
expected_row_ids = list(range(len(row_ids)))
125+
if row_ids != expected_row_ids:
126+
raise ValueError(f"Row IDs are not monotonically increasing starting from 0. "
127+
f"Expected: {expected_row_ids}")
128+
109129
# Calculate first_row_id for each row_id
110130
first_row_id_values = []
111131
for row_id in row_ids:
@@ -155,6 +175,12 @@ def _write_group(self, partition: GenericRow, first_row_id: int,
155175
data: pa.Table, column_names: List[str]):
156176
"""Write a group of data with the same first_row_id."""
157177

178+
# Validate data row count matches the first_row_id's row count
179+
expected_row_count = self.first_row_id_to_row_count_map.get(first_row_id, 0)
180+
if data.num_rows != expected_row_count:
181+
raise ValueError(
182+
f"Data row count ({data.num_rows}) does not match expected row count ({expected_row_count}) for first_row_id {first_row_id}")
183+
158184
# Create a file store write for this partition
159185
file_store_write = FileStoreWrite(self.table, self.commit_user)
160186

0 commit comments

Comments
 (0)