Skip to content

Commit 3a303be

Browse files
author
xiaohongbo
committed
[python] light refactor for stats collect
[python] remove duplicate compute min and max fix code format fix key fields stats collect twice issue clean code fix fix fix fix fix fix fix clean code fix code format fix code format fix fix code format fix code format fix fix fix
1 parent e3284b4 commit 3a303be

File tree

3 files changed

+158
-37
lines changed

3 files changed

+158
-37
lines changed

paimon-python/pypaimon/tests/reader_base_test.py

Lines changed: 106 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -700,6 +700,112 @@ def _test_value_stats_cols_case(self, manifest_manager, table, value_stats_cols,
700700

701701
self.assertEqual(read_entry.file.value_stats.null_counts, null_counts)
702702

703+
def test_primary_key_value_stats(self):
704+
pa_schema = pa.schema([
705+
('id', pa.int64()),
706+
('name', pa.string()),
707+
('price', pa.float64()),
708+
('category', pa.string())
709+
])
710+
schema = Schema.from_pyarrow_schema(
711+
pa_schema,
712+
primary_keys=['id'],
713+
options={'metadata.stats-mode': 'full', 'bucket': '2'}
714+
)
715+
self.catalog.create_table('default.test_pk_value_stats', schema, False)
716+
table = self.catalog.get_table('default.test_pk_value_stats')
717+
718+
test_data = pa.Table.from_pydict({
719+
'id': [1, 2, 3, 4, 5],
720+
'name': ['Alice', 'Bob', 'Charlie', 'David', 'Eve'],
721+
'price': [10.5, 20.3, 30.7, 40.1, 50.9],
722+
'category': ['A', 'B', 'C', 'D', 'E']
723+
}, schema=pa_schema)
724+
725+
write_builder = table.new_batch_write_builder()
726+
writer = write_builder.new_write()
727+
writer.write_arrow(test_data)
728+
commit_messages = writer.prepare_commit()
729+
commit = write_builder.new_commit()
730+
commit.commit(commit_messages)
731+
writer.close()
732+
733+
read_builder = table.new_read_builder()
734+
table_scan = read_builder.new_scan()
735+
latest_snapshot = SnapshotManager(table).get_latest_snapshot()
736+
manifest_files = table_scan.starting_scanner.manifest_list_manager.read_all(latest_snapshot)
737+
manifest_entries = table_scan.starting_scanner.manifest_file_manager.read(
738+
manifest_files[0].file_name,
739+
lambda row: table_scan.starting_scanner._filter_manifest_entry(row),
740+
False
741+
)
742+
743+
self.assertGreater(len(manifest_entries), 0, "Should have at least one manifest entry")
744+
file_meta = manifest_entries[0].file
745+
746+
key_stats = file_meta.key_stats
747+
self.assertIsNotNone(key_stats, "key_stats should not be None")
748+
self.assertGreater(key_stats.min_values.arity, 0, "key_stats should contain key fields")
749+
self.assertEqual(key_stats.min_values.arity, 1, "key_stats should contain exactly 1 key field (id)")
750+
751+
value_stats = file_meta.value_stats
752+
self.assertIsNotNone(value_stats, "value_stats should not be None")
753+
754+
if file_meta.value_stats_cols is None:
755+
expected_value_fields = ['name', 'price', 'category']
756+
self.assertGreaterEqual(value_stats.min_values.arity, len(expected_value_fields),
757+
f"value_stats should contain at least {len(expected_value_fields)} value fields")
758+
else:
759+
self.assertNotIn('id', file_meta.value_stats_cols,
760+
"Key field 'id' should NOT be in value_stats_cols")
761+
762+
expected_value_fields = ['name', 'price', 'category']
763+
self.assertTrue(set(expected_value_fields).issubset(set(file_meta.value_stats_cols)),
764+
f"value_stats_cols should contain value fields: {expected_value_fields}, "
765+
f"but got: {file_meta.value_stats_cols}")
766+
767+
expected_arity = len(file_meta.value_stats_cols)
768+
self.assertEqual(value_stats.min_values.arity, expected_arity,
769+
f"value_stats should contain {expected_arity} fields (matching value_stats_cols), "
770+
f"but got {value_stats.min_values.arity}")
771+
self.assertEqual(value_stats.max_values.arity, expected_arity,
772+
f"value_stats should contain {expected_arity} fields (matching value_stats_cols), "
773+
f"but got {value_stats.max_values.arity}")
774+
self.assertEqual(len(value_stats.null_counts), expected_arity,
775+
f"value_stats null_counts should have {expected_arity} elements, "
776+
f"but got {len(value_stats.null_counts)}")
777+
778+
self.assertEqual(value_stats.min_values.arity, len(file_meta.value_stats_cols),
779+
f"value_stats.min_values.arity ({value_stats.min_values.arity}) must match "
780+
f"value_stats_cols length ({len(file_meta.value_stats_cols)})")
781+
782+
for field_name in file_meta.value_stats_cols:
783+
is_system_field = (field_name.startswith('_KEY_') or
784+
field_name in ['_SEQUENCE_NUMBER', '_VALUE_KIND', '_ROW_ID'])
785+
self.assertFalse(is_system_field,
786+
f"value_stats_cols should not contain system field: {field_name}")
787+
788+
value_stats_fields = table_scan.starting_scanner.manifest_file_manager._get_value_stats_fields(
789+
{'_VALUE_STATS_COLS': file_meta.value_stats_cols},
790+
table.fields
791+
)
792+
min_value_stats = GenericRowDeserializer.from_bytes(
793+
value_stats.min_values.data,
794+
value_stats_fields
795+
).values
796+
max_value_stats = GenericRowDeserializer.from_bytes(
797+
value_stats.max_values.data,
798+
value_stats_fields
799+
).values
800+
801+
self.assertEqual(len(min_value_stats), 3, "min_value_stats should have 3 values")
802+
self.assertEqual(len(max_value_stats), 3, "max_value_stats should have 3 values")
803+
804+
actual_data = read_builder.new_read().to_arrow(table_scan.plan().splits())
805+
self.assertEqual(actual_data.num_rows, 5, "Should have 5 rows")
806+
actual_ids = sorted(actual_data.column('id').to_pylist())
807+
self.assertEqual(actual_ids, [1, 2, 3, 4, 5], "All IDs should be present")
808+
703809
def test_split_target_size(self):
704810
"""Test source.split.target-size configuration effect on split generation."""
705811
from pypaimon.common.options.core_options import CoreOptions

paimon-python/pypaimon/write/writer/data_blob_writer.py

Lines changed: 3 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -276,14 +276,7 @@ def _create_data_file_meta(self, file_name: str, file_path: str, data: pa.Table,
276276
# Column stats (only for normal columns)
277277
metadata_stats_enabled = self.options.metadata_stats_enabled()
278278
stats_columns = self.normal_columns if metadata_stats_enabled else []
279-
column_stats = {
280-
field.name: self._get_column_stats(data, field.name)
281-
for field in stats_columns
282-
}
283-
284-
min_value_stats = [column_stats[field.name]['min_values'] for field in stats_columns]
285-
max_value_stats = [column_stats[field.name]['max_values'] for field in stats_columns]
286-
value_null_counts = [column_stats[field.name]['null_counts'] for field in stats_columns]
279+
value_stats = self._collect_value_stats(data, stats_columns)
287280

288281
self.sequence_generator.start = self.sequence_generator.current
289282

@@ -293,14 +286,8 @@ def _create_data_file_meta(self, file_name: str, file_path: str, data: pa.Table,
293286
row_count=data.num_rows,
294287
min_key=GenericRow([], []),
295288
max_key=GenericRow([], []),
296-
key_stats=SimpleStats(
297-
GenericRow([], []),
298-
GenericRow([], []),
299-
[]),
300-
value_stats=SimpleStats(
301-
GenericRow(min_value_stats, stats_columns),
302-
GenericRow(max_value_stats, stats_columns),
303-
value_null_counts),
289+
key_stats=SimpleStats.empty_stats(),
290+
value_stats=value_stats,
304291
min_sequence_number=-1,
305292
max_sequence_number=-1,
306293
schema_id=self.table.table_schema.id,

paimon-python/pypaimon/write/writer/data_writer.py

Lines changed: 49 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -192,42 +192,38 @@ def _write_data_to_file(self, data: pa.Table):
192192

193193
# key stats & value stats
194194
value_stats_enabled = self.options.metadata_stats_enabled()
195-
stats_fields = PyarrowFieldParser.to_paimon_schema(data.schema) if value_stats_enabled\
196-
else self.table.trimmed_primary_keys_fields
195+
stats_fields = self.table.fields if self.table.is_primary_key_table \
196+
else PyarrowFieldParser.to_paimon_schema(data.schema)
197197
column_stats = {
198198
field.name: self._get_column_stats(data, field.name)
199199
for field in stats_fields
200200
}
201-
data_fields = stats_fields if value_stats_enabled else []
202-
min_value_stats = [column_stats[field.name]['min_values'] for field in data_fields]
203-
max_value_stats = [column_stats[field.name]['max_values'] for field in data_fields]
204-
value_null_counts = [column_stats[field.name]['null_counts'] for field in data_fields]
205201
key_fields = self.trimmed_primary_keys_fields
206-
min_key_stats = [column_stats[field.name]['min_values'] for field in key_fields]
207-
max_key_stats = [column_stats[field.name]['max_values'] for field in key_fields]
208-
key_null_counts = [column_stats[field.name]['null_counts'] for field in key_fields]
209-
if not all(count == 0 for count in key_null_counts):
202+
key_stats = self._collect_value_stats(data, key_fields, column_stats)
203+
if not all(count == 0 for count in key_stats.null_counts):
210204
raise RuntimeError("Primary key should not be null")
205+
206+
value_fields = stats_fields if value_stats_enabled else []
207+
value_stats = self._collect_value_stats(data, value_fields, column_stats) if value_stats_enabled else SimpleStats.empty_stats()
211208

212209
min_seq = self.sequence_generator.start
213210
max_seq = self.sequence_generator.current
214211
self.sequence_generator.start = self.sequence_generator.current
212+
if value_stats_enabled:
213+
if len(value_fields) == len(self.table.fields):
214+
value_stats_cols = None
215+
else:
216+
value_stats_cols = [field.name for field in value_fields]
217+
else:
218+
value_stats_cols = []
215219
self.committed_files.append(DataFileMeta.create(
216220
file_name=file_name,
217221
file_size=self.file_io.get_file_size(file_path),
218222
row_count=data.num_rows,
219223
min_key=GenericRow(min_key, self.trimmed_primary_keys_fields),
220224
max_key=GenericRow(max_key, self.trimmed_primary_keys_fields),
221-
key_stats=SimpleStats(
222-
GenericRow(min_key_stats, self.trimmed_primary_keys_fields),
223-
GenericRow(max_key_stats, self.trimmed_primary_keys_fields),
224-
key_null_counts,
225-
),
226-
value_stats=SimpleStats(
227-
GenericRow(min_value_stats, data_fields),
228-
GenericRow(max_value_stats, data_fields),
229-
value_null_counts,
230-
),
225+
key_stats=key_stats,
226+
value_stats=value_stats,
231227
min_sequence_number=min_seq,
232228
max_sequence_number=max_seq,
233229
schema_id=self.table.table_schema.id,
@@ -236,7 +232,7 @@ def _write_data_to_file(self, data: pa.Table):
236232
creation_time=Timestamp.now(),
237233
delete_row_count=0,
238234
file_source=0,
239-
value_stats_cols=None if value_stats_enabled else [],
235+
value_stats_cols=value_stats_cols,
240236
external_path=external_path_str, # Set external path if using external paths
241237
first_row_id=None,
242238
write_cols=self.write_cols,
@@ -275,6 +271,27 @@ def _find_optimal_split_point(data: pa.RecordBatch, target_size: int) -> int:
275271

276272
return best_split
277273

274+
def _collect_value_stats(self, data: pa.Table, fields: List,
275+
column_stats: Optional[Dict[str, Dict]] = None) -> SimpleStats:
276+
if not fields:
277+
return SimpleStats.empty_stats()
278+
279+
if column_stats is None:
280+
column_stats = {
281+
field.name: self._get_column_stats(data, field.name)
282+
for field in fields
283+
}
284+
285+
min_stats = [column_stats[field.name]['min_values'] for field in fields]
286+
max_stats = [column_stats[field.name]['max_values'] for field in fields]
287+
null_counts = [column_stats[field.name]['null_counts'] for field in fields]
288+
289+
return SimpleStats(
290+
GenericRow(min_stats, fields),
291+
GenericRow(max_stats, fields),
292+
null_counts
293+
)
294+
278295
@staticmethod
279296
def _get_column_stats(record_batch: pa.RecordBatch, column_name: str) -> Dict:
280297
column_array = record_batch.column(column_name)
@@ -284,6 +301,17 @@ def _get_column_stats(record_batch: pa.RecordBatch, column_name: str) -> Dict:
284301
"max_values": None,
285302
"null_counts": column_array.null_count,
286303
}
304+
305+
column_type = column_array.type
306+
supports_minmax = not (pa.types.is_nested(column_type) or pa.types.is_map(column_type))
307+
308+
if not supports_minmax:
309+
return {
310+
"min_values": None,
311+
"max_values": None,
312+
"null_counts": column_array.null_count,
313+
}
314+
287315
min_values = pc.min(column_array).as_py()
288316
max_values = pc.max(column_array).as_py()
289317
null_counts = column_array.null_count

0 commit comments

Comments
 (0)