Skip to content

Commit 51ff463

Browse files
authored
refactor(libcommon): remove Indexer in favor of local caching in the rows endpoint (#3251)
* refactor(libcommon): remove the effectively unused arguments of `Indexer` * style: remove unnecessarry imports * refactor(libcommon): remove `Indexer` * refactor(services): directly create `RowsIndex` instead of `Indexer` * test(libcommon): fix `test_rows_index_query_with_empty_dataset` to use `ds_empty` * chore: missing import and mypy types * style: fix import order * fix(libcommon): cache the latest instance of `RowsIndex` * test(libcommon): add a test for caching the latest RowsIndex instance * fix(libcommon): only cache RowsIndex when serving from the rows endpoint * test(libcommon): remove previously added test case for caching RowIndex instances * chore: missing type annotations
1 parent c7d0081 commit 51ff463

File tree

4 files changed

+117
-106
lines changed

4 files changed

+117
-106
lines changed

libs/libcommon/src/libcommon/parquet_utils.py

Lines changed: 0 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -601,31 +601,3 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
601601
f" split={self.split}, offset={offset}, length={length}, with truncated binary"
602602
)
603603
return self.parquet_index.query_truncated_binary(offset=offset, length=length)
604-
605-
606-
class Indexer:
607-
def __init__(
608-
self,
609-
parquet_metadata_directory: StrPath,
610-
httpfs: HTTPFileSystem,
611-
max_arrow_data_in_memory: int,
612-
):
613-
self.parquet_metadata_directory = parquet_metadata_directory
614-
self.httpfs = httpfs
615-
self.max_arrow_data_in_memory = max_arrow_data_in_memory
616-
617-
@lru_cache(maxsize=1)
618-
def get_rows_index(
619-
self,
620-
dataset: str,
621-
config: str,
622-
split: str,
623-
) -> RowsIndex:
624-
return RowsIndex(
625-
dataset=dataset,
626-
config=config,
627-
split=split,
628-
httpfs=self.httpfs,
629-
parquet_metadata_directory=self.parquet_metadata_directory,
630-
max_arrow_data_in_memory=self.max_arrow_data_in_memory,
631-
)

libs/libcommon/tests/test_parquet_utils.py

Lines changed: 80 additions & 54 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,6 @@
1717
from fsspec.implementations.http import HTTPFileSystem
1818

