Skip to content

Commit 4df749c

Browse files
committed
feat: raise TooBigRows exceptions if the scan size would exceed a limit
1 parent 8104e67 commit 4df749c

File tree

8 files changed

+135
-37
lines changed

8 files changed

+135
-37
lines changed

libs/libcommon/src/libcommon/parquet_utils.py

Lines changed: 16 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@
2626
from libcommon.viewer_utils.features import get_supported_unsupported_columns
2727

2828
try:
29-
from libviewer import Dataset as LibviewerDataset # type: ignore [import-untyped]
29+
import libviewer as lv # type: ignore [import-untyped]
3030

3131
_has_libviewer = True
3232
except ImportError:
@@ -522,12 +522,14 @@ def __init__(
522522
hf_token: Optional[str],
523523
parquet_metadata_directory: StrPath,
524524
max_arrow_data_in_memory: int,
525+
max_scan_size: int,
525526
unsupported_features: Sequence[FeatureType] = (),
526527
data_store: str = "hf://",
527528
):
528529
self.dataset = dataset
529530
self.config = config
530531
self.split = split
532+
self.max_scan_size = max_scan_size
531533

532534
self._init_dataset_info(parquet_metadata_directory)
533535
self._init_parquet_index(
@@ -608,7 +610,7 @@ def _init_viewer_index(self, data_store: str, metadata_store: str) -> None:
608610
}
609611
)
610612

