Skip to content

Commit dd1630b

Browse files
authored
refactor(libcommon): consolidate query() and query_truncated_binary() methods (#3253)
* refactor(libcommon): remove unused `hf_token` in `parquet_utils` (#3249) refactor(libcommon): remove unused hf_token in `parquet_utils` refactor(libcommon): remove the effectively unused arguments of `Indexer` (#3250) * refactor(libcommon): remove the effectively unused arguments of `Indexer` * style: remove unnecessarry imports * refactor(libcommon): remove `unsupported_features` argument from `RowsIndex` * style: remove unnecessarry imports refactor(libcommon): remove the effectively unused arguments of `Indexer` style: remove unnecessarry imports refactor(libcommon): remove `Indexer` refactor(services): directly create `RowsIndex` instead of `Indexer` test(libcommon): fix `test_rows_index_query_with_empty_dataset` to use `ds_empty` chore: missing import and mypy types style: fix import order fix(libcommon): cache the latest instance of `RowsIndex` test(libcommon): add a test for caching the latest RowsIndex instance fix(libcommon): only cache RowsIndex when serving from the rows endpoint test(libcommon): remove previously added test case for caching RowIndex instances chore: missing type annotations refactor(libcommon): remove now obsolete `get_supported_unsupported_columns\(\)` function style(libcommon): remove unnecessary imports chore: remove reduntant iterations * refactor(libcommon): consolidate query() and query_truncated_binary() methods * refactor(libcommon): pull out truncate_binary_columns() as a function to be used later by libviewer * style: fix mypy errors * chore: expect pathlib.Path rather than string path * style: fix ruff check error
1 parent f0ba03f commit dd1630b

File tree

6 files changed

+134
-201
lines changed

6 files changed

+134
-201
lines changed

libs/libcommon/src/libcommon/parquet_utils.py

Lines changed: 78 additions & 164 deletions
Original file line numberDiff line numberDiff line change
@@ -138,11 +138,30 @@ def is_list_pa_type(parquet_file_path: Path, feature_name: str) -> bool:
138138
return is_list
139139

140140

141+
def truncate_binary_columns(table: pa.Table, max_binary_length: int, features: Features) -> tuple[pa.Table, list[str]]:
142+
# truncate binary columns in the Arrow table to the specified maximum length
143+
# return a new Arrow table and the list of truncated columns
144+
if max_binary_length < 0:
145+
return table, []
146+
147+
columns: dict[str, pa.Array] = {}
148+
truncated_column_names: list[str] = []
149+
for field_idx, field in enumerate(table.schema): # noqa: F402
150+
if features[field.name] == Value("binary") and table[field_idx].nbytes > max_binary_length:
151+
truncated_array = pc.binary_slice(table[field_idx], 0, max_binary_length // len(table))
152+
columns[field.name] = truncated_array
153+
truncated_column_names.append(field.name)
154+
else:
155+
columns[field.name] = table[field_idx]
156+
157+
return pa.table(columns), truncated_column_names
158+
159+
141160
@dataclass
142161
class RowGroupReader:
143162
parquet_file: pq.ParquetFile
144163
group_id: int
145-
features: Features
164+
schema: pa.Schema
146165

147166
def read(self, columns: list[str]) -> pa.Table:
148167
if not set(self.parquet_file.schema_arrow.names) <= set(columns):
@@ -151,18 +170,7 @@ def read(self, columns: list[str]) -> pa.Table:
151170
)
152171
pa_table = self.parquet_file.read_row_group(i=self.group_id, columns=columns)
153172
# cast_table_to_schema adds null values to missing columns
154-
return cast_table_to_schema(pa_table, self.features.arrow_schema)
155-
156-
def read_truncated_binary(self, columns: list[str], max_binary_length: int) -> tuple[pa.Table, list[str]]:
157-
pa_table = self.parquet_file.read_row_group(i=self.group_id, columns=columns)
158-
truncated_columns: list[str] = []
159-
if max_binary_length:
160-
for field_idx, field in enumerate(pa_table.schema):
161-
if self.features[field.name] == Value("binary") and pa_table[field_idx].nbytes > max_binary_length:
162-
truncated_array = pc.binary_slice(pa_table[field_idx], 0, max_binary_length // len(pa_table))
163-
pa_table = pa_table.set_column(field_idx, field, truncated_array)
164-
truncated_columns.append(field.name)
165-
return cast_table_to_schema(pa_table, self.features.arrow_schema), truncated_columns
173+
return cast_table_to_schema(pa_table, self.schema)
166174

167175
def read_size(self, columns: Optional[Iterable[str]] = None) -> int:
168176
if columns is None:
@@ -179,32 +187,33 @@ def read_size(self, columns: Optional[Iterable[str]] = None) -> int:
179187

180188
@dataclass
181189
class ParquetIndexWithMetadata:
190+
files: list[ParquetFileMetadataItem]
182191
features: Features
183-
parquet_files_urls: list[str]
184-
metadata_paths: list[str]
185-
num_bytes: list[int]
186-
num_rows: list[int]
187192
httpfs: HTTPFileSystem
188193
max_arrow_data_in_memory: int
189194
partial: bool
195+
metadata_dir: Path
190196

197+
file_offsets: np.ndarray = field(init=False)
191198
num_rows_total: int = field(init=False)
192199

193200
def __post_init__(self) -> None:
194201
if self.httpfs._session is None:
195202
self.httpfs_session = asyncio.run(self.httpfs.set_session())
196203
else:
197204
self.httpfs_session = self.httpfs._session
198-
self.num_rows_total = sum(self.num_rows)
199205

200-
def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, list[str]]:
206+
num_rows = np.array([f["num_rows"] for f in self.files])
207+
self.file_offsets = np.cumsum(num_rows)
208+
self.num_rows_total = np.sum(num_rows)
209+
210+
def query(self, offset: int, length: int) -> tuple[pa.Table, list[str]]:
201211
"""Query the parquet files
202212
203213
Note that this implementation will always read at least one row group, to get the list of columns and always
204214
have the same schema, even if the requested rows are invalid (out of range).
205215
206-
This is the same as query() except that:
207-
216+
If binary columns are present, then:
208217
- it computes a maximum size to allocate to binary data in step "parquet_index_with_metadata.row_groups_size_check_truncated_binary"
209218
- it uses `read_truncated_binary()` in step "parquet_index_with_metadata.query_truncated_binary".
210219
@@ -219,27 +228,19 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
219228
`pa.Table`: The requested rows.
220229
`list[strl]: List of truncated columns.
221230
"""
222-
all_columns = set(self.features)
223-
binary_columns = set(column for column, feature in self.features.items() if feature == Value("binary"))
224-
if not binary_columns:
225-
return self.query(offset=offset, length=length), []
226231
with StepProfiler(
227232
method="parquet_index_with_metadata.query", step="get the parquet files than contain the requested rows"
228233
):
229-
parquet_file_offsets = np.cumsum(self.num_rows)
230-
231-
last_row_in_parquet = parquet_file_offsets[-1] - 1
234+
last_row_in_parquet = self.file_offsets[-1] - 1
232235
first_row = min(offset, last_row_in_parquet)
233236
last_row = min(offset + length - 1, last_row_in_parquet)
234237
first_parquet_file_id, last_parquet_file_id = np.searchsorted(
235-
parquet_file_offsets, [first_row, last_row], side="right"
238+
self.file_offsets, [first_row, last_row], side="right"
236239
)
237240
parquet_offset = (
238-
offset - parquet_file_offsets[first_parquet_file_id - 1] if first_parquet_file_id > 0 else offset
241+
offset - self.file_offsets[first_parquet_file_id - 1] if first_parquet_file_id > 0 else offset
239242
)
240-
urls = self.parquet_files_urls[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203
241-
metadata_paths = self.metadata_paths[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203
242-
num_bytes = self.num_bytes[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203
243+
files_to_scan = self.files[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203
243244

244245
with StepProfiler(
245246
method="parquet_index_with_metadata.query", step="load the remote parquet files using metadata from disk"
@@ -248,17 +249,17 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
248249
pq.ParquetFile(
249250
HTTPFile(
250251
self.httpfs,
251-
url,
252+
f["url"],
252253
session=self.httpfs_session,
253-
size=size,
254+
size=f["size"],
254255
loop=self.httpfs.loop,
255256
cache_type=None,
256257
**self.httpfs.kwargs,
257258
),
258-
metadata=pq.read_metadata(metadata_path),
259+
metadata=pq.read_metadata(self.metadata_dir / f["parquet_metadata_subpath"]),
259260
pre_buffer=True,
260261
)
261-
for url, metadata_path, size in zip(urls, metadata_paths, num_bytes)
262+
for f in files_to_scan
262263
]
263264

264265
with StepProfiler(
@@ -272,7 +273,7 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
272273
]
273274
)
274275
row_group_readers = [
275-
RowGroupReader(parquet_file=parquet_file, group_id=group_id, features=self.features)
276+
RowGroupReader(parquet_file=parquet_file, group_id=group_id, schema=self.features.arrow_schema)
276277
for parquet_file in parquet_files
277278
for group_id in range(parquet_file.metadata.num_row_groups)
278279
]
@@ -290,6 +291,28 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
290291
row_group_offsets, [first_row, last_row], side="right"
291292
)
292293

294+
all_columns = set(self.features)
295+
binary_columns = set(column for column, feature in self.features.items() if feature == Value("binary"))
296+
if binary_columns:
297+
pa_table, truncated_columns = self._read_with_binary(
298+
row_group_readers, first_row_group_id, last_row_group_id, all_columns, binary_columns
299+
)
300+
else:
301+
pa_table, truncated_columns = self._read_without_binary(
302+
row_group_readers, first_row_group_id, last_row_group_id
303+
)
304+
305+
first_row_in_pa_table = row_group_offsets[first_row_group_id - 1] if first_row_group_id > 0 else 0
306+
return pa_table.slice(parquet_offset - first_row_in_pa_table, length), truncated_columns
307+
308+
def _read_with_binary(
309+
self,
310+
row_group_readers: list[RowGroupReader],
311+
first_row_group_id: int,
312+
last_row_group_id: int,
313+
all_columns: set[str],
314+
binary_columns: set[str],
315+
) -> tuple[pa.Table, list[str]]:
293316
with StepProfiler(
294317
method="parquet_index_with_metadata.row_groups_size_check_truncated_binary",
295318
step="check if the rows can fit in memory",
@@ -329,100 +352,21 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
329352
columns = list(self.features.keys())
330353
truncated_columns: set[str] = set()
331354
for i in range(first_row_group_id, last_row_group_id + 1):
332-
rg_pa_table, rg_truncated_columns = row_group_readers[i].read_truncated_binary(
333-
columns, max_binary_length=max_binary_length
355+
rg_pa_table = row_group_readers[i].read(columns)
356+
rg_pa_table, rg_truncated_columns = truncate_binary_columns(
357+
rg_pa_table, max_binary_length, self.features
334358
)
335359
pa_tables.append(rg_pa_table)
336360
truncated_columns |= set(rg_truncated_columns)
337361
pa_table = pa.concat_tables(pa_tables)
338362
except ArrowInvalid as err:
339363
raise SchemaMismatchError("Parquet files have different schema.", err)
340-
first_row_in_pa_table = row_group_offsets[first_row_group_id - 1] if first_row_group_id > 0 else 0
341-
return pa_table.slice(parquet_offset - first_row_in_pa_table, length), list(truncated_columns)
342-
343-
def query(self, offset: int, length: int) -> pa.Table:
344-
"""Query the parquet files
345-
346-
Note that this implementation will always read at least one row group, to get the list of columns and always
347-
have the same schema, even if the requested rows are invalid (out of range).
348-
349-
Args:
350-
offset (`int`): The first row to read.
351-
length (`int`): The number of rows to read.
352-
353-
Raises:
354-
[`TooBigRows`]: if the arrow data from the parquet row groups is bigger than max_arrow_data_in_memory
355-
356-
Returns:
357-
`pa.Table`: The requested rows.
358-
"""
359-
with StepProfiler(
360-
method="parquet_index_with_metadata.query", step="get the parquet files than contain the requested rows"
361-
):
362-
parquet_file_offsets = np.cumsum(self.num_rows)
363364

364-
last_row_in_parquet = parquet_file_offsets[-1] - 1
365-
first_row = min(offset, last_row_in_parquet)
366-
last_row = min(offset + length - 1, last_row_in_parquet)
367-
first_parquet_file_id, last_parquet_file_id = np.searchsorted(
368-
parquet_file_offsets, [first_row, last_row], side="right"
369-
)
370-
parquet_offset = (
371-
offset - parquet_file_offsets[first_parquet_file_id - 1] if first_parquet_file_id > 0 else offset
372-
)
373-
urls = self.parquet_files_urls[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203
374-
metadata_paths = self.metadata_paths[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203
375-
num_bytes = self.num_bytes[first_parquet_file_id : last_parquet_file_id + 1] # noqa: E203
376-
377-
with StepProfiler(
378-
method="parquet_index_with_metadata.query", step="load the remote parquet files using metadata from disk"
379-
):
380-
parquet_files = [
381-
pq.ParquetFile(
382-
HTTPFile(
383-
self.httpfs,
384-
url,
385-
session=self.httpfs_session,
386-
size=size,
387-
loop=self.httpfs.loop,
388-
cache_type=None,
389-
**self.httpfs.kwargs,
390-
),
391-
metadata=pq.read_metadata(metadata_path),
392-
pre_buffer=True,
393-
)
394-
for url, metadata_path, size in zip(urls, metadata_paths, num_bytes)
395-
]
396-
397-
with StepProfiler(
398-
method="parquet_index_with_metadata.query", step="get the row groups than contain the requested rows"
399-
):
400-
row_group_offsets = np.cumsum(
401-
[
402-
parquet_file.metadata.row_group(group_id).num_rows
403-
for parquet_file in parquet_files
404-
for group_id in range(parquet_file.metadata.num_row_groups)
405-
]
406-
)
407-
row_group_readers = [
408-
RowGroupReader(parquet_file=parquet_file, group_id=group_id, features=self.features)
409-
for parquet_file in parquet_files
410-
for group_id in range(parquet_file.metadata.num_row_groups)
411-
]
412-
413-
if len(row_group_offsets) == 0 or row_group_offsets[-1] == 0: # if the dataset is empty
414-
if offset < 0:
415-
raise IndexError("Offset must be non-negative")
416-
return cast_table_to_schema(parquet_files[0].read(), self.features.arrow_schema)
417-
418-
last_row_in_parquet = row_group_offsets[-1] - 1
419-
first_row = min(parquet_offset, last_row_in_parquet)
420-
last_row = min(parquet_offset + length - 1, last_row_in_parquet)
421-
422-
first_row_group_id, last_row_group_id = np.searchsorted(
423-
row_group_offsets, [first_row, last_row], side="right"
424-
)
365+
return pa_table, list(truncated_columns)
425366

367+
def _read_without_binary(
368+
self, row_group_readers: list[RowGroupReader], first_row_group_id: int, last_row_group_id: int
369+
) -> tuple[pa.Table, list[str]]:
426370
with StepProfiler(
427371
method="parquet_index_with_metadata.row_groups_size_check", step="check if the rows can fit in memory"
428372
):
@@ -443,8 +387,8 @@ def query(self, offset: int, length: int) -> pa.Table:
443387
)
444388
except ArrowInvalid as err:
445389
raise SchemaMismatchError("Parquet files have different schema.", err)
446-
first_row_in_pa_table = row_group_offsets[first_row_group_id - 1] if first_row_group_id > 0 else 0
447-
return pa_table.slice(parquet_offset - first_row_in_pa_table, length)
390+
391+
return pa_table, []
448392

449393
@staticmethod
450394
def from_parquet_metadata_items(
@@ -458,40 +402,31 @@ def from_parquet_metadata_items(
458402
raise EmptyParquetMetadataError("No parquet files found.")
459403

460404
partial = parquet_export_is_partial(parquet_file_metadata_items[0]["url"])
405+
metadata_dir = Path(parquet_metadata_directory)
461406

462407
with StepProfiler(
463408
method="parquet_index_with_metadata.from_parquet_metadata_items",
464409
step="get the index from parquet metadata",
465410
):
466411
try:
467-
parquet_files_metadata = sorted(
468-
parquet_file_metadata_items, key=lambda parquet_file_metadata: parquet_file_metadata["filename"]
469-
)
470-
parquet_files_urls = [parquet_file_metadata["url"] for parquet_file_metadata in parquet_files_metadata]
471-
metadata_paths = [
472-
os.path.join(parquet_metadata_directory, parquet_file_metadata["parquet_metadata_subpath"])
473-
for parquet_file_metadata in parquet_files_metadata
474-
]
475-
num_bytes = [parquet_file_metadata["size"] for parquet_file_metadata in parquet_files_metadata]
476-
num_rows = [parquet_file_metadata["num_rows"] for parquet_file_metadata in parquet_files_metadata]
412+
files = sorted(parquet_file_metadata_items, key=lambda f: f["filename"])
477413
except Exception as e:
478414
raise ParquetResponseFormatError(f"Could not parse the list of parquet files: {e}") from e
479415

480416
with StepProfiler(
481417
method="parquet_index_with_metadata.from_parquet_metadata_items", step="get the dataset's features"
482418
):
483419
if features is None: # config-parquet version<6 didn't have features
484-
features = Features.from_arrow_schema(pq.read_schema(metadata_paths[0]))
420+
first_arrow_schema = pq.read_schema(metadata_dir / files[0]["parquet_metadata_subpath"])
421+
features = Features.from_arrow_schema(first_arrow_schema)
485422

486423
return ParquetIndexWithMetadata(
424+
files=files,
487425
features=features,
488-
parquet_files_urls=parquet_files_urls,
489-
metadata_paths=metadata_paths,
490-
num_bytes=num_bytes,
491-
num_rows=num_rows,
492426
httpfs=httpfs,
493427
max_arrow_data_in_memory=max_arrow_data_in_memory,
494428
partial=partial,
429+
metadata_dir=metadata_dir,
495430
)
496431

497432

@@ -551,28 +486,7 @@ def _init_parquet_index(
551486

552487
# note that this cache size is global for the class, not per instance
553488
@lru_cache(maxsize=1)
554-
def query(self, offset: int, length: int) -> pa.Table:
555-
"""Query the parquet files
556-
557-
Note that this implementation will always read at least one row group, to get the list of columns and always
558-
have the same schema, even if the requested rows are invalid (out of range).
559-
560-
Args:
561-
offset (`int`): The first row to read.
562-
length (`int`): The number of rows to read.
563-
564-
Returns:
565-
`pa.Table`: The requested rows.
566-
"""
567-
logging.info(
568-
f"Query {type(self.parquet_index).__name__} for dataset={self.dataset}, config={self.config},"
569-
f" split={self.split}, offset={offset}, length={length}"
570-
)
571-
return self.parquet_index.query(offset=offset, length=length)
572-
573-
# note that this cache size is global for the class, not per instance
574-
@lru_cache(maxsize=1)
575-
def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, list[str]]:
489+
def query(self, offset: int, length: int) -> tuple[pa.Table, list[str]]:
576490
"""Query the parquet files
577491
578492
Note that this implementation will always read at least one row group, to get the list of columns and always
@@ -590,4 +504,4 @@ def query_truncated_binary(self, offset: int, length: int) -> tuple[pa.Table, li
590504
f"Query {type(self.parquet_index).__name__} for dataset={self.dataset}, config={self.config},"
591505
f" split={self.split}, offset={offset}, length={length}, with truncated binary"
592506
)
593-
return self.parquet_index.query_truncated_binary(offset=offset, length=length)
507+
return self.parquet_index.query(offset=offset, length=length)

0 commit comments

Comments
 (0)