Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 3 additions & 0 deletions dlio_benchmark/checkpointing/checkpointing_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,5 +42,8 @@ def get_mechanism(checkpoint_mechanism_type):
elif checkpoint_mechanism_type == CheckpointMechanismType.PT_S3_SAVE:
from dlio_benchmark.checkpointing.pytorch_s3_checkpointing import PyTorchS3Checkpointing
return PyTorchS3Checkpointing.get_instance()
elif checkpoint_mechanism_type == CheckpointMechanismType.PT_MSC_SAVE:
from dlio_benchmark.checkpointing.pytorch_msc_checkpointing import PyTorchMscCheckpointing
return PyTorchMscCheckpointing.get_instance()
else:
raise Exception(str(ErrorCodes.EC1005))
79 changes: 79 additions & 0 deletions dlio_benchmark/checkpointing/pytorch_msc_checkpointing.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
"""
Copyright (c) 2026, UChicago Argonne, LLC
All Rights Reserved

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import logging
import os

try:
import multistorageclient as msc
MSC_AVAILABLE = True
except ImportError:
MSC_AVAILABLE = False
Path = None
logging.warning(
"Multi-Storage Client (MSC) not available. "
"Install with: pip install multi-storage-client"
)

from dlio_benchmark.checkpointing.base_checkpointing import BaseCheckpointing
from dlio_benchmark.checkpointing.pytorch_checkpointing import PyTorchCheckpointing
from dlio_benchmark.common.constants import MODULE_CHECKPOINT
from dlio_benchmark.utils.utility import Profile, dft_ai

dlp = Profile(MODULE_CHECKPOINT)


class PyTorchMscCheckpointing(PyTorchCheckpointing):
"""
PyTorch checkpointing via NVIDIA Multi-Storage Client (MSC).
"""
__instance = None

@staticmethod
def get_instance():
""" Static access method. """
if PyTorchMscCheckpointing.__instance is None:
PyTorchMscCheckpointing.__instance = PyTorchMscCheckpointing()
return PyTorchMscCheckpointing.__instance

@dft_ai.checkpoint.init
def __init__(self):
BaseCheckpointing.__init__(self, "ptmsc")
self.checkpoint_folder = self.args.storage_root

@dft_ai.checkpoint.capture
def save_state(self, suffix, state, fsync=False):
name = self.get_name(suffix)
msc.torch.save(state, os.path.join(self.checkpoint_folder, name))

@dft_ai.checkpoint.restart
def load_state(self, suffix, state):
name = self.get_name(suffix)
state = msc.torch.load(os.path.join(self.checkpoint_folder, name))
self.logger.debug(f"checkpoint state loaded: {state}")
assert len(state.keys()) > 0

@dlp.log
def save_checkpoint(self, epoch, step_number):
super().save_checkpoint(epoch, step_number)

@dlp.log
def load_checkpoint(self, epoch, step_number):
super().load_checkpoint(epoch, step_number)

@dlp.log
def finalize(self):
super().finalize()
2 changes: 2 additions & 0 deletions dlio_benchmark/common/enumerations.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ class CheckpointMechanismType(Enum):
TF_SAVE = 'tf_save'
PT_SAVE = 'pt_save'
PT_S3_SAVE = 'pt_s3_save'
PT_MSC_SAVE = 'pt_msc_save'

def __str__(self):
return self.value
Expand Down Expand Up @@ -59,6 +60,7 @@ class StorageType(Enum):
PARALLEL_FS = 'parallel_fs'
S3 = 's3'
AISTORE = 'aistore'
MSC = 'msc'

def __str__(self):
return self.value
Expand Down
20 changes: 16 additions & 4 deletions dlio_benchmark/data_generator/generator_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,8 +14,10 @@
See the License for the specific language governing permissions and
limitations under the License.
"""
from dlio_benchmark.common.enumerations import FormatType
from dlio_benchmark.common.enumerations import FormatType, StorageType
from dlio_benchmark.common.error_code import ErrorCodes
from dlio_benchmark.utils.config import ConfigArguments