1919
from libcommon.parquet_utils import (
20-
Indexer,
2120
ParquetIndexWithMetadata,
2221
RowsIndex,
2322
SchemaMismatchError,
@@ -346,56 +345,25 @@ def dataset_image_with_config_parquet() -> dict[str, Any]:
346345
return config_parquet_content
347346

348347

348+
# TODO(kszucs): this fixture is used in a single test case, but the tests starts
349+
# to fail if I move the index creation there.
349350
@pytest.fixture
350351
def rows_index_with_parquet_metadata(
351-
indexer: Indexer,
352352
ds_sharded: Dataset,
353353
ds_sharded_fs: AbstractFileSystem,
354354
dataset_sharded_with_config_parquet_metadata: dict[str, Any],
355-
) -> Generator[RowsIndex, None, None]:
356-
with ds_sharded_fs.open("default/train/0003.parquet") as f:
357-
with patch("libcommon.parquet_utils.HTTPFile", return_value=f):
358-
yield indexer.get_rows_index("ds_sharded", "default", "train")
359-
360-
361-
@pytest.fixture
362-
def rows_index_with_empty_dataset(
363-
indexer: Indexer,
364-
ds_empty: Dataset,
365-
ds_empty_fs: AbstractFileSystem,
366-
dataset_empty_with_config_parquet_metadata: dict[str, Any],
367-
) -> Generator[RowsIndex, None, None]:
368-
with ds_empty_fs.open("default/train/0000.parquet") as f:
369-
with patch("libcommon.parquet_utils.HTTPFile", return_value=f):
370-
yield indexer.get_rows_index("ds_empty", "default", "train")
371-
372-
373-
@pytest.fixture
374-
def rows_index_with_too_big_rows(
375355
parquet_metadata_directory: StrPath,
376-
ds_sharded: Dataset,
377-
ds_sharded_fs: AbstractFileSystem,
378-
dataset_sharded_with_config_parquet_metadata: dict[str, Any],
379356
) -> Generator[RowsIndex, None, None]:
380-
indexer = Indexer(
381-
parquet_metadata_directory=parquet_metadata_directory,
382-
httpfs=HTTPFileSystem(),
383-
max_arrow_data_in_memory=1,
384-
)
385357
with ds_sharded_fs.open("default/train/0003.parquet") as f:
386358
with patch("libcommon.parquet_utils.HTTPFile", return_value=f):
387-
yield indexer.get_rows_index("ds_sharded", "default", "train")
388-
389-
390-
@pytest.fixture
391-
def indexer(
392-
parquet_metadata_directory: StrPath,
393-
) -> Indexer:
394-
return Indexer(
395-
parquet_metadata_directory=parquet_metadata_directory,
396-
httpfs=HTTPFileSystem(),
397-
max_arrow_data_in_memory=9999999999,
398-
)
359+
yield RowsIndex(
360+
dataset="ds_sharded",
361+
config="default",
362+
split="train",
363+
parquet_metadata_directory=parquet_metadata_directory,
364+
httpfs=HTTPFileSystem(),
365+
max_arrow_data_in_memory=9999999999,
366+
)
399367

400368

401369
def test_parquet_export_is_partial() -> None:
@@ -411,11 +379,22 @@ def test_parquet_export_is_partial() -> None:
411379

412380

413381
def test_indexer_get_rows_index_with_parquet_metadata(
414-
indexer: Indexer, ds: Dataset, ds_fs: AbstractFileSystem, dataset_with_config_parquet_metadata: dict[str, Any]
382+
ds: Dataset,
383+
ds_fs: AbstractFileSystem,
384+
parquet_metadata_directory: StrPath,
385+
dataset_with_config_parquet_metadata: dict[str, Any],
415386
) -> None:
416387
with ds_fs.open("default/train/0000.parquet") as f:
417388
with patch("libcommon.parquet_utils.HTTPFile", return_value=f):
418-
index = indexer.get_rows_index("ds", "default", "train")
389+
index = RowsIndex(
390+
dataset="ds",
391+
config="default",
392+
split="train",
393+
parquet_metadata_directory=parquet_metadata_directory,
394+
httpfs=HTTPFileSystem(),
395+
max_arrow_data_in_memory=9999999999,
396+
)
397+
419398
assert isinstance(index.parquet_index, ParquetIndexWithMetadata)
420399
assert index.parquet_index.features == ds.features
421400
assert index.parquet_index.num_rows == [len(ds)]
@@ -429,15 +408,23 @@ def test_indexer_get_rows_index_with_parquet_metadata(
429408

430409

431410
def test_indexer_get_rows_index_sharded_with_parquet_metadata(
432-
indexer: Indexer,
433411
ds: Dataset,
434412
ds_sharded: Dataset,
435413
ds_sharded_fs: AbstractFileSystem,
414+
parquet_metadata_directory: StrPath,
436415
dataset_sharded_with_config_parquet_metadata: dict[str, Any],
437416
) -> None:
438417
with ds_sharded_fs.open("default/train/0003.parquet") as f:
439418
with patch("libcommon.parquet_utils.HTTPFile", return_value=f):
440-
index = indexer.get_rows_index("ds_sharded", "default", "train")
419+
index = RowsIndex(
420+
dataset="ds_sharded",
421+
config="default",
422+
split="train",
423+
parquet_metadata_directory=parquet_metadata_directory,
424+
httpfs=HTTPFileSystem(),
425+
max_arrow_data_in_memory=9999999999,
426+
)
427+
441428
assert isinstance(index.parquet_index, ParquetIndexWithMetadata)
442429
assert index.parquet_index.features == ds_sharded.features
443430
assert index.parquet_index.num_rows == [len(ds)] * 4
@@ -463,28 +450,67 @@ def test_rows_index_query_with_parquet_metadata(
463450
rows_index_with_parquet_metadata.query(offset=-1, length=2)
464451

465452

466-
def test_rows_index_query_with_too_big_rows(rows_index_with_too_big_rows: RowsIndex, ds_sharded: Dataset) -> None:
453+
def test_rows_index_query_with_too_big_rows(
454+
parquet_metadata_directory: StrPath,
455+
ds_sharded: Dataset,
456+
ds_sharded_fs: AbstractFileSystem,
457+
dataset_sharded_with_config_parquet_metadata: dict[str, Any],
458+
) -> None:
459+
with ds_sharded_fs.open("default/train/0003.parquet") as f:
460+
with patch("libcommon.parquet_utils.HTTPFile", return_value=f):
461+
index = RowsIndex(
462+
dataset="ds_sharded",
463+
config="default",
464+
split="train",
465+
parquet_metadata_directory=parquet_metadata_directory,
466+
httpfs=HTTPFileSystem(),
467+
max_arrow_data_in_memory=1,
468+
)
469+
467470
with pytest.raises(TooBigRows):
468-
rows_index_with_too_big_rows.query(offset=0, length=3)
471+
index.query(offset=0, length=3)
469472

470473

471-
def test_rows_index_query_with_empty_dataset(rows_index_with_empty_dataset: RowsIndex, ds_sharded: Dataset) -> None:
472-
assert isinstance(rows_index_with_empty_dataset.parquet_index, ParquetIndexWithMetadata)
473-
assert rows_index_with_empty_dataset.query(offset=0, length=1).to_pydict() == ds_sharded[:0]
474+
def test_rows_index_query_with_empty_dataset(
475+
ds_empty: Dataset,
476+
ds_empty_fs: AbstractFileSystem,
477+
dataset_empty_with_config_parquet_metadata: dict[str, Any],
478+
parquet_metadata_directory: StrPath,
479+
) -> None:
480+
with ds_empty_fs.open("default/train/0000.parquet") as f:
481+
with patch("libcommon.parquet_utils.HTTPFile", return_value=f):
482+
index = RowsIndex(
483+
dataset="ds_empty",
484+
config="default",
485+
split="train",
486+
parquet_metadata_directory=parquet_metadata_directory,
487+
httpfs=HTTPFileSystem(),
488+
max_arrow_data_in_memory=9999999999,
489+
)
490+
491+
assert isinstance(index.parquet_index, ParquetIndexWithMetadata)
492+
assert index.query(offset=0, length=1).to_pydict() == ds_empty[:0]
474493
with pytest.raises(IndexError):
475-
rows_index_with_empty_dataset.query(offset=-1, length=2)
494+
index.query(offset=-1, length=2)
476495

477496

478497
def test_indexer_schema_mistmatch_error(
479-
indexer: Indexer,
480498
ds_sharded_fs: AbstractFileSystem,
481499
ds_sharded_fs_with_different_schema: AbstractFileSystem,
482500
dataset_sharded_with_config_parquet_metadata: dict[str, Any],
501+
parquet_metadata_directory: StrPath,
483502
) -> None:
484503
with ds_sharded_fs_with_different_schema.open("default/train/0000.parquet") as first_parquet:
485504
with ds_sharded_fs_with_different_schema.open("default/train/0001.parquet") as second_parquet:
486505
with patch("libcommon.parquet_utils.HTTPFile", side_effect=[first_parquet, second_parquet]):
487-
index = indexer.get_rows_index("ds_sharded", "default", "train")
506+
index = RowsIndex(
507+
dataset="ds_sharded",
508+
config="default",
509+
split="train",
510+
parquet_metadata_directory=parquet_metadata_directory,
511+
httpfs=HTTPFileSystem(),
512+
max_arrow_data_in_memory=9999999999,
513+
)
488514
with pytest.raises(SchemaMismatchError):
489515
index.query(offset=0, length=3)
490516

services/rows/src/rows/routes/rows.py

Lines changed: 19 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22
# Copyright 2022 The HuggingFace Authors.
33

44
import logging
5+
from functools import lru_cache
56
from http import HTTPStatus
67
from typing import Optional
78

@@ -22,7 +23,7 @@
2223
try_backfill_dataset_then_raise,
2324
)
2425
from libcommon.constants import CONFIG_PARQUET_METADATA_KIND
25-
from libcommon.parquet_utils import Indexer, TooBigRows
26+
from libcommon.parquet_utils import RowsIndex, TooBigRows
2627
from libcommon.prometheus import StepProfiler
2728
from libcommon.simple_cache import CachedArtifactError, CachedArtifactNotFoundError
2829
from libcommon.storage import StrPath
@@ -48,14 +49,24 @@ def create_rows_endpoint(
4849
max_age_short: int = 0,
4950
storage_clients: Optional[list[StorageClient]] = None,
5051
) -> Endpoint:
51-
indexer = Indexer(
52-
parquet_metadata_directory=parquet_metadata_directory,
53-
httpfs=HTTPFileSystem(headers={"authorization": f"Bearer {hf_token}"}),
54-
max_arrow_data_in_memory=max_arrow_data_in_memory,
55-
)
52+
httpfs = HTTPFileSystem(headers={"authorization": f"Bearer {hf_token}"})
53+
54+
@lru_cache(maxsize=1)
55+
def get_rows_index(dataset: str, config: str, split: str) -> RowsIndex:
56+
# cache the RowsIndex instance and therefore save one call to Mongo
57+
# if multiple queries to the same dataset are done in a row (90% of
58+
# requests in a short time window are to the same dataset)
59+
return RowsIndex(
60+
dataset=dataset,
61+
config=config,
62+
split=split,
63+
httpfs=httpfs,
64+
max_arrow_data_in_memory=max_arrow_data_in_memory,
65+
parquet_metadata_directory=parquet_metadata_directory,
66+
)
5667

5768
async def rows_endpoint(request: Request) -> Response:
58-
await indexer.httpfs.set_session()
69+
await httpfs.set_session()
5970
revision: Optional[str] = None
6071
with StepProfiler(method="rows_endpoint", step="all"):
6172
try:
@@ -84,11 +95,7 @@ async def rows_endpoint(request: Request) -> Response:
8495
)
8596
try:
8697
with StepProfiler(method="rows_endpoint", step="get row groups index"):
87-
rows_index = indexer.get_rows_index(
88-
dataset=dataset,
89-
config=config,
90-
split=split,
91-
)
98+
rows_index = get_rows_index(dataset=dataset, config=config, split=split)
9299
revision = rows_index.revision
93100
with StepProfiler(method="rows_endpoint", step="query the rows"):
94101
try:

services/worker/src/worker/job_runners/split/first_rows.py

Lines changed: 18 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
SplitParquetSchemaMismatchError,
1919
TooBigContentError,
2020
)
21-
from libcommon.parquet_utils import EmptyParquetMetadataError, Indexer, SchemaMismatchError, TooBigRows
21+
from libcommon.parquet_utils import EmptyParquetMetadataError, RowsIndex, SchemaMismatchError, TooBigRows
2222
from libcommon.simple_cache import CachedArtifactError, CachedArtifactNotFoundError
2323
from libcommon.storage import StrPath
2424
from libcommon.storage_client import StorageClient
@@ -41,7 +41,9 @@ def compute_first_rows_from_parquet_response(
4141
rows_max_number: int,
4242
rows_min_number: int,
4343
columns_max_number: int,
44-
indexer: Indexer,
44+
httpfs: HTTPFileSystem,
45+
max_arrow_data_in_memory: int,
46+
parquet_metadata_directory: StrPath,
4547
) -> SplitFirstRowsResponse:
4648
"""
4749
Compute the response of 'split-first-rows' for one specific split of a dataset from the parquet files.
@@ -67,8 +69,12 @@ def compute_first_rows_from_parquet_response(
6769
The minimum number of rows of the response.
6870
columns_max_number (`int`):
6971
The maximum number of columns supported.
70-
indexer (`Indexer`):
71-
An indexer to get the rows index.
72+
httpfs (`HTTPFileSystem`):
73+
An HTTP filesystem to access the parquet files.
74+
parquet_metadata_directory (`StrPath`):
75+
The local directory where the parquet metadata are stored.
76+
max_arrow_data_in_memory (`int`):
77+
The maximum size in bytes of Arrow data loaded in memory.
7278
7379
Raises:
7480
[~`libcommon.exceptions.ParquetResponseEmptyError`]:
@@ -85,10 +91,13 @@ def compute_first_rows_from_parquet_response(
8591
logging.info(f"compute 'split-first-rows' from parquet for {dataset=} {config=} {split=}")
8692

8793
try:
88-
rows_index = indexer.get_rows_index(
94+
rows_index = RowsIndex(
8995
dataset=dataset,
9096
config=config,
9197
split=split,
98+
httpfs=httpfs,
99+
max_arrow_data_in_memory=max_arrow_data_in_memory,
100+
parquet_metadata_directory=parquet_metadata_directory,
92101
)
93102
except EmptyParquetMetadataError:
94103
raise ParquetResponseEmptyError("No parquet files found.")
@@ -272,7 +281,6 @@ def get_rows_content(rows_max_number: int) -> RowsContent:
272281

273282
class SplitFirstRowsJobRunner(SplitJobRunnerWithDatasetsCache):
274283
first_rows_config: FirstRowsConfig
275-
indexer: Indexer
276284

277285
@staticmethod
278286
def get_job_type() -> str:
@@ -293,11 +301,7 @@ def __init__(
293301
)
294302
self.first_rows_config = app_config.first_rows
295303
self.parquet_metadata_directory = parquet_metadata_directory
296-
self.indexer = Indexer(
297-
parquet_metadata_directory=parquet_metadata_directory,
298-
httpfs=HTTPFileSystem(headers={"authorization": f"Bearer {self.app_config.common.hf_token}"}),
299-
max_arrow_data_in_memory=app_config.rows_index.max_arrow_data_in_memory,
300-
)
304+
self.httpfs = HTTPFileSystem(headers={"authorization": f"Bearer {self.app_config.common.hf_token}"})
301305
self.storage_client = storage_client
302306

303307
def compute(self) -> CompleteJobResult:
@@ -314,7 +318,9 @@ def compute(self) -> CompleteJobResult:
314318
rows_min_number=self.first_rows_config.min_number,
315319
rows_max_number=MAX_NUM_ROWS_PER_PAGE,
316320
columns_max_number=self.first_rows_config.columns_max_number,
317-
indexer=self.indexer,
321+
httpfs=self.httpfs,
322+
max_arrow_data_in_memory=self.app_config.rows_index.max_arrow_data_in_memory,
323+
parquet_metadata_directory=self.parquet_metadata_directory,
318324
)
319325
)
320326
except (

0 commit comments

Comments
 (0)