611-
self.viewer_index = LibviewerDataset(
613+
self.viewer_index = lv.Dataset(
612614
name=self.dataset,
613615
files=files,
614616
revision=self.revision,
@@ -649,7 +651,17 @@ def query_with_page_pruning(self, offset: int, length: int) -> pa.Table:
649651
raise IndexError("Offset must be non-negative")
650652
if length < 0:
651653
raise IndexError("Length must be non-negative")
652-
batches, _files_to_index = self.viewer_index.sync_scan(offset=offset, limit=length)
654+
655+
try:
656+
batches, _files_to_index = self.viewer_index.sync_scan(
657+
offset=offset, limit=length, scan_size_limit=self.max_scan_size
658+
)
659+
except lv.DatasetError as e:
660+
if "Scan size limit exceeded" in str(e):
661+
raise TooBigRows(str(e)) from e
662+
else:
663+
raise
664+
653665
return pa.Table.from_batches(batches, schema=self.features.arrow_schema)
654666

655667
# note that this cache size is global for the class, not per instance
@@ -707,6 +719,7 @@ def get_rows_index(self, dataset: str, config: str, split: str, data_store: str
707719
hf_token=self.hf_token,
708720
parquet_metadata_directory=self.parquet_metadata_directory,
709721
max_arrow_data_in_memory=self.max_arrow_data_in_memory,
722+
max_scan_size=self.max_arrow_data_in_memory,
710723
unsupported_features=unsupported_features,
711724
data_store=data_store,
712725
)

libs/libcommon/tests/test_parquet_utils.py

Lines changed: 19 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -368,7 +368,8 @@ def rows_index_with_empty_dataset(
368368
) -> Generator[RowsIndex, None, None]:
369369
with ds_empty_fs.open("default/train/0000.parquet") as f:
370370
with patch("libcommon.parquet_utils.HTTPFile", return_value=f):
371-
yield indexer.get_rows_index("ds_empty", "default", "train")
371+
data_store = f"file://{ds_empty_fs.local_root_dir}"
372+
yield indexer.get_rows_index("ds_empty", "default", "train", data_store=data_store)
372373

373374

374375
@pytest.fixture
@@ -386,7 +387,8 @@ def rows_index_with_too_big_rows(
386387
)
387388
with ds_sharded_fs.open("default/train/0003.parquet") as f:
388389
with patch("libcommon.parquet_utils.HTTPFile", return_value=f):
389-
yield indexer.get_rows_index("ds_sharded", "default", "train")
390+
data_store = f"file://{ds_sharded_fs.local_root_dir}"
391+
yield indexer.get_rows_index("ds_sharded", "default", "train", data_store=data_store)
390392

391393

392394
@pytest.fixture
@@ -465,24 +467,18 @@ def test_rows_index_query_with_parquet_metadata(
465467
with pytest.raises(IndexError):
466468
rows_index_with_parquet_metadata.query(offset=-1, length=2)
467469

470+
# test the same with page pruning API
471+
import libviewer as lv # type: ignore [import-untyped]
468472

469-
def test_rows_index_query_with_page_pruning(rows_index_with_parquet_metadata: RowsIndex, ds_sharded: Dataset) -> None:
470-
from libviewer import Dataset as LibviewerDataset # type: ignore [import-untyped]
471-
472-
assert isinstance(rows_index_with_parquet_metadata.viewer_index, LibviewerDataset)
473-
473+
assert isinstance(rows_index_with_parquet_metadata.viewer_index, lv.Dataset)
474474
result = rows_index_with_parquet_metadata.query_with_page_pruning(offset=1, length=3)
475475
assert result.to_pydict() == ds_sharded[1:4]
476-
477476
result = rows_index_with_parquet_metadata.query_with_page_pruning(offset=1, length=0)
478477
assert result.to_pydict() == ds_sharded[:0]
479-
480478
result = rows_index_with_parquet_metadata.query_with_page_pruning(offset=999999, length=1)
481479
assert result.to_pydict() == ds_sharded[:0]
482-
483480
result = rows_index_with_parquet_metadata.query_with_page_pruning(offset=1, length=99999999)
484481
assert result.to_pydict() == ds_sharded[1:]
485-
486482
with pytest.raises(IndexError):
487483
rows_index_with_parquet_metadata.query_with_page_pruning(offset=0, length=-1)
488484
with pytest.raises(IndexError):
@@ -493,13 +489,25 @@ def test_rows_index_query_with_too_big_rows(rows_index_with_too_big_rows: RowsIn
493489
with pytest.raises(TooBigRows):
494490
rows_index_with_too_big_rows.query(offset=0, length=3)
495491

492+
# test the same with page pruning API
493+
with pytest.raises(TooBigRows):
494+
rows_index_with_too_big_rows.query_with_page_pruning(offset=0, length=2)
495+
496496

497497
def test_rows_index_query_with_empty_dataset(rows_index_with_empty_dataset: RowsIndex, ds_sharded: Dataset) -> None:
498498
assert isinstance(rows_index_with_empty_dataset.parquet_index, ParquetIndexWithMetadata)
499499
assert rows_index_with_empty_dataset.query(offset=0, length=1).to_pydict() == ds_sharded[:0]
500500
with pytest.raises(IndexError):
501501
rows_index_with_empty_dataset.query(offset=-1, length=2)
502502

503+
# test the same with page pruning API
504+
import libviewer as lv # type: ignore [import-untyped]
505+
assert isinstance(rows_index_with_empty_dataset.viewer_index, lv.Dataset)
506+
result = rows_index_with_empty_dataset.query_with_page_pruning(offset=0, length=1)
507+
assert result.to_pydict() == ds_sharded[:0]
508+
with pytest.raises(IndexError):
509+
rows_index_with_empty_dataset.query_with_page_pruning(offset=-1, length=2)
510+
503511

504512
def test_indexer_schema_mistmatch_error(
505513
indexer: Indexer,

libs/libviewer/Cargo.lock

Lines changed: 1 addition & 0 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

libs/libviewer/Cargo.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ crate-type = ["cdylib"]
1010

1111
[dependencies]
1212
arrow = { version = "56.2", features = ["pyarrow"] }
13+
bytes = "1.10.1"
1314
futures = "0.3"
1415
object_store = "0.12.0"
1516
object_store_opendal = "0.52.0"

libs/libviewer/libviewer/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,7 @@
22

33
from huggingface_hub import hf_hub_download, list_repo_files
44

5-
from ._internal import PyDataset
5+
from ._internal import PyDataset, PyDatasetError as DatasetError
66

77

88
__all__ = ["Dataset"]

libs/libviewer/src/dataset.rs

Lines changed: 14 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -30,7 +30,7 @@ pub enum DatasetError {
3030
#[error("Arrow error: {0}")]
3131
Arrow(#[from] arrow::error::ArrowError),
3232

33-
#[error("Parquet error: {0}")]
33+
#[error("{0}")]
3434
Parquet(#[from] ::parquet::errors::ParquetError),
3535

3636
#[error("Object store error: {0}")]
@@ -58,7 +58,7 @@ impl ParquetScan {
5858
self.metadata.offset_index().is_some()
5959
}
6060

61-
fn estimate_scan_size(&self) -> i64 {
61+
fn estimate_scan_size(&self) -> u64 {
6262
let mut scan_size = 0;
6363
let mut rows_to_skip = self.offset;
6464
let mut rows_needed = self.limit;
@@ -78,7 +78,7 @@ impl ParquetScan {
7878

7979
// Accumulate the size if we need to scan any rows from this row group
8080
if rows_to_read > 0 {
81-
scan_size += row_group.compressed_size();
81+
scan_size += row_group.compressed_size() as u64;
8282
rows_needed -= rows_to_read;
8383
}
8484

@@ -94,7 +94,7 @@ impl ParquetScan {
9494
scan_size
9595
}
9696

97-
fn shall_index(&self, scan_size_limit: i64) -> bool {
97+
fn shall_index(&self, scan_size_limit: u64) -> bool {
9898
// TODO(kszucs): reconsider the case when we want to index:
9999
// 1. file reads with row groups larger than the scan size limit
100100
// 2. file reads with overall size larger than the scan size limit (multiple row groups)
@@ -107,13 +107,18 @@ impl ParquetScan {
107107
self.estimate_scan_size() > scan_size_limit
108108
}
109109

110-
async fn execute(&self, data_store: Arc<dyn ObjectStore>) -> Result<Vec<RecordBatch>> {
110+
async fn execute(
111+
&self,
112+
data_store: Arc<dyn ObjectStore>,
113+
scan_size_limit: u64,
114+
) -> Result<Vec<RecordBatch>> {
111115
let stream = read_batch_stream(
112116
data_store,
113117
self.file.path.clone(),
114118
self.metadata.clone(),
115119
self.offset,
116120
self.limit,
121+
scan_size_limit,
117122
)?;
118123
Ok(stream.try_collect::<Vec<_>>().await?)
119124
}
@@ -159,7 +164,7 @@ pub struct Dataset {
159164
metadata_store: Arc<dyn ObjectStore>,
160165
pub metadata_store_uri: String,
161166
/// Scan size limit for triggering indexing
162-
indexing_size_threshold: i64,
167+
indexing_size_threshold: u64,
163168
}
164169

165170
impl Dataset {
@@ -169,7 +174,7 @@ impl Dataset {
169174
revision: Option<&str>,
170175
data_uri: &str,
171176
metadata_uri: &str,
172-
indexing_size_threshold: i64,
177+
indexing_size_threshold: u64,
173178
) -> Result<Self> {
174179
Ok(Self {
175180
name: name.to_string(),
@@ -265,6 +270,7 @@ impl Dataset {
265270
&self,
266271
limit: Option<u64>,
267272
offset: Option<u64>,
273+
scan_size_limit: u64,
268274
) -> Result<(Vec<RecordBatch>, Vec<IndexedFile>)> {
269275
// 1. create an object reader for each file in the access plan
270276
// 2. generate a stream of record batches from each reader
@@ -280,7 +286,7 @@ impl Dataset {
280286

281287
let tasks = plan.into_iter().map(|scan| {
282288
let data_store = self.data_store.clone();
283-
task::spawn(async move { scan.execute(data_store).await })
289+
task::spawn(async move { scan.execute(data_store, scan_size_limit).await })
284290
});
285291
let results = future::try_join_all(tasks).await?;
286292
let batches = results

libs/libviewer/src/lib.rs

Lines changed: 17 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,18 +2,22 @@ mod dataset;
22
mod parquet;
33

44
use arrow::pyarrow::IntoPyArrow;
5-
use pyo3::exceptions::PyValueError;
5+
use pyo3::create_exception;
66
use pyo3::prelude::*;
77
use pyo3_async_runtimes;
8+
89
use tokio;
910

1011
use crate::dataset::{Dataset, DatasetError};
1112

12-
const INDEXING_SIZE_THRESHOLD: i64 = 100 * 1024 * 1024; // 100 MiB
13+
const INDEXING_SIZE_THRESHOLD: u64 = 100 * 1024 * 1024; // 100 MiB
14+
const DEFAULT_SCAN_SIZE_LIMIT: u64 = 1 * 1024 * 1024 * 1024; // 1 GiB
15+
16+
create_exception!(libviewer, PyDatasetError, pyo3::exceptions::PyException);
1317

1418
impl From<DatasetError> for PyErr {
1519
fn from(err: DatasetError) -> Self {
16-
PyValueError::new_err(err.to_string())
20+
PyDatasetError::new_err(err.to_string())
1721
}
1822
}
1923

@@ -46,7 +50,7 @@ impl PyDataset {
4650
metadata_store: &str,
4751
data_store: &str,
4852
revision: Option<&str>,
49-
indexing_size_threshold: i64,
53+
indexing_size_threshold: u64,
5054
) -> PyResult<Self> {
5155
let dataset = Dataset::try_new(
5256
name,
@@ -88,32 +92,36 @@ impl PyDataset {
8892
Ok(&self.dataset.metadata_store_uri)
8993
}
9094

91-
#[pyo3(signature = (limit=None, offset=None))]
95+
#[pyo3(signature = (limit=None, offset=None, scan_size_limit=DEFAULT_SCAN_SIZE_LIMIT))]
9296
fn sync_scan(
9397
&self,
9498
py: Python<'_>,
9599
limit: Option<u64>,
96100
offset: Option<u64>,
101+
scan_size_limit: u64,
97102
) -> PyResult<(Vec<PyObject>, Vec<IndexedFile>)> {
98103
let rt = tokio::runtime::Runtime::new()?;
99-
let (record_batches, files_to_index) = rt.block_on(self.dataset.scan(limit, offset))?;
104+
let (record_batches, files_to_index) =
105+
rt.block_on(self.dataset.scan(limit, offset, scan_size_limit))?;
100106
let pyarrow_batches = record_batches
101107
.into_iter()
102108
.map(|batch| batch.into_pyarrow(py))
103109
.collect::<PyResult<Vec<_>>>()?;
104110
Ok((pyarrow_batches, files_to_index))
105111
}
106112

107-
#[pyo3(signature = (limit=None, offset=None))]
113+
#[pyo3(signature = (limit=None, offset=None, scan_size_limit=DEFAULT_SCAN_SIZE_LIMIT))]
108114
fn scan<'py>(
109115
&self,
110116
py: Python<'py>,
111117
limit: Option<u64>,
112118
offset: Option<u64>,
119+
scan_size_limit: u64,
113120
) -> PyResult<Bound<'py, PyAny>> {
114121
let this = self.clone();
115122
pyo3_async_runtimes::tokio::future_into_py(py, async move {
116-
let (record_batches, files_to_index) = this.dataset.scan(limit, offset).await?;
123+
let (record_batches, files_to_index) =
124+
this.dataset.scan(limit, offset, scan_size_limit).await?;
117125
let pyarrow_batches = Python::with_gil(|py| {
118126
record_batches
119127
.into_iter()
@@ -151,5 +159,6 @@ impl PyDataset {
151159
#[pyo3(name = "_internal")]
152160
fn dv(m: &Bound<'_, PyModule>) -> PyResult<()> {
153161
m.add_class::<PyDataset>()?;
162+
m.add("PyDatasetError", m.py().get_type::<PyDatasetError>())?;
154163
Ok(())
155164
}

0 commit comments

Comments
 (0)