class GeneratorFactory(object):
def __init__(self):
Expand Down Expand Up @@ -45,10 +47,20 @@ def get_generator(type):
from dlio_benchmark.data_generator.png_generator import PNGGenerator
return PNGGenerator()
elif type == FormatType.SYNTHETIC:
from dlio_benchmark.data_generator.synthetic_generator import SyntheticGenerator
from dlio_benchmark.data_generator.synthetic_generator import (
SyntheticGenerator,
)
return SyntheticGenerator()
elif type == FormatType.INDEXED_BINARY or type == FormatType.MMAP_INDEXED_BINARY:
from dlio_benchmark.data_generator.indexed_binary_generator import IndexedBinaryGenerator
return IndexedBinaryGenerator()
if ConfigArguments.get_instance().storage_type == StorageType.MSC:
from dlio_benchmark.data_generator.indexed_binary_msc_generator import (
IndexedBinaryMscGenerator,
)
return IndexedBinaryMscGenerator()
else:
from dlio_benchmark.data_generator.indexed_binary_generator import (
IndexedBinaryGenerator,
)
return IndexedBinaryGenerator()
else:
raise Exception(str(ErrorCodes.EC1001))
123 changes: 123 additions & 0 deletions dlio_benchmark/data_generator/indexed_binary_msc_generator.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
Copyright (c) 2026, UChicago Argonne, LLC
All Rights Reserved

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""

import os
import struct
import tempfile

import numpy as np

from dlio_benchmark.common.constants import MODULE_DATA_GENERATOR
from dlio_benchmark.data_generator.data_generator import DataGenerator
from dlio_benchmark.utils.utility import DLIOMPI, Profile, progress

dlp = Profile(MODULE_DATA_GENERATOR)


class IndexedBinaryMscGenerator(DataGenerator):
"""
Generator for creating Indexed Binary data via the storage abstraction (MSC).

Unlike IndexedBinaryGenerator, this class does not use MPI collective I/O
(which requires a shared POSIX filesystem). Each rank independently writes
its assigned files through self.storage.upload_file().
"""

def __init__(self):
super().__init__()

def index_file_path_off(self, prefix_path):
return prefix_path + '.off.idx'

def index_file_path_size(self, prefix_path):
return prefix_path + '.sz.idx'

@dlp.log
def generate(self):
super().generate()
np.random.seed(10)
dim = self.get_dimension(self.total_files_to_generate)

for i in dlp.iter(range(self.my_rank, int(self.total_files_to_generate), self.comm_size)):
dim_ = dim[2 * i]
shape_size = 0
if isinstance(dim_, list):
shape_size = np.prod(dim_)
else:
dim1 = dim_
dim2 = dim[2 * i + 1]
shape_size = dim1 * dim2

sample_size = shape_size * self._args.record_element_bytes
total_size = sample_size * self.num_samples
memory_size = self._args.generation_buffer_size
write_size = total_size
if total_size > memory_size:
write_size = memory_size - (memory_size % sample_size)

out_path_spec = self._file_list[i]
out_path_spec_off = self.index_file_path_off(out_path_spec)
out_path_spec_sz = self.index_file_path_size(out_path_spec)

progress(i + 1, self.total_files_to_generate, "Generating Indexed Binary Data (MSC)")

records = np.random.randint(255, size=write_size, dtype=np.uint8)

tmp_data = tempfile.NamedTemporaryFile(delete=False)
tmp_off = tempfile.NamedTemporaryFile(delete=False)
tmp_sz = tempfile.NamedTemporaryFile(delete=False)
try:
written_bytes = 0
while written_bytes < total_size:
data_to_write = write_size if written_bytes + write_size <= total_size else total_size - written_bytes
samples_to_write = data_to_write // sample_size

# Write data
myfmt = 'B' * data_to_write
binary_data = struct.pack(myfmt, *records[:data_to_write])
tmp_data.write(binary_data)
struct._clearcache()

# Write offsets
myfmt = 'Q' * samples_to_write
offsets = range(0, data_to_write, sample_size)
offsets = offsets[:samples_to_write]
binary_offsets = struct.pack(myfmt, *offsets)
tmp_off.write(binary_offsets)

# Write sizes
myfmt = 'Q' * samples_to_write
sample_sizes = [sample_size] * samples_to_write
binary_sizes = struct.pack(myfmt, *sample_sizes)
tmp_sz.write(binary_sizes)

