Skip to content

Commit 0d12cf4

Browse files
authored
Fix race condition on Table.scan with limit (#545)
1 parent cc8552d commit 0d12cf4

File tree

1 file changed

+4
-14
lines changed

1 file changed

+4
-14
lines changed

pyiceberg/io/pyarrow.py

Lines changed: 4 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -946,13 +946,9 @@ def _task_to_table(
946946
projected_field_ids: Set[int],
947947
positional_deletes: Optional[List[ChunkedArray]],
948948
case_sensitive: bool,
949-
row_counts: List[int],
950949
limit: Optional[int] = None,
951950
name_mapping: Optional[NameMapping] = None,
952951
) -> Optional[pa.Table]:
953-
if limit and sum(row_counts) >= limit:
954-
return None
955-
956952
_, _, path = PyArrowFileIO.parse_location(task.file.file_path)
957953
arrow_format = ds.ParquetFileFormat(pre_buffer=True, buffer_size=(ONE_MEGABYTE * 8))
958954
with fs.open_input_file(path) as fin:
@@ -1015,11 +1011,6 @@ def _task_to_table(
10151011
if len(arrow_table) < 1:
10161012
return None
10171013

1018-
if limit is not None and sum(row_counts) >= limit:
1019-
return None
1020-
1021-
row_counts.append(len(arrow_table))
1022-
10231014
return to_requested_schema(projected_schema, file_project_schema, arrow_table)
10241015

10251016

@@ -1085,7 +1076,6 @@ def project_table(
10851076
id for id in projected_schema.field_ids if not isinstance(projected_schema.find_type(id), (MapType, ListType))
10861077
}.union(extract_field_ids(bound_row_filter))
10871078

1088-
row_counts: List[int] = []
10891079
deletes_per_file = _read_all_delete_files(fs, tasks)
10901080
executor = ExecutorFactory.get_or_create()
10911081
futures = [
@@ -1098,21 +1088,21 @@ def project_table(
10981088
projected_field_ids,
10991089
deletes_per_file.get(task.file.file_path),
11001090
case_sensitive,
1101-
row_counts,
11021091
limit,
11031092
table.name_mapping(),
11041093
)
11051094
for task in tasks
11061095
]
1107-
1096+
total_row_count = 0
11081097
# for consistent ordering, we need to maintain future order
11091098
futures_index = {f: i for i, f in enumerate(futures)}
11101099
completed_futures: SortedList[Future[pa.Table]] = SortedList(iterable=[], key=lambda f: futures_index[f])
11111100
for future in concurrent.futures.as_completed(futures):
11121101
completed_futures.add(future)
1113-
1102+
if table_result := future.result():
1103+
total_row_count += len(table_result)
11141104
# stop early if limit is satisfied
1115-
if limit is not None and sum(row_counts) >= limit:
1105+
if limit is not None and total_row_count >= limit:
11161106
break
11171107

11181108
# by now, we've either completed all tasks or satisfied the limit

0 commit comments

Comments
 (0)