diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py new file mode 100644 index 0000000000..db4746137c --- /dev/null +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/chunk_sampler.py @@ -0,0 +1,95 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Chunk-aware sampler for efficient iteration over chunked SCDL datasets.""" + +import random +import warnings +from typing import Iterator, Optional + +from torch.utils.data import Sampler + +from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset + + +class ChunkAwareSampler(Sampler[int]): + """Sampler that iterates by chunks for efficient access patterns. + + This sampler ensures all rows from a chunk window are accessed together + before moving to the next window. This is optimal for: + - Local: memory locality (chunk data stays in cache) + - Remote: prefetching (download chunks once, use all rows) + + Args: + dataset: A chunked SingleCellMemMapDataset. + shuffle_chunks: Whether to shuffle chunk order each epoch. + shuffle_within_window: Whether to shuffle rows within each chunk window. + chunks_per_window: Number of chunks to load together (more = better randomness). + seed: Random seed for reproducibility. + """ + + def __init__( + self, + dataset: SingleCellMemMapDataset, + shuffle_chunks: bool = True, + shuffle_within_window: bool = True, + chunks_per_window: int = 1, + seed: Optional[int] = None, + ): + """Initialize the chunk aware sampler.""" + if not dataset._is_chunked: + raise ValueError("ChunkAwareSampler requires a chunked dataset") + + self.dataset = dataset + self.shuffle_chunks = shuffle_chunks + self.shuffle_within_window = shuffle_within_window + self.chunks_per_window = max(1, chunks_per_window) + self.rng = random.Random(seed) + self.chunked_info = dataset.header.chunked_info + + # Warn if chunks_per_window exceeds cache size (causes thrashing) + if dataset._chunk_loader and chunks_per_window > dataset._chunk_loader.max_cached_chunks: + warnings.warn( + f"chunks_per_window ({chunks_per_window}) > max_cached_chunks " + f"({dataset._chunk_loader.max_cached_chunks}). This causes cache thrashing. " + f"Increase max_cached_chunks or decrease chunks_per_window." + ) + + def __iter__(self) -> Iterator[int]: + """Yield row indices, grouped by chunk window.""" + chunk_ids = list(range(self.chunked_info.num_chunks)) + + if self.shuffle_chunks: + self.rng.shuffle(chunk_ids) + + # Process in windows of N chunks + for i in range(0, len(chunk_ids), self.chunks_per_window): + window_chunks = chunk_ids[i : i + self.chunks_per_window] + + # Gather all indices from this window + all_indices = [] + for chunk_id in window_chunks: + start = chunk_id * self.chunked_info.chunk_size + end = min(start + self.chunked_info.chunk_size, self.chunked_info.total_rows) + all_indices.extend(range(start, end)) + + if self.shuffle_within_window: + self.rng.shuffle(all_indices) + + yield from all_indices + + def __len__(self) -> int: + """Return total number of samples.""" + return self.chunked_info.total_rows diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py new file mode 100644 index 0000000000..b98b7ce856 --- /dev/null +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/remote_chunk_loader.py @@ -0,0 +1,142 @@ +# SPDX-FileCopyrightText: Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Remote chunk loader with LRU caching for chunked SCDL datasets. + +NOTE: This is a simple POC implementation. For production multi-worker/multi-node use: +- Add file locking for shared cache (filelock) +- Add reference counting to prevent evicting in-use chunks +- Use DistributedChunkSampler to shard chunks across nodes +""" + +import shutil +import tempfile +from collections import OrderedDict +from pathlib import Path +from typing import Optional + +import fsspec + + +class RemoteChunkLoader: + """Downloads and caches chunks from remote storage with LRU eviction. + + Args: + remote_path: Remote path (s3://bucket/path, gs://bucket/path, etc.) + cache_dir: Local directory for caching chunks. If None, uses temp directory. + max_cached_chunks: Maximum number of chunks to keep in cache. + storage_options: Optional dict of options passed to fsspec (e.g., endpoint_url for S3). + """ + + def __init__( + self, + remote_path: str, + cache_dir: Optional[Path] = None, + max_cached_chunks: int = 2, + storage_options: Optional[dict] = None, + ): + """Initialize the remote chunk loader.""" + self.remote_path = remote_path.rstrip("/") + self.cache_dir = Path(cache_dir) if cache_dir else Path(tempfile.mkdtemp(prefix="scdl_cache_")) + self.max_cached_chunks = max_cached_chunks + self._cache: OrderedDict[int, Path] = OrderedDict() + + # Initialize filesystem with optional storage options + protocol = remote_path.split("://")[0] if "://" in remote_path else "file" + self._fs = fsspec.filesystem(protocol, **(storage_options or {})) + + # Ensure cache directory exists + self.cache_dir.mkdir(parents=True, exist_ok=True) + + def get_chunk(self, chunk_id: int) -> Path: + """Get local path to chunk, downloading if needed. + + Args: + chunk_id: The chunk index to retrieve. + + Returns: + Local path to the chunk directory. + """ + if chunk_id in self._cache: + self._cache.move_to_end(chunk_id) + return self._cache[chunk_id] + + # Evict oldest chunks if at capacity + while len(self._cache) >= self.max_cached_chunks: + old_id, old_path = self._cache.popitem(last=False) + shutil.rmtree(old_path, ignore_errors=True) + + # Download chunk + local_path = self._download_chunk(chunk_id) + self._cache[chunk_id] = local_path + return local_path + + def _download_chunk(self, chunk_id: int) -> Path: + """Download a chunk from remote storage.""" + chunk_name = f"chunk_{chunk_id:05d}" + remote_chunk = f"{self.remote_path}/{chunk_name}" + local_chunk = self.cache_dir / chunk_name + + local_chunk.mkdir(parents=True, exist_ok=True) + + # Download all files in chunk directory + for remote_file in self._fs.ls(remote_chunk): + fname = Path(remote_file).name + self._fs.get(remote_file, str(local_chunk / fname)) + + return local_chunk + + def _remote_exists(self, remote_path: str) -> bool: + """Check if a remote path exists (uses ls instead of exists for compatibility).""" + try: + # Use ls instead of exists() because some S3-compatible storage + # doesn't support HeadObject which exists() relies on + parent = "/".join(remote_path.rsplit("/", 1)[:-1]) + name = remote_path.rsplit("/", 1)[-1] + files = self._fs.ls(parent, detail=False) + return any(f.endswith(name) for f in files) + except Exception: + return False + + def get_metadata(self) -> Path: + """Download and return path to metadata files (header, features, etc.).""" + metadata_dir = self.cache_dir / "_metadata" + if metadata_dir.exists(): + return metadata_dir + + metadata_dir.mkdir(parents=True, exist_ok=True) + + # Download header and feature indices (header.sch is the SCDL header format) + for name in ["header.sch", "version.json", "metadata.json"]: + remote_file = f"{self.remote_path}/{name}" + if self._remote_exists(remote_file): + self._fs.get(remote_file, str(metadata_dir / name)) + + # Download feature directories + for name in ["var_features", "obs_features"]: + remote_dir = f"{self.remote_path}/{name}" + if self._remote_exists(remote_dir): + local_dir = metadata_dir / name + self._fs.get(remote_dir, str(local_dir), recursive=True) + + return metadata_dir + + def cleanup(self): + """Delete all cached data.""" + shutil.rmtree(self.cache_dir, ignore_errors=True) + + def __del__(self): + """Cleanup on garbage collection.""" + self.cleanup() diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py index 89fa3f9d61..cf887bbcd6 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/io/single_cell_memmap_dataset.py @@ -18,6 +18,7 @@ import logging import os import shutil +import tempfile import warnings from pathlib import Path from typing import Dict, List, Optional, Tuple, Union @@ -30,6 +31,7 @@ from bionemo.scdl.api.single_cell_row_dataset import SingleCellRowDataset from bionemo.scdl.index.row_feature_index import ObservedFeatureIndex, VariableFeatureIndex +from bionemo.scdl.io.remote_chunk_loader import RemoteChunkLoader from bionemo.scdl.schema.header import ArrayDType, ArrayInfo, Backend, FeatureIndexInfo, SCDLHeader from bionemo.scdl.schema.version import CurrentSCDLVersion from bionemo.scdl.util.filecopyutil import extend_files @@ -41,6 +43,7 @@ determine_dtype, smallest_uint_dtype, ) +from bionemo.scdl.util.partition_scdl import partition_scdl from bionemo.scdl.util.scdl_constants import FLOAT_ORDER, INT_ORDER, FileNames, Mode, NeighborSamplingStrategy @@ -128,6 +131,9 @@ def __init__( self.data_path: str = data_path self.header: SCDLHeader = None self.mode: Mode = mode + self._is_chunked: bool = False + self._chunks: List[Tuple[np.ndarray, np.ndarray, np.ndarray]] = [] + self._chunk_loader = None # For remote chunked datasets self.paginated_load_cutoff = paginated_load_cutoff self.load_block_row_size = load_block_row_size self.var_feature_index_name = var_feature_index_name @@ -260,6 +266,41 @@ def version(self) -> str: """ return self._version + @classmethod + def from_remote( + cls, + remote_path: str, + cache_dir: Optional[str] = None, + max_cached_chunks: int = 2, + storage_options: Optional[Dict] = None, + ): + """Load a chunked dataset from remote storage (S3, GCS, HTTP). + + Args: + remote_path: Remote path (s3://bucket/path, gs://bucket/path, etc.) + cache_dir: Local directory for caching chunks. If None, uses temp directory. + max_cached_chunks: Maximum number of chunks to keep in cache. + storage_options: Options passed to fsspec (e.g., {"endpoint_url": "https://..."} for S3). + """ + loader = RemoteChunkLoader( + remote_path, Path(cache_dir) if cache_dir else None, max_cached_chunks, storage_options + ) + metadata_path = loader.get_metadata() + ds = cls.__new__(cls) + # Initialize essential attributes that __init__ would set + ds._version = importlib.metadata.version("bionemo.scdl") + ds._chunk_loader = loader + ds.data_path = remote_path + ds.header = None + ds.mode = Mode.READ_APPEND + ds._is_chunked = False + ds._chunks = [] + ds.dtypes = {} + ds._var_feature_index = None + ds._obs_feature_index = None + ds.load(str(metadata_path)) + return ds + def _extract_neighbor_data(self, adata) -> bool: """Extracts neighbor data from AnnData.obsp object and saves to memmap files. @@ -436,10 +477,19 @@ def get_row( List[np.ndarray]: optional, corresponding variable (column) features. List[np.ndarray]: optional, corresponding observed (row) features. """ - start = self.row_index[index] - end = self.row_index[index + 1] - values = self.data[start:end] - columns = self.col_index[start:end] + if self._is_chunked: + chunk_id, local_idx = self.header.chunked_info.get_chunk_for_row(index) + if self._chunk_loader: + data, rowptr, colptr = self._load_chunk_from_path(self._chunk_loader.get_chunk(chunk_id)) + else: + data, rowptr, colptr = self._chunks[chunk_id] + start, end = rowptr[local_idx], rowptr[local_idx + 1] + values, columns = data[start:end], colptr[start:end] + else: + start = self.row_index[index] + end = self.row_index[index + 1] + values = self.data[start:end] + columns = self.col_index[start:end] ret = (values, columns) var_features = ( self._var_feature_index.lookup(index, select_features=var_feature_names)[0] @@ -681,41 +731,60 @@ def load(self, stored_path: str) -> None: if self.header is not None and hasattr(self.header, "arrays"): # Map from FileNames.value to dtype string for array_info in self.header.arrays: - if FileNames[array_info.name].value not in self.dtypes: - raise ValueError(f"Array name {FileNames[array_info.name].value} not found in dtypes") self.dtypes[FileNames[array_info.name].value] = array_info.dtype.numpy_dtype_string - # Metadata is required, so we must check if it exists and fail if not. - if not os.path.exists(f"{self.data_path}/{FileNames.METADATA.value}"): - raise FileNotFoundError( - f"Error: the metadata file {self.data_path}/{FileNames.METADATA.value} does not exist." - ) - - with open(f"{self.data_path}/{FileNames.METADATA.value}", Mode.READ_APPEND.value) as mfi: - self.metadata = json.load(mfi) + # Load metadata if exists + metadata_path = f"{self.data_path}/{FileNames.METADATA.value}" + if os.path.exists(metadata_path): + with open(metadata_path, Mode.READ_APPEND.value) as mfi: + self.metadata = json.load(mfi) + # Load feature indices if os.path.exists(f"{self.data_path}/{FileNames.VAR_FEATURES.value}"): self._var_feature_index = VariableFeatureIndex.load(f"{self.data_path}/{FileNames.VAR_FEATURES.value}") - elif os.path.exists( - f"{self.data_path}/{FileNames.FEATURES.value}" - ): # Backward compatibility with old features file + elif os.path.exists(f"{self.data_path}/{FileNames.FEATURES.value}"): self._var_feature_index = VariableFeatureIndex.load(f"{self.data_path}/{FileNames.FEATURES.value}") if os.path.exists(f"{self.data_path}/{FileNames.OBS_FEATURES.value}"): self._obs_feature_index = ObservedFeatureIndex.load(f"{self.data_path}/{FileNames.OBS_FEATURES.value}") - # mmap the existing arrays - self.data = self._load_mmap_file_if_exists( - f"{self.data_path}/{FileNames.DATA.value}", self.dtypes[f"{FileNames.DATA.value}"] - ) - self.row_index = self._load_mmap_file_if_exists( - f"{self.data_path}/{FileNames.ROWPTR.value}", dtype=self.dtypes[f"{FileNames.ROWPTR.value}"] - ) - self.col_index = self._load_mmap_file_if_exists( - f"{self.data_path}/{FileNames.COLPTR.value}", dtype=self.dtypes[f"{FileNames.COLPTR.value}"] - ) - # Load neighbor data - if self.load_neighbors: - self._load_neighbor_memmaps() + # Load data arrays - chunked vs monolithic + if self.header is not None and self.header.backend == Backend.CHUNKED_MEMMAP_V0: + self._is_chunked = True + self._load_chunk_memmaps() + else: + self.data = self._load_mmap_file_if_exists( + f"{self.data_path}/{FileNames.DATA.value}", self.dtypes[f"{FileNames.DATA.value}"] + ) + self.row_index = self._load_mmap_file_if_exists( + f"{self.data_path}/{FileNames.ROWPTR.value}", dtype=self.dtypes[f"{FileNames.ROWPTR.value}"] + ) + self.col_index = self._load_mmap_file_if_exists( + f"{self.data_path}/{FileNames.COLPTR.value}", dtype=self.dtypes[f"{FileNames.COLPTR.value}"] + ) + if self.load_neighbors: + self._load_neighbor_memmaps() + + def _load_chunk_memmaps(self) -> None: + """Preload all chunk memmaps (lazy - just file handles, no RAM). + + For local datasets, loads from data_path directly. + For remote datasets, this is skipped - chunks are loaded on demand. + """ + # For remote datasets, don't preload - chunks are fetched on demand via get_row() + if self._chunk_loader is not None: + return + # Local: preload all chunk paths + for chunk_id in range(self.header.chunked_info.num_chunks): + chunk_path = Path(self.data_path) / f"chunk_{chunk_id:05d}" + self._chunks.append(self._load_chunk_from_path(chunk_path)) + + def _load_chunk_from_path(self, chunk_path: Path) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """Load memmaps for a single chunk directory.""" + return ( + np.memmap(chunk_path / FileNames.DATA.value, dtype=self.dtypes[FileNames.DATA.value], mode="r"), + np.memmap(chunk_path / FileNames.ROWPTR.value, dtype=self.dtypes[FileNames.ROWPTR.value], mode="r"), + np.memmap(chunk_path / FileNames.COLPTR.value, dtype=self.dtypes[FileNames.COLPTR.value], mode="r"), + ) def _write_metadata(self) -> None: with open(f"{self.data_path}/{FileNames.METADATA.value}", f"{Mode.CREATE.value}") as mfi: @@ -1218,6 +1287,8 @@ def number_of_rows(self) -> int: ValueError if the length of the number of rows in the feature index does not correspond to the number of stored rows. """ + if self._is_chunked: + return self.header.chunked_info.total_rows if len(self._var_feature_index) > 0 and self._var_feature_index.number_of_rows() != self.row_index.size - 1: raise ValueError( f"""The number of rows in the feature index {self._var_feature_index.number_of_rows()} @@ -1445,3 +1516,32 @@ def concat( mode=Mode.READ_APPEND.value, ) self.save() + + def to_chunked( + self, output_path: Optional[str] = None, chunk_size: int = 100_000, delete_original: bool = False + ) -> "SingleCellMemMapDataset": + """Convert this dataset to a chunked format for efficient remote access. + + Args: + output_path: Path where the chunked dataset will be created. If None, replaces in-place. + chunk_size: Number of rows per chunk (default: 100,000). + delete_original: If True and output_path is set, delete the original after conversion. + + Returns: + A new SingleCellMemMapDataset instance pointing to the chunked data. + """ + if self._is_chunked: + raise ValueError("Dataset is already chunked") + + src = Path(self.data_path) + if output_path is None: + # In-place: partition to temp, then swap + with tempfile.TemporaryDirectory() as tmp_dir: + tmp_path = Path(tmp_dir) / "chunked" + partition_scdl(src, tmp_path, chunk_size=chunk_size) + shutil.rmtree(src) + shutil.move(str(tmp_path), str(src)) + return SingleCellMemMapDataset(str(src)) + + partition_scdl(src, Path(output_path), chunk_size=chunk_size, delete_original=delete_original) + return SingleCellMemMapDataset(output_path) diff --git a/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py b/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py index 1affa2a596..e1f1578858 100644 --- a/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py +++ b/sub-packages/bionemo-scdl/src/bionemo/scdl/util/partition_scdl.py @@ -20,7 +20,6 @@ import numpy as np -from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset from bionemo.scdl.schema.header import ChunkedInfo, SCDLHeader from bionemo.scdl.util.scdl_constants import Backend, FileNames @@ -29,8 +28,11 @@ def partition_scdl( input_path: Path, output_path: Path, chunk_size: int = 100_000, + delete_original: bool = False, ) -> SCDLHeader: """Partition an SCDL dataset into chunks.""" + from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset + input_path, output_path = Path(input_path), Path(output_path) if not input_path.exists(): @@ -44,7 +46,11 @@ def partition_scdl( source_ds = SingleCellMemMapDataset(str(input_path)) total_rows = len(source_ds) rowptr = source_ds.row_index - num_chunks = (total_rows + chunk_size - 1) // chunk_size + if chunk_size <= 0: + raise ValueError(f"Chunk size must be greater than 0, got {chunk_size}") + if total_rows <= 0: + raise ValueError(f"Total rows must be greater than 0, got {total_rows}") + num_chunks = max(1, (total_rows + chunk_size - 1) // chunk_size) # Create chunks for chunk_id in range(num_chunks): @@ -78,4 +84,8 @@ def partition_scdl( header.chunked_info = ChunkedInfo(chunk_size=chunk_size, num_chunks=num_chunks, total_rows=total_rows) header.save(str(output_path / FileNames.HEADER.value)) + if delete_original: + del source_ds # Release memmap handles + shutil.rmtree(input_path) + return header diff --git a/sub-packages/bionemo-scdl/test_remote_loading.py b/sub-packages/bionemo-scdl/test_remote_loading.py new file mode 100644 index 0000000000..e1cf367a12 --- /dev/null +++ b/sub-packages/bionemo-scdl/test_remote_loading.py @@ -0,0 +1,102 @@ +#!/usr/bin/env python + +# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: LicenseRef-Apache2 +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +"""Test script for remote chunked SCDL loading with ChunkAwareSampler. + +Usage: + python test_remote_loading.py s3://my-bucket/chunked_scdl + python test_remote_loading.py gs://my-bucket/chunked_scdl + python test_remote_loading.py --cache-dir /tmp/cache --max-chunks 3 s3://bucket/path +""" + +import argparse + +from torch.utils.data import DataLoader + +from bionemo.scdl.io.chunk_sampler import ChunkAwareSampler +from bionemo.scdl.io.single_cell_memmap_dataset import SingleCellMemMapDataset +from bionemo.scdl.util.torch_dataloader_utils import collate_sparse_matrix_batch + + +def main(): + parser = argparse.ArgumentParser(description="Test remote chunked SCDL loading") + parser.add_argument( + "--remote_path", + default="s3://general-purpose/polina/chunked", + help="Remote path (s3://..., gs://..., https://...)", + ) + parser.add_argument("--endpoint-url", default="https://pbss.s8k.io", help="S3 endpoint URL (for non-AWS S3)") + parser.add_argument("--cache-dir", default="/tmp/scdl_cache", help="Local cache directory") + parser.add_argument("--max-chunks", type=int, default=3, help="Max chunks to cache") + parser.add_argument("--chunks-per-window", type=int, default=2, help="Chunks per sampling window") + parser.add_argument("--batch-size", type=int, default=1, help="Batch size") + parser.add_argument("--num-batches", type=int, default=10, help="Number of batches to iterate") + args = parser.parse_args() + + print(f"Loading remote dataset: {args.remote_path}") + print(f" Endpoint: {args.endpoint_url}") + print(f" Cache dir: {args.cache_dir}") + print(f" Max cached chunks: {args.max_chunks}") + + # Build storage_options for S3-compatible storage + # For s3fs, endpoint_url must be in client_kwargs + storage_options = {} + if args.endpoint_url: + storage_options["client_kwargs"] = {"endpoint_url": args.endpoint_url} + + # 1. Load from remote + ds = SingleCellMemMapDataset.from_remote( + args.remote_path, + cache_dir=args.cache_dir, + max_cached_chunks=args.max_chunks, + storage_options=storage_options if storage_options else None, + ) + print(f" Rows: {len(ds)}") + print(f" Chunks: {ds.header.chunked_info.num_chunks}") + print(f" Chunk size: {ds.header.chunked_info.chunk_size}") + + # 2. Create sampler + print(f"\nCreating ChunkAwareSampler (chunks_per_window={args.chunks_per_window})...") + sampler = ChunkAwareSampler( + ds, + shuffle_chunks=True, + shuffle_within_window=True, + chunks_per_window=args.chunks_per_window, + seed=42, + ) + + # 3. Create DataLoader + print(f"Creating DataLoader (batch_size={args.batch_size})...") + loader = DataLoader(ds, sampler=sampler, batch_size=args.batch_size, collate_fn=collate_sparse_matrix_batch) + + # 4. Iterate batches + print(f"\nIterating {args.num_batches} batches...") + for i, batch in enumerate(loader): + if i >= args.num_batches: + break + print(f" Batch {i}: shape={batch.shape}") + + print("\nSuccess! Remote chunked loading works.") + + # 5. Cleanup (optional) + if ds._chunk_loader: + print(f"\nCleaning up cache at {args.cache_dir}...") + ds._chunk_loader.cleanup() + + +if __name__ == "__main__": + main() diff --git a/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py b/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py index 3b8e934471..7152d7e53f 100644 --- a/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py +++ b/sub-packages/bionemo-scdl/tests/bionemo/scdl/conftest.py @@ -199,13 +199,24 @@ def _make(tmp_path): @pytest.fixture def make_h5ad_with_raw(make_random_csr): - """Factory to create an h5ad with uniquely randomized data for the fields .raw.X and .X""" + """Factory to create an h5ad with uniquely randomized data for .raw.X, .X, obs, and var.""" def _make(tmp_path): - X = make_random_csr(total_nnz=100, n_cols=50, seed=42) - X_raw = make_random_csr(total_nnz=100, n_cols=50, seed=43) + n_rows, n_cols = 100, 50 + X = make_random_csr(total_nnz=n_rows, n_cols=n_cols, seed=42) + X_raw = make_random_csr(total_nnz=n_rows, n_cols=n_cols, seed=43) + + obs = pd.DataFrame( + {"cell_type": [f"type_{i % 3}" for i in range(n_rows)]}, + index=[f"cell_{i}" for i in range(n_rows)], + ) + var = pd.DataFrame( + {"gene_name": [f"gene_{i}" for i in range(n_cols)]}, + index=[f"ENSG{i:08d}" for i in range(n_cols)], + ) + h = tmp_path / "var.h5ad" - ad.AnnData(X=X, var=pd.DataFrame(index=np.arange(X.shape[1])), raw={"X": X_raw}).write_h5ad(h) + ad.AnnData(X=X, obs=obs, var=var, raw={"X": X_raw}).write_h5ad(h) return h return _make