diff --git a/jax_dataloader/_modidx.py b/jax_dataloader/_modidx.py index 06daffc..1ccd0d7 100644 --- a/jax_dataloader/_modidx.py +++ b/jax_dataloader/_modidx.py @@ -66,6 +66,14 @@ 'jax_dataloader/loaders/base.py'), 'jax_dataloader.loaders.base.BaseDataLoader.__next__': ( 'loader.base.html#basedataloader.__next__', 'jax_dataloader/loaders/base.py')}, + 'jax_dataloader.loaders.grain': { 'jax_dataloader.loaders.grain.DataLoaderGrain': ( 'loader.grain.html#dataloadergrain', + 'jax_dataloader/loaders/grain.py'), + 'jax_dataloader.loaders.grain.DataLoaderGrain.__init__': ( 'loader.grain.html#dataloadergrain.__init__', + 'jax_dataloader/loaders/grain.py'), + 'jax_dataloader.loaders.grain.DataLoaderGrain.__iter__': ( 'loader.grain.html#dataloadergrain.__iter__', + 'jax_dataloader/loaders/grain.py'), + 'jax_dataloader.loaders.grain.DataLoaderGrain.__next__': ( 'loader.grain.html#dataloadergrain.__next__', + 'jax_dataloader/loaders/grain.py')}, 'jax_dataloader.loaders.jax': { 'jax_dataloader.loaders.jax.DataLoaderJAX': ( 'loader.jax.html#dataloaderjax', 'jax_dataloader/loaders/jax.py'), 'jax_dataloader.loaders.jax.DataLoaderJAX.__init__': ( 'loader.jax.html#dataloaderjax.__init__', diff --git a/jax_dataloader/imports.py b/jax_dataloader/imports.py index d8bc69f..32e1cdb 100644 --- a/jax_dataloader/imports.py +++ b/jax_dataloader/imports.py @@ -62,6 +62,11 @@ tfds = None TFDataset = Annotated[None, Is[lambda _: tf is not None]] +try: + import grain.python as grain +except ModuleNotFoundError: + grain = None + try: import haiku as hk except ModuleNotFoundError: diff --git a/jax_dataloader/loaders/grain.py b/jax_dataloader/loaders/grain.py new file mode 100644 index 0000000..f349fdf --- /dev/null +++ b/jax_dataloader/loaders/grain.py @@ -0,0 +1,47 @@ +# AUTOGENERATED! DO NOT EDIT! File to edit: ../../nbs/loader.grain.ipynb. + +# %% ../../nbs/loader.grain.ipynb 3 +from __future__ import print_function, division, annotations +from ..imports import * +from ..datasets import ArrayDataset, JAXDataset +from . import BaseDataLoader +from ..utils import get_config +from ..tests import * +import jax_dataloader as jdl + +# %% auto 0 +__all__ = ['DataLoaderGrain'] + +# %% ../../nbs/loader.grain.ipynb 4 +class DataLoaderGrain(BaseDataLoader): + + # @typecheck + def __init__( + self, + dataset: Union[JAXDataset, TorchDataset, HFDataset], + batch_size: int = 1, # Batch size + shuffle: bool = False, # If true, dataloader shuffles before sampling each batch + num_workers: int = 0, # Number of workers to use + drop_last: bool = False, # Drop last batch or not + **kwargs + ): + + sampler = grain.IndexSampler( + num_records=len(dataset), + shuffle=shuffle, + seed=get_config().global_seed, + shard_options=grain.NoSharding() + ) + operations = (grain.Batch(batch_size, drop_remainder=drop_last),) + self.dataloader = grain.DataLoader( + data_source=dataset, + sampler=sampler, + operations=operations, + worker_count=num_workers + ) + + def __next__(self): + return next(self.dataloader) + + def __iter__(self): + return self.dataloader.__iter__() diff --git a/nbs/loader.grain.ipynb b/nbs/loader.grain.ipynb new file mode 100644 index 0000000..2006027 --- /dev/null +++ b/nbs/loader.grain.ipynb @@ -0,0 +1,113 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Grain Dataloader" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| default_exp loaders.grain" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| include: false\n", + "%load_ext autoreload\n", + "%autoreload 2\n", + "from ipynb_path import *\n", + "import warnings\n", + "warnings.simplefilter(action='ignore', category=FutureWarning)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "from __future__ import print_function, division, annotations\n", + "from jax_dataloader.imports import *\n", + "from jax_dataloader.datasets import ArrayDataset, JAXDataset\n", + "from jax_dataloader.loaders import BaseDataLoader\n", + "from jax_dataloader.utils import get_config\n", + "from jax_dataloader.tests import *\n", + "import jax_dataloader as jdl" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| export\n", + "class DataLoaderGrain(BaseDataLoader):\n", + "\n", + " # @typecheck\n", + " def __init__(\n", + " self, \n", + " dataset: Union[JAXDataset, TorchDataset, HFDataset],\n", + " batch_size: int = 1, # Batch size\n", + " shuffle: bool = False, # If true, dataloader shuffles before sampling each batch\n", + " num_workers: int = 0, # Number of workers to use\n", + " drop_last: bool = False, # Drop last batch or not\n", + " **kwargs\n", + " ):\n", + "\n", + " sampler = grain.IndexSampler(\n", + " num_records=len(dataset),\n", + " shuffle=shuffle,\n", + " seed=get_config().global_seed,\n", + " shard_options=grain.NoSharding()\n", + " )\n", + " operations = (grain.Batch(batch_size, drop_remainder=drop_last),)\n", + " self.dataloader = grain.DataLoader(\n", + " data_source=dataset,\n", + " sampler=sampler,\n", + " operations=operations,\n", + " worker_count=num_workers\n", + " )\n", + "\n", + " def __next__(self):\n", + " return next(self.dataloader)\n", + "\n", + " def __iter__(self):\n", + " return self.dataloader.__iter__()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "#| hide\n", + "# test_dataloader(DataLoaderGrain, samples=20, batch_size=12, test_len=False)\n", + "# test_dataloader(DataLoaderGrain, samples=20, batch_size=10, test_len=False)\n", + "# test_dataloader(DataLoaderGrain, samples=11, batch_size=10, test_len=False)\n", + "# test_dataloader(DataLoaderGrain, samples=40, batch_size=12, test_len=False)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "python3", + "language": "python", + "name": "python3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/settings.ini b/settings.ini index bad1648..48b090d 100644 --- a/settings.ini +++ b/settings.ini @@ -30,6 +30,7 @@ dev_requirements = scikit-learn pandas nbdev jupyter dm-haiku optax nbdev-mkdocs torch_requirements = torch torchvision tensorflow_requirements = tensorflow tensorflow-datasets huggingface_requirements = datasets +grain_requirements = grain black_formatting = False readme_nb = index.ipynb allowed_metadata_keys = diff --git a/setup.py b/setup.py index f446ad4..8161893 100644 --- a/setup.py +++ b/setup.py @@ -32,8 +32,12 @@ tensorflow_requirements = (cfg.get('tensorflow_requirements') or '').split() huggingface_requirements = (cfg.get('huggingface_requirements') or '').split() torch_requirements = (cfg.get('torch_requirements') or '').split() +grain_requirements = (cfg.get('grain_requirements') or '').split() dev_requirements = (cfg.get('dev_requirements') or '').split() -all_requirements = requirements + tensorflow_requirements + huggingface_requirements + torch_requirements + dev_requirements +all_requirements = ( + requirements + tensorflow_requirements + huggingface_requirements + + torch_requirements + grain_requirements + dev_requirements +) extras_require = { 'all': all_requirements, 'dev': all_requirements,