written_bytes = written_bytes + data_to_write

tmp_data.close()
tmp_off.close()
tmp_sz.close()

self.storage.upload_file(out_path_spec, tmp_data.name)
self.storage.upload_file(out_path_spec_off, tmp_off.name)
self.storage.upload_file(out_path_spec_sz, tmp_sz.name)
finally:
os.unlink(tmp_data.name)
os.unlink(tmp_off.name)
os.unlink(tmp_sz.name)

np.random.seed()
DLIOMPI.get_instance().comm().Barrier()
121 changes: 121 additions & 0 deletions dlio_benchmark/reader/indexed_binary_msc_reader.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,121 @@
"""
Copyright (c) 2026, UChicago Argonne, LLC
All Rights Reserved

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
import numpy as np

from dlio_benchmark.common.constants import MODULE_DATA_READER
from dlio_benchmark.common.enumerations import DataLoaderSampler
from dlio_benchmark.reader.reader_handler import FormatReader
from dlio_benchmark.storage.storage_factory import StorageFactory
from dlio_benchmark.utils.utility import Profile, dft_ai

dlp = Profile(MODULE_DATA_READER)


class IndexedBinaryMscReader(FormatReader):
"""
Reader for Indexed Binary files via MSC.
"""

@dlp.log_init
def __init__(self, dataset_type, thread_index, epoch):
super().__init__(dataset_type, thread_index)
self.storage = StorageFactory().get_storage(
self._args.storage_type, self._args.storage_root, self._args.framework
)
self.file_map_ibr = {}
self.load_index()

def index_file_path_off(self, prefix_path):
prefix_path = prefix_path.replace(self.storage.storage_root, "")
return prefix_path + '.off.idx'

def index_file_path_size(self, prefix_path):
prefix_path = prefix_path.replace(self.storage.storage_root, "")
return prefix_path + '.sz.idx'

def binary_file_path(self, prefix_path):
prefix_path = prefix_path.replace(self.storage.storage_root, "")
return prefix_path

def _load_index_array(self, path, dtype=np.uint64):
"""Fetch an entire index file and parse it as a numpy array."""
raw = self.storage.get_data(path, None)
return np.frombuffer(raw, dtype=dtype)

def load_index_file(self, global_sample_idx, filename, sample_index):
assert isinstance(filename, str), "filename must be a string"
if filename not in self.file_map_ibr:
offset_file = self.index_file_path_off(filename)
sz_file = self.index_file_path_size(filename)
self.file_map_ibr[filename] = [
self._load_index_array(offset_file),
self._load_index_array(sz_file),
]
self.logger.debug(
f"loaded index for {filename}: "
f"{len(self.file_map_ibr[filename][0])} offsets, "
f"{len(self.file_map_ibr[filename][1])} sizes"
)

@dlp.log
def load_index(self):
if self._args.data_loader_sampler == DataLoaderSampler.ITERATIVE:
for global_sample_idx, filename, sample_index in self.file_map[self.thread_index]:
self.load_index_file(global_sample_idx, filename, sample_index)
elif self._args.data_loader_sampler == DataLoaderSampler.INDEX:
for global_sample_idx, (filename, sample_index) in self.global_index_map.items():
self.load_index_file(global_sample_idx, filename, sample_index)

@dlp.log
def open(self, filename):
super().open(filename)
return self.storage.open(self.binary_file_path(filename))

@dlp.log
def close(self, filename):
super().close(filename)
self.open_file_map[filename].close()

@dlp.log
def get_sample(self, filename, sample_index):
super().get_sample(filename, sample_index)
file = self.open_file_map[filename]
offset = self.file_map_ibr[filename][0][sample_index]
size = self.file_map_ibr[filename][1][sample_index]
self.logger.debug(f"reading sample from offset {offset} of size {size} from file {filename}")
file.seek(offset)
image = np.empty(size, dtype=np.uint8)
file.readinto(image)
dlp.update(image_size=size)

def next(self):
for batch in super().next():
yield batch

@dft_ai.data.item
def read_index(self, image_idx, step):
return super().read_index(image_idx, step)

@dlp.log
def finalize(self):
super().finalize()

def is_index_based(self):
return True

def is_iterator_based(self):
return True
Loading