Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
48 changes: 48 additions & 0 deletions daos_pytorch.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,48 @@
# Running DLIO benchmark with DAOS PyTorch DataLoader

## Prerequisites

- Python 3.10 or higher
- openmpi and openmpi-devel packages
- DAOS client libraries built and installed from [master branch](https://github.com/daos-stack/daos/tree/master/src/client/pydaos/torch)
- configured and working DAOS agent on the compute nodes


## Getting started


Since `DAOS PyTorch` client was not released outside the `master` branch you'd need to build and install the `pydaos3-package` on the compute node manually (`torch` integration comes with `pydaos` package):


```bash

$: pip install $(DAOS_BUILD_OUTPUT)/install/lib/daos/python

```



## Example of running benchmark with `DAOS` PyTorch client


```bash
# LD_LIBRARY_PATH is needed to load DAOS libraries from build directory
export LD_LIBRARY_PATH=/lus/flare/projects/DAOS_Testing/daos/install/lib64/:$LD_LIBRARY_PATH

mpiexec --np ${NTOTRANKS} -ppn ${NRANKS} --cpu-bind depth -d ${NDEPTH} --no-vni \
dlio_benchmark workload=daos_pytorch \
++workload.workflow.generate_data=True \
++workload.dataset.daos_pool=DAOS_Testing \
++workload.dataset.daos_cont=defaults \
++workload.workflow.checkpoint=True \
++workload.checkpoint.checkpoint_daos_pool=DAOS_Testing \
++workload.checkpoint.checkpoint_daos_cont=defaults \
++workload.checkpoint.checkpoint_folder=/checkpoints \
++workload.dataset.data_folder=/datasets/small-08 \
++workload.dataset.num_files_train=80000 \
++workload.dataset.num_files_eval=10000 \
++workload.reader.batch_size=32 \
++workload.reader.read_threads=4 \
++workload.dataset.record_length_bytes=1048576 \
++workload.train.epochs=5
```
80 changes: 80 additions & 0 deletions dlio_benchmark/checkpointing/pytorch_daos_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
"""
Copyright (c) 2026, Enakta Labs Ltd
All Rights Reserved

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import logging
import torch
from pydaos.torch import Checkpoint as DaosCheckpoint

from dlio_benchmark.checkpointing.base_checkpointing import BaseCheckpointing
from dlio_benchmark.checkpointing.pytorch_checkpointing import PyTorchCheckpointing
from dlio_benchmark.utils.utility import Profile, dft_ai

from dlio_benchmark.common.constants import MODULE_CHECKPOINT

dlp = Profile(MODULE_CHECKPOINT)


class PyTorchDaosCheckpointing(PyTorchCheckpointing):
__instance = None

@staticmethod
def get_instance():
""" Static access method. """
if PyTorchDaosCheckpointing.__instance is None:
logging.basicConfig(level=logging.INFO)
PyTorchDaosCheckpointing.__instance = PyTorchDaosCheckpointing()
return PyTorchDaosCheckpointing.__instance

@dft_ai.checkpoint.init
def __init__(self):
BaseCheckpointing.__init__(self, "pt")

prefix = self.args.checkpoint_folder
pool = self.args.checkpoint_daos_pool
cont = self.args.checkpoint_daos_cont
chunk_size = self.args.checkpoint_daos_chunk_size
chunks_limit = self.args.checkpoint_daos_chunks_limit

logging.info(f"Checkpointing is set to DAOS pool: {pool}, container: {cont}, prefix: {prefix}, chunk_size: {chunk_size} and chunks_limit: {chunks_limit}")
self.ckpt = DaosCheckpoint(pool, cont, prefix, transfer_chunk_size=chunk_size, chunks_limit=chunks_limit)

@dft_ai.checkpoint.capture
def save_state(self, suffix, state, fsync = False):
name = self.get_name(suffix)
with self.ckpt.writer(name) as f:
torch.save(state, f)

@dft_ai.checkpoint.restart
def load_state(self, suffix, state):
name = self.get_name(suffix)
state = dict()
with self.ckpt.reader(name) as f:
state = torch.load(f)
self.logger.debug(f"checkpoint state loaded: {state}")
assert(len(state.keys())>0)

@dft_ai.checkpoint.capture
def save_checkpoint(self, epoch, step_number):
super().save_checkpoint(epoch, step_number)

@dlp.log
def load_checkpoint(self, epoch, step_number):
super().load_checkpoint(epoch, step_number)

@dlp.log
def finalize(self):
super().finalize()
2 changes: 2 additions & 0 deletions dlio_benchmark/common/enumerations.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,7 @@ class StorageType(Enum):
PARALLEL_FS = 'parallel_fs'
S3 = 's3'
AISTORE = 'aistore'
DAOS_PYTORCH = 'daos_pytorch'

def __str__(self):
return self.value
Expand Down Expand Up @@ -174,6 +175,7 @@ class DataLoaderType(Enum):
CUSTOM='custom'
NONE='none'
SYNTHETIC='synthetic'
DAOS_PYTORCH="daos_pytorch"

def __str__(self):
return self.value
Expand Down
37 changes: 37 additions & 0 deletions dlio_benchmark/configs/workload/daos_pytorch.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
model:
name: default

framework: pytorch
storage:
storage_type: "daos_pytorch"
storage_root: "/"

workflow:
generate_data: False
train: True
evaluation: True
profiling: False

dataset:
data_folder: /data/default
format: npz
num_files_train: 1024
num_files_eval: 64
num_samples_per_file: 1
record_length: 4096
num_subfolders_train: 2
num_subfolders_eval: 2
daos_pool: default-pool
daos_cont: samples


reader:
data_loader: daos_pytorch
batch_size: 32
read_threads: 4

checkpoint:
checkpoint_folder: /checkpoints
checkpoint_daos_pool: default-pool
checkpoint_daos_cont: checkpoints
checkpoint_mechanism_classname: dlio_benchmark.checkpointing.pytorch_daos_checkpointing.PyTorchDaosCheckpointing
1 change: 1 addition & 0 deletions dlio_benchmark/data_loader/base_data_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ def __init__(self, format_type, dataset_type, epoch_number, data_loader_type):
self.data_loader_type = data_loader_type
self.num_samples = self._args.total_samples_train if self.dataset_type is DatasetType.TRAIN else self._args.total_samples_eval
self.batch_size = self._args.batch_size if self.dataset_type is DatasetType.TRAIN else self._args.batch_size_eval
self.read_threads = self._args.read_threads
self.logger = self._args.logger

@abstractmethod
Expand Down
114 changes: 114 additions & 0 deletions dlio_benchmark/data_loader/daos_torch_data_loader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
"""
Copyright (c) 2026, Enakta Labs, LTD
All Rights Reserved

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import os
import io
import numpy as np
import math
import pickle
from torch.utils.data import Dataset

from dlio_benchmark.common.constants import MODULE_DATA_LOADER
from dlio_benchmark.common.enumerations import DataLoaderType, FormatType
from dlio_benchmark.data_loader.torch_data_loader import BaseTorchDataLoader
from dlio_benchmark.utils.utility import utcnow, DLIOMPI, Profile, dft_ai
from dlio_benchmark.utils.config import ConfigArguments

dlp = Profile(MODULE_DATA_LOADER)


def get_format_reader(format):
if format == FormatType.NPZ:
return lambda b: np.load(io.BytesIO(b), allow_pickle=True)["x"]
elif format == FormatType.NPY:
return lambda b: np.load(io.BytesIO(b), allow_pickle=True)
else:
raise ValueError(f"TorchDaosDataset does not support {format} format")

class TorchDaosDataset(Dataset):
"""
Wrapper over DaosDataset to log calls for the profiler
"""
@dlp.log_init
def __init__(self, format_type, dataset_type, epoch, num_samples, num_workers, batch_size):
self.format_type = format_type
self.dataset_type = dataset_type
self.epoch_number = epoch
self.num_samples = num_samples
self.num_images_read = 0
self.batch_size = batch_size
args = ConfigArguments.get_instance()
self.serial_args = pickle.dumps(args)
self.logger = args.logger
self.dlp_logger = None

# to avoid loading pydoas.torch at the top level if not needed or not installed
from pydaos.torch import Dataset as DaosDataset

prefix = os.path.join(args.data_folder, f"{self.dataset_type}")
self.dataset = DaosDataset(pool=args.daos_pool,
cont=args.daos_cont,
path=prefix,
transform_fn=get_format_reader(self.format_type))

# self.num_samples = len(self.dataset)
if num_workers == 0:
self.worker_init(-1)


@dlp.log
def worker_init(self, worker_id):
pickle.loads(self.serial_args)
_args = ConfigArguments.get_instance()
_args.configure_dlio_logging(is_child=True)
self.dlp_logger = _args.configure_dftracer(is_child=True, use_pid=True)
self.logger.debug(f"{utcnow()} worker initialized {worker_id} with format {self.format_type}")
self.dataset.worker_init(worker_id)

def __del__(self):
if self.dlp_logger:
self.dlp_logger.finalize()

@dlp.log
def __len__(self):
return self.num_samples

@dlp.log
def __getitem__(self, image_idx):
self.num_images_read += 1
step = int(math.ceil(self.num_images_read / self.batch_size))
self.logger.debug(f"{utcnow()} Rank {DLIOMPI.get_instance().rank()} reading {image_idx} sample")
dlp.update(step=step)
dft_ai.update(step=step)
return self.dataset.__getitem__(image_idx)

@dlp.log
def __getitems__(self, indices):
self.num_images_read += len(indices)
step = int(math.ceil(self.num_images_read / self.batch_size))
self.logger.debug(f"{utcnow()} Rank {DLIOMPI.get_instance().rank()} reading {len(indices)} samples")
dlp.update(step=step)
dft_ai.update(step=step)
return self.dataset.__getitems__(indices)

class DaosTorchDataLoader(BaseTorchDataLoader):
@dlp.log_init
def __init__(self, format_type, dataset_type, epoch_number):
super().__init__(format_type, dataset_type, epoch_number, DataLoaderType.DAOS_PYTORCH)

def get_dataset(self) -> Dataset:
return TorchDaosDataset(self.format_type, self.dataset_type, self.epoch_number, self.num_samples,
self.read_threads, self.batch_size)
3 changes: 3 additions & 0 deletions dlio_benchmark/data_loader/data_loader_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,9 @@ def get_loader(type, format_type, dataset_type, epoch):
elif type == DataLoaderType.PYTORCH:
from dlio_benchmark.data_loader.torch_data_loader import TorchDataLoader
return TorchDataLoader(format_type, dataset_type, epoch)
elif type == DataLoaderType.DAOS_PYTORCH:
from dlio_benchmark.data_loader.daos_torch_data_loader import DaosTorchDataLoader
return DaosTorchDataLoader(format_type, dataset_type, epoch)
elif type == DataLoaderType.TENSORFLOW:
from dlio_benchmark.data_loader.tf_data_loader import TFDataLoader
return TFDataLoader(format_type, dataset_type, epoch)
Expand Down
Loading
